Skip to main content

Java Inference API

The ONNX Runtime Java API enables high-performance inference in Java applications and Android development. This guide covers the complete Java API with real examples from the codebase.

Installation

Maven

<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.17.0</version>
</dependency>

<!-- For Android -->
<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime-android</artifactId>
    <version>1.17.0</version>
</dependency>

Gradle

implementation 'com.microsoft.onnxruntime:onnxruntime:1.17.0'

// For Android
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.17.0'

Quick Start

Here’s a minimal Java example:
import ai.onnxruntime.*;
import java.util.Map;

public class QuickStart {
    public static void main(String[] args) throws OrtException {
        // Create environment
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        
        // Create session
        try (OrtSession session = env.createSession("model.onnx",
                new OrtSession.SessionOptions())) {
            
            // Get input name
            String inputName = session.getInputNames().iterator().next();
            
            // Create input tensor
            float[][][][] inputData = new float[1][3][224][224];
            // Fill with data...
            
            OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);
            
            // Run inference
            try (OrtSession.Result results = session.run(
                    Map.of(inputName, inputTensor))) {
                
                // Get output
                float[][] output = (float[][]) results.get(0).getValue();
                System.out.println("Output: " + output[0][0]);
            }
        }
    }
}

OrtEnvironment

The environment manages global ONNX Runtime state. Create one per application.
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtLoggingLevel;

// Get default environment (singleton)
OrtEnvironment env = OrtEnvironment.getEnvironment();

// Create with logging level
OrtEnvironment env = OrtEnvironment.getEnvironment(
    OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING
);

// Create with name and logging level
OrtEnvironment env = OrtEnvironment.getEnvironment(
    OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO,
    "MyApp"
);
Logging levels:
OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE
OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO
OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING
OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL

OrtSession

The session loads and runs ONNX models.

Creating a Session

From file path:
import ai.onnxruntime.*;

OrtEnvironment env = OrtEnvironment.getEnvironment();

// Basic usage
try (OrtSession session = env.createSession("model.onnx",
        new OrtSession.SessionOptions())) {
    // Use session...
}

// With options
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);

try (OrtSession session = env.createSession("model.onnx", options)) {
    // Use session...
}
From byte array:
import java.nio.file.Files;
import java.nio.file.Paths;

byte[] modelBytes = Files.readAllBytes(Paths.get("model.onnx"));

try (OrtSession session = env.createSession(modelBytes,
        new OrtSession.SessionOptions())) {
    // Use session...
}
From ByteBuffer:
import java.nio.ByteBuffer;

ByteBuffer modelBuffer = loadModelToBuffer("model.onnx");

try (OrtSession session = env.createSession(modelBuffer,
        new OrtSession.SessionOptions())) {
    // Use session...
}

Session Metadata

// Get input/output counts
long numInputs = session.getNumInputs();
long numOutputs = session.getNumOutputs();

// Get input names (ordered set)
Set<String> inputNames = session.getInputNames();
for (String name : inputNames) {
    System.out.println("Input: " + name);
}

// Get output names (ordered set)
Set<String> outputNames = session.getOutputNames();
for (String name : outputNames) {
    System.out.println("Output: " + name);
}

// Get input information
Map<String, NodeInfo> inputInfo = session.getInputInfo();
for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
    System.out.println("Input: " + entry.getKey());
    System.out.println("  Info: " + entry.getValue().getInfo());
}

// Get output information
Map<String, NodeInfo> outputInfo = session.getOutputInfo();
for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) {
    System.out.println("Output: " + entry.getKey());
    TensorInfo tensorInfo = (TensorInfo) entry.getValue().getInfo();
    System.out.println("  Shape: " + Arrays.toString(tensorInfo.getShape()));
    System.out.println("  Type: " + tensorInfo.type);
}
Get model metadata:
OnnxModelMetadata metadata = session.getMetadata();
System.out.println("Producer: " + metadata.getProducerName());
System.out.println("Graph Name: " + metadata.getGraphName());
System.out.println("Domain: " + metadata.getDomain());
System.out.println("Version: " + metadata.getVersion());
System.out.println("Description: " + metadata.getDescription());

// Custom metadata
Map<String, String> customMetadata = metadata.getCustomMetadata();
for (Map.Entry<String, String> entry : customMetadata.entrySet()) {
    System.out.println(entry.getKey() + ": " + entry.getValue());
}

Running Inference

Basic inference:
import ai.onnxruntime.*;
import java.util.Map;

// Prepare input
float[][][][] inputData = new float[1][3][224][224];
// Fill with data...

String inputName = session.getInputNames().iterator().next();
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData);

// Run inference
try (OrtSession.Result results = session.run(
        Map.of(inputName, inputTensor))) {
    
    // Get first output
    OnnxValue output = results.get(0);
    float[][] predictions = (float[][]) output.getValue();
    
    System.out.println("Predictions: " + Arrays.toString(predictions[0]));
}
Multiple inputs:
import java.util.HashMap;

Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input1", OnnxTensor.createTensor(env, input1Data));
inputs.put("input2", OnnxTensor.createTensor(env, input2Data));

