Java Inference API
The ONNX Runtime Java API enables high-performance inference in Java applications and Android development. This guide covers the complete Java API with real examples from the codebase.Installation
Maven
Copy
Ask AI
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.0</version>
</dependency>
<!-- For Android -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime-android</artifactId>
<version>1.17.0</version>
</dependency>
Gradle
Copy
Ask AI
implementation 'com.microsoft.onnxruntime:onnxruntime:1.17.0'
// For Android
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.17.0'
Quick Start
Here’s a minimal Java example:Copy
Ask AI
import ai.onnxruntime.*;
import java.util.Map;
public class QuickStart {
public static void main(String[] args) throws OrtException {
// Create environment
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create session
try (OrtSession session = env.createSession("model.onnx",
new OrtSession.SessionOptions())) {
// Get input name
String inputName = session.getInputNames().iterator().next();
// Create input tensor
float[][][][] inputData = new float[1][3][224][224];
// Fill with data...
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
// Run inference
try (OrtSession.Result results = session.run(
Map.of(inputName, inputTensor))) {
// Get output
float[][] output = (float[][]) results.get(0).getValue();
System.out.println("Output: " + output[0][0]);
}
}
}
}
OrtEnvironment
The environment manages global ONNX Runtime state. Create one per application.Copy
Ask AI
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtLoggingLevel;
// Get default environment (singleton)
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Create with logging level
OrtEnvironment env = OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING
);
// Create with name and logging level
OrtEnvironment env = OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO,
"MyApp"
);
Copy
Ask AI
OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE
OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO
OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING
OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL
OrtSession
The session loads and runs ONNX models.Creating a Session
From file path:Copy
Ask AI
import ai.onnxruntime.*;
OrtEnvironment env = OrtEnvironment.getEnvironment();
// Basic usage
try (OrtSession session = env.createSession("model.onnx",
new OrtSession.SessionOptions())) {
// Use session...
}
// With options
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
try (OrtSession session = env.createSession("model.onnx", options)) {
// Use session...
}
Copy
Ask AI
import java.nio.file.Files;
import java.nio.file.Paths;
byte[] modelBytes = Files.readAllBytes(Paths.get("model.onnx"));
try (OrtSession session = env.createSession(modelBytes,
new OrtSession.SessionOptions())) {
// Use session...
}
Copy
Ask AI
import java.nio.ByteBuffer;
ByteBuffer modelBuffer = loadModelToBuffer("model.onnx");
try (OrtSession session = env.createSession(modelBuffer,
new OrtSession.SessionOptions())) {
// Use session...
}
Session Metadata
Copy
Ask AI
// Get input/output counts
long numInputs = session.getNumInputs();
long numOutputs = session.getNumOutputs();
// Get input names (ordered set)
Set<String> inputNames = session.getInputNames();
for (String name : inputNames) {
System.out.println("Input: " + name);
}
// Get output names (ordered set)
Set<String> outputNames = session.getOutputNames();
for (String name : outputNames) {
System.out.println("Output: " + name);
}
// Get input information
Map<String, NodeInfo> inputInfo = session.getInputInfo();
for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
System.out.println("Input: " + entry.getKey());
System.out.println(" Info: " + entry.getValue().getInfo());
}
// Get output information
Map<String, NodeInfo> outputInfo = session.getOutputInfo();
for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) {
System.out.println("Output: " + entry.getKey());
TensorInfo tensorInfo = (TensorInfo) entry.getValue().getInfo();
System.out.println(" Shape: " + Arrays.toString(tensorInfo.getShape()));
System.out.println(" Type: " + tensorInfo.type);
}
Copy
Ask AI
OnnxModelMetadata metadata = session.getMetadata();
System.out.println("Producer: " + metadata.getProducerName());
System.out.println("Graph Name: " + metadata.getGraphName());
System.out.println("Domain: " + metadata.getDomain());
System.out.println("Version: " + metadata.getVersion());
System.out.println("Description: " + metadata.getDescription());
// Custom metadata
Map<String, String> customMetadata = metadata.getCustomMetadata();
for (Map.Entry<String, String> entry : customMetadata.entrySet()) {
System.out.println(entry.getKey() + ": " + entry.getValue());
}
Running Inference
Basic inference:Copy
Ask AI
import ai.onnxruntime.*;
import java.util.Map;
// Prepare input
float[][][][] inputData = new float[1][3][224][224];
// Fill with data...
String inputName = session.getInputNames().iterator().next();
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
// Run inference
try (OrtSession.Result results = session.run(
Map.of(inputName, inputTensor))) {
// Get first output
OnnxValue output = results.get(0);
float[][] predictions = (float[][]) output.getValue();
System.out.println("Predictions: " + Arrays.toString(predictions[0]));
}
Copy
Ask AI
import java.util.HashMap;
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input1", OnnxTensor.createTensor(env, input1Data));
inputs.put("input2", OnnxTensor.createTensor(env, input2Data));
try (OrtSession.Result results = session.run(inputs)) {
// Process results...
}
Copy
Ask AI
import java.util.Set;
// Only compute specific outputs
Set<String> requestedOutputs = Set.of("output1", "output2");
try (OrtSession.Result results = session.run(
inputs, requestedOutputs)) {
// Get outputs by name
OnnxValue output1 = results.get("output1").get();
OnnxValue output2 = results.get("output2").get();
}
Copy
Ask AI
OrtSession.RunOptions runOptions = new OrtSession.RunOptions();
runOptions.setLogSeverityLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING);
runOptions.setLogVerbosityLevel(0);
runOptions.setRunTag("inference_run_1");
try (OrtSession.Result results = session.run(
inputs, runOptions)) {
// Process results...
}
SessionOptions
Configure session behavior:Copy
Ask AI
import ai.onnxruntime.OrtSession.SessionOptions;
SessionOptions options = new SessionOptions();
// Optimization level
options.setOptimizationLevel(SessionOptions.OptLevel.ALL_OPT);
// Options: NO_OPT, BASIC_OPT, EXTENDED_OPT, ALL_OPT
// Threading
options.setIntraOpNumThreads(4);
options.setInterOpNumThreads(2);
// Execution mode
options.setExecutionMode(SessionOptions.ExecutionMode.SEQUENTIAL);
// Options: SEQUENTIAL, PARALLEL
// Memory optimization
options.setCPUArenaAllocator(true);
options.setMemoryPatternOptimization(true);
// Profiling
options.setProfileOutput("ort_profile.json");
// Logging
options.setLogId("MySession");
options.setLogSeverityLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING);
// Save optimized model
options.setOptimizedModelFilePath("optimized_model.onnx");
// Register custom op library
options.registerCustomOpLibrary("/path/to/custom_ops.so");
OnnxTensor
Create tensors for model inputs: From Java arrays:Copy
Ask AI
import ai.onnxruntime.OnnxTensor;
// 1D tensor
float[] data1D = {1.0f, 2.0f, 3.0f};
OnnxTensor tensor1D = OnnxTensor.createTensor(env, data1D);
// 2D tensor
float[][] data2D = {{1.0f, 2.0f}, {3.0f, 4.0f}};
OnnxTensor tensor2D = OnnxTensor.createTensor(env, data2D);
// 4D tensor (common for images)
float[][][][] data4D = new float[1][3][224][224];
OnnxTensor tensor4D = OnnxTensor.createTensor(env, data4D);
Copy
Ask AI
import java.nio.FloatBuffer;
import java.nio.ByteBuffer;
FloatBuffer buffer = FloatBuffer.allocate(1 * 3 * 224 * 224);
// Fill buffer...
long[] shape = {1, 3, 224, 224};
OnnxTensor tensor = OnnxTensor.createTensor(
env, buffer, shape, OnnxJavaType.FLOAT);
Copy
Ask AI
String[] strings = {"hello", "world"};
OnnxTensor tensor = OnnxTensor.createTensor(env, strings);
Copy
Ask AI
TensorInfo info = tensor.getInfo();
long[] shape = info.getShape();
OnnxJavaType type = info.type;
long size = info.getElementCount();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Type: " + type);
System.out.println("Elements: " + size);
Execution Providers
Adding Execution Providers
CUDA:Copy
Ask AI
import ai.onnxruntime.providers.OrtCUDAProviderOptions;
SessionOptions options = new SessionOptions();
// Simple CUDA
options.addCUDA(0); // Device ID
// With options
OrtCUDAProviderOptions cudaOptions = new OrtCUDAProviderOptions(0);
cudaOptions.add("gpu_mem_limit", "2147483648"); // 2GB
cudaOptions.add("arena_extend_strategy", "kSameAsRequested");
cudaOptions.add("cudnn_conv_algo_search", "EXHAUSTIVE");
options.addCUDA(cudaOptions);
Copy
Ask AI
import ai.onnxruntime.providers.OrtTensorRTProviderOptions;
OrtTensorRTProviderOptions trtOptions = new OrtTensorRTProviderOptions(0);
trtOptions.add("trt_max_workspace_size", "2147483648");
trtOptions.add("trt_fp16_enable", "1");
options.addTensorRT(trtOptions);
Copy
Ask AI
import ai.onnxruntime.providers.CoreMLFlags;
import java.util.EnumSet;
EnumSet<CoreMLFlags> coremlFlags = EnumSet.of(
CoreMLFlags.COREML_FLAG_ENABLE_ON_SUBGRAPH
);
options.addCoreML(coremlFlags);
Copy
Ask AI
import ai.onnxruntime.providers.NNAPIFlags;
import java.util.EnumSet;
EnumSet<NNAPIFlags> nnapiFlags = EnumSet.of(
NNAPIFlags.USE_FP16
);
options.addNnapi(nnapiFlags);
Copy
Ask AI
import ai.onnxruntime.OnnxRuntime;
Set<String> availableProviders = OnnxRuntime.getAvailableProviders();
System.out.println("Available providers: " + availableProviders);
Complete Example: MNIST Classification
From the ONNX Runtime codebase:Copy
Ask AI
import ai.onnxruntime.*;
import java.io.*;
import java.util.*;
public class MNISTClassifier {
private final OrtEnvironment env;
private final OrtSession session;
private final String inputName;
public MNISTClassifier(String modelPath) throws OrtException {
// Create environment
env = OrtEnvironment.getEnvironment();
// Configure session
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(
OrtSession.SessionOptions.OptLevel.BASIC_OPT);
// Create session
session = env.createSession(modelPath, options);
// Get input name
inputName = session.getInputNames().iterator().next();
// Print model info
System.out.println("Model loaded: " + modelPath);
System.out.println("Inputs:");
for (NodeInfo info : session.getInputInfo().values()) {
System.out.println(" " + info);
}
System.out.println("Outputs:");
for (NodeInfo info : session.getOutputInfo().values()) {
System.out.println(" " + info);
}
}
public int classify(float[][][][] imageData) throws OrtException {
// Create input tensor
try (OnnxTensor inputTensor = OnnxTensor.createTensor(env, imageData)) {
// Run inference
try (OrtSession.Result results = session.run(
Map.of(inputName, inputTensor))) {
// Get predictions
float[][] output = (float[][]) results.get(0).getValue();
// Find max probability
float maxVal = Float.NEGATIVE_INFINITY;
int maxIdx = 0;
for (int i = 0; i < output[0].length; i++) {
if (output[0][i] > maxVal) {
maxVal = output[0][i];
maxIdx = i;
}
}
return maxIdx;
}
}
}
public void close() throws OrtException {
session.close();
}
public static void main(String[] args) throws Exception {
if (args.length < 1) {
System.out.println("Usage: MNISTClassifier <model-path>");
return;
}
try (MNISTClassifier classifier = new MNISTClassifier(args[0])) {
// Create test data (1, 1, 28, 28)
float[][][][] testData = new float[1][1][28][28];
// Fill with sample data
Random rand = new Random();
for (int i = 0; i < 28; i++) {
for (int j = 0; j < 28; j++) {
testData[0][0][i][j] = rand.nextFloat();
}
}
// Run classification
int prediction = classifier.classify(testData);
System.out.println("Predicted digit: " + prediction);
}
}
}
Android Example
Copy
Ask AI
import ai.onnxruntime.*;
import android.content.Context;
import java.io.InputStream;
public class AndroidInference {
private final OrtEnvironment env;
private final OrtSession session;
public AndroidInference(Context context, String modelFileName)
throws Exception {
env = OrtEnvironment.getEnvironment();
// Load model from assets
InputStream modelStream = context.getAssets().open(modelFileName);
byte[] modelBytes = new byte[modelStream.available()];
modelStream.read(modelBytes);
modelStream.close();
// Configure for mobile
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(
OrtSession.SessionOptions.OptLevel.ALL_OPT);
// Add NNAPI for Android acceleration
options.addNnapi(
java.util.EnumSet.of(ai.onnxruntime.providers.NNAPIFlags.USE_FP16)
);
session = env.createSession(modelBytes, options);
}
public float[] infer(float[][][][] input) throws OrtException {
String inputName = session.getInputNames().iterator().next();
try (OnnxTensor tensor = OnnxTensor.createTensor(env, input);
OrtSession.Result results = session.run(
java.util.Map.of(inputName, tensor))) {
float[][] output = (float[][]) results.get(0).getValue();
return output[0];
}
}
public void close() throws OrtException {
session.close();
}
}
Error Handling
Copy
Ask AI
try {
OrtEnvironment env = OrtEnvironment.getEnvironment();
try (OrtSession session = env.createSession("model.onnx",
new OrtSession.SessionOptions())) {
// Run inference...
}
} catch (OrtException e) {
System.err.println("ONNX Runtime error: " + e.getMessage());
e.printStackTrace();
} catch (Exception e) {
System.err.println("Error: " + e.getMessage());
}
Supported Data Types
Copy
Ask AI
OnnxJavaType.FLOAT // float
OnnxJavaType.DOUBLE // double
OnnxJavaType.INT8 // byte
OnnxJavaType.INT16 // short
OnnxJavaType.INT32 // int
OnnxJavaType.INT64 // long
OnnxJavaType.UINT8 // unsigned byte
OnnxJavaType.BOOL // boolean
OnnxJavaType.STRING // String
Best Practices
Use try-with-resources
Use try-with-resources
Always use try-with-resources for OrtSession, OnnxTensor, and Result to ensure proper cleanup.
Reuse OrtEnvironment
Reuse OrtEnvironment
Create one OrtEnvironment per application and reuse it for all sessions.
Reuse Sessions
Reuse Sessions
Session creation is expensive. Create once and reuse for multiple inferences.
Choose the Right Provider
Choose the Right Provider
Use NNAPI on Android, CUDA on desktop with NVIDIA GPUs for best performance.
Enable Optimization
Enable Optimization
Set optimization level to ALL_OPT for production deployments.
Next Steps
Model Optimization
Optimize models for production deployment
Execution Providers
Configure hardware acceleration