try (OrtSession.Result results = session.run(inputs)) {
    // Process results...
}
Request specific outputs:
import java.util.Set;

// Only compute specific outputs
Set<String> requestedOutputs = Set.of("output1", "output2");

try (OrtSession.Result results = session.run(
        inputs, requestedOutputs)) {
    // Get outputs by name
    OnnxValue output1 = results.get("output1").get();
    OnnxValue output2 = results.get("output2").get();
}
With RunOptions:
OrtSession.RunOptions runOptions = new OrtSession.RunOptions();
runOptions.setLogSeverityLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING);
runOptions.setLogVerbosityLevel(0);
runOptions.setRunTag("inference_run_1");

try (OrtSession.Result results = session.run(
        inputs, runOptions)) {
    // Process results...
}

SessionOptions

Configure session behavior:
import ai.onnxruntime.OrtSession.SessionOptions;

SessionOptions options = new SessionOptions();

// Optimization level
options.setOptimizationLevel(SessionOptions.OptLevel.ALL_OPT);
// Options: NO_OPT, BASIC_OPT, EXTENDED_OPT, ALL_OPT

// Threading
options.setIntraOpNumThreads(4);
options.setInterOpNumThreads(2);

// Execution mode
options.setExecutionMode(SessionOptions.ExecutionMode.SEQUENTIAL);
// Options: SEQUENTIAL, PARALLEL

// Memory optimization
options.setCPUArenaAllocator(true);
options.setMemoryPatternOptimization(true);

// Profiling
options.setProfileOutput("ort_profile.json");

// Logging
options.setLogId("MySession");
options.setLogSeverityLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING);

// Save optimized model
options.setOptimizedModelFilePath("optimized_model.onnx");

// Register custom op library
options.registerCustomOpLibrary("/path/to/custom_ops.so");

OnnxTensor

Create tensors for model inputs: From Java arrays:
import ai.onnxruntime.OnnxTensor;

// 1D tensor
float[] data1D = {1.0f, 2.0f, 3.0f};
OnnxTensor tensor1D = OnnxTensor.createTensor(env, data1D);

// 2D tensor
float[][] data2D = {{1.0f, 2.0f}, {3.0f, 4.0f}};
OnnxTensor tensor2D = OnnxTensor.createTensor(env, data2D);

// 4D tensor (common for images)
float[][][][] data4D = new float[1][3][224][224];
OnnxTensor tensor4D = OnnxTensor.createTensor(env, data4D);
From ByteBuffer:
import java.nio.FloatBuffer;
import java.nio.ByteBuffer;

FloatBuffer buffer = FloatBuffer.allocate(1 * 3 * 224 * 224);
// Fill buffer...

long[] shape = {1, 3, 224, 224};
OnnxTensor tensor = OnnxTensor.createTensor(
    env, buffer, shape, OnnxJavaType.FLOAT);
From String array:
String[] strings = {"hello", "world"};
OnnxTensor tensor = OnnxTensor.createTensor(env, strings);
Get tensor information:
TensorInfo info = tensor.getInfo();
long[] shape = info.getShape();
OnnxJavaType type = info.type;
long size = info.getElementCount();

System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Type: " + type);
System.out.println("Elements: " + size);

Execution Providers

Adding Execution Providers

CUDA:
import ai.onnxruntime.providers.OrtCUDAProviderOptions;

SessionOptions options = new SessionOptions();

// Simple CUDA
options.addCUDA(0); // Device ID

// With options
OrtCUDAProviderOptions cudaOptions = new OrtCUDAProviderOptions(0);
cudaOptions.add("gpu_mem_limit", "2147483648"); // 2GB
cudaOptions.add("arena_extend_strategy", "kSameAsRequested");
cudaOptions.add("cudnn_conv_algo_search", "EXHAUSTIVE");

options.addCUDA(cudaOptions);
TensorRT:
import ai.onnxruntime.providers.OrtTensorRTProviderOptions;

OrtTensorRTProviderOptions trtOptions = new OrtTensorRTProviderOptions(0);
trtOptions.add("trt_max_workspace_size", "2147483648");
trtOptions.add("trt_fp16_enable", "1");

options.addTensorRT(trtOptions);
CoreML (macOS/iOS):
import ai.onnxruntime.providers.CoreMLFlags;
import java.util.EnumSet;

EnumSet<CoreMLFlags> coremlFlags = EnumSet.of(
    CoreMLFlags.COREML_FLAG_ENABLE_ON_SUBGRAPH
);
options.addCoreML(coremlFlags);
NNAPI (Android):
import ai.onnxruntime.providers.NNAPIFlags;
import java.util.EnumSet;

EnumSet<NNAPIFlags> nnapiFlags = EnumSet.of(
    NNAPIFlags.USE_FP16
);
options.addNnapi(nnapiFlags);
Check available providers:
import ai.onnxruntime.OnnxRuntime;

Set<String> availableProviders = OnnxRuntime.getAvailableProviders();
System.out.println("Available providers: " + availableProviders);

Complete Example: MNIST Classification

From the ONNX Runtime codebase:
import ai.onnxruntime.*;
import java.io.*;
import java.util.*;

public class MNISTClassifier {
    private final OrtEnvironment env;
    private final OrtSession session;
    private final String inputName;
    
    public MNISTClassifier(String modelPath) throws OrtException {
        // Create environment
        env = OrtEnvironment.getEnvironment();
        
        // Configure session
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        options.setOptimizationLevel(
            OrtSession.SessionOptions.OptLevel.BASIC_OPT);
        
        // Create session
        session = env.createSession(modelPath, options);
        
        // Get input name
        inputName = session.getInputNames().iterator().next();
        
        // Print model info
        System.out.println("Model loaded: " + modelPath);
        System.out.println("Inputs:");
        for (NodeInfo info : session.getInputInfo().values()) {
            System.out.println("  " + info);
        }
        System.out.println("Outputs:");
        for (NodeInfo info : session.getOutputInfo().values()) {
            System.out.println("  " + info);
        }
    }
    
    public int classify(float[][][][] imageData) throws OrtException {
        // Create input tensor
        try (OnnxTensor inputTensor = OnnxTensor.createTensor(env, imageData)) {
            
            // Run inference
            try (OrtSession.Result results = session.run(
                    Map.of(inputName, inputTensor))) {
                
                // Get predictions
                float[][] output = (float[][]) results.get(0).getValue();
                
                // Find max probability
                float maxVal = Float.NEGATIVE_INFINITY;
                int maxIdx = 0;
                for (int i = 0; i < output[0].length; i++) {
                    if (output[0][i] > maxVal) {
                        maxVal = output[0][i];
                        maxIdx = i;
                    }
                }
                
                return maxIdx;
            }
        }
    }
    
    public void close() throws OrtException {
        session.close();
    }
    
    public static void main(String[] args) throws Exception {
        if (args.length < 1) {
            System.out.println("Usage: MNISTClassifier <model-path>");
            return;
        }
        
        try (MNISTClassifier classifier = new MNISTClassifier(args[0])) {
            // Create test data (1, 1, 28, 28)
            float[][][][] testData = new float[1][1][28][28];
            
            // Fill with sample data
            Random rand = new Random();
            for (int i = 0; i < 28; i++) {
                for (int j = 0; j < 28; j++) {
                    testData[0][0][i][j] = rand.nextFloat();
                }
            }
            
            // Run classification
            int prediction = classifier.classify(testData);
            System.out.println("Predicted digit: " + prediction);
        }
    }
}

Android Example

import ai.onnxruntime.*;
import android.content.Context;
import java.io.InputStream;

public class AndroidInference {
    private final OrtEnvironment env;
    private final OrtSession session;
    
    public AndroidInference(Context context, String modelFileName) 
            throws Exception {
        env = OrtEnvironment.getEnvironment();
        
        // Load model from assets
        InputStream modelStream = context.getAssets().open(modelFileName);
        byte[] modelBytes = new byte[modelStream.available()];
        modelStream.read(modelBytes);
        modelStream.close();
        
        // Configure for mobile
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        options.setOptimizationLevel(
            OrtSession.SessionOptions.OptLevel.ALL_OPT);
        
        // Add NNAPI for Android acceleration
        options.addNnapi(
            java.util.EnumSet.of(ai.onnxruntime.providers.NNAPIFlags.USE_FP16)
        );
        
        session = env.createSession(modelBytes, options);
    }
    
    public float[] infer(float[][][][] input) throws OrtException {
        String inputName = session.getInputNames().iterator().next();
        
        try (OnnxTensor tensor = OnnxTensor.createTensor(env, input);
             OrtSession.Result results = session.run(
                 java.util.Map.of(inputName, tensor))) {
            
            float[][] output = (float[][]) results.get(0).getValue();
            return output[0];
        }
    }
    
    public void close() throws OrtException {
        session.close();
    }
}

Error Handling

try {
    OrtEnvironment env = OrtEnvironment.getEnvironment();
    try (OrtSession session = env.createSession("model.onnx",
            new OrtSession.SessionOptions())) {
        // Run inference...
    }
} catch (OrtException e) {
    System.err.println("ONNX Runtime error: " + e.getMessage());
    e.printStackTrace();
} catch (Exception e) {
    System.err.println("Error: " + e.getMessage());
}

Supported Data Types

OnnxJavaType.FLOAT      // float
OnnxJavaType.DOUBLE     // double
OnnxJavaType.INT8       // byte
OnnxJavaType.INT16      // short
OnnxJavaType.INT32      // int
OnnxJavaType.INT64      // long
OnnxJavaType.UINT8      // unsigned byte
OnnxJavaType.BOOL       // boolean
OnnxJavaType.STRING     // String

Best Practices

Always use try-with-resources for OrtSession, OnnxTensor, and Result to ensure proper cleanup.
Create one OrtEnvironment per application and reuse it for all sessions.
Session creation is expensive. Create once and reuse for multiple inferences.
Use NNAPI on Android, CUDA on desktop with NVIDIA GPUs for best performance.
Set optimization level to ALL_OPT for production deployments.

Next Steps

Model Optimization

Optimize models for production deployment

Execution Providers

Configure hardware acceleration