Skip to main content
The WebGPU execution provider enables GPU-accelerated inference directly in web browsers using the modern WebGPU API, providing cross-platform GPU access without plugins or native installations.

Requirements

Browser Support

  • Chrome/Edge: Version 113+ (stable support)
  • Firefox: Version 121+ (experimental, requires flag)
  • Safari: Version 18+ (Technology Preview)
  • Opera: Version 99+

Hardware

  • GPU with Vulkan, Metal, or DirectX 12 support
  • Supported GPUs:
    • NVIDIA (Vulkan/DirectX 12)
    • AMD (Vulkan/DirectX 12)
    • Intel (Vulkan/DirectX 12)
    • Apple Silicon (Metal)
WebGPU is still evolving. Check caniuse.com/webgpu for current browser support.

Installation

npm install onnxruntime-web

Basic Configuration

JavaScript API

import * as ort from 'onnxruntime-web';

// Set WebGPU as the execution provider
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;

// Create session with WebGPU
const session = await ort.InferenceSession.create('model.onnx', {
  executionProviders: ['webgpu']
});

// Run inference
const feeds = { input: new ort.Tensor('float32', inputData, [1, 512]) };
const results = await session.run(feeds);

TypeScript

import * as ort from 'onnxruntime-web';

interface ModelInputs {
  input_ids: ort.Tensor;
  attention_mask: ort.Tensor;
}

interface ModelOutputs {
  logits: ort.Tensor;
}

async function runModel(inputs: ModelInputs): Promise<ModelOutputs> {
  const session = await ort.InferenceSession.create('model.onnx', {
    executionProviders: ['webgpu'],
    graphOptimizationLevel: 'all'
  });
  
  const feeds = {
    input_ids: inputs.input_ids,
    attention_mask: inputs.attention_mask
  };
  
  const results = await session.run(feeds);
  return results as unknown as ModelOutputs;
}

Memory Management

GPU Buffer Management

WebGPU uses GPU buffers for efficient memory management:
// C++ implementation of WebGPU memory
struct WebGPUMemory final : DeviceBuffer {
  WebGPUMemory(size_t size, Ort::Allocator* allocator) 
      : owned_{true}, ort_allocator_{allocator} {
    size_in_bytes_ = size;
    p_device_ = static_cast<uint8_t*>(ort_allocator_->Alloc(size_in_bytes_));
  }
  
  void CopyDeviceToCpu() override {
    // Create tensors for async copy
    auto src_tensor = OrtValue::CreateTensor(
      *ort_memory_info_, p_device_, size_in_bytes_, 
      shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
    
    auto cpu_mem_info = OrtMemoryInfo::CreateCpu();
    auto dst_tensor = OrtValue::CreateTensor(
      *cpu_mem_info, p_cpu_, size_in_bytes_, 
      shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
    
    // Synchronous copy
    GetOrtEnv().CopyTensors({src_tensor.get()}, {dst_tensor.get()}, nullptr);
  }
};
WebGPU buffer handles cannot use pointer arithmetic. Full-buffer copies are used for operations requiring offsets.

Memory Transfer Optimization

// Efficient tensor creation
const inputTensor = new ort.Tensor(
  'float32',
  inputData,
  [batchSize, sequenceLength]
);

// Reuse tensors when possible
const session = await ort.InferenceSession.create('model.onnx', {
  executionProviders: ['webgpu'],
  enableMemPattern: true,  // Reuse memory patterns
  enableCpuMemArena: false  // Disable CPU arena for WebGPU
});

Configuration Options

Session Options

const sessionOptions = {
  executionProviders: ['webgpu'],
  graphOptimizationLevel: 'all',
  executionMode: 'sequential',
  enableProfiling: false,
  
  // WebGPU-specific options
  preferredLayout: 'NCHW',  // or 'NHWC'
  
  // Enable INT64 support (required for some models)
  extra: {
    session: {
      'ep.webgpu.enableInt64': '1'
    }
  }
};

const session = await ort.InferenceSession.create(
  'model.onnx',
  sessionOptions
);

INT64 Support

WebGPU requires explicit INT64 enablement:
const session = await ort.InferenceSession.create('model.onnx', {
  executionProviders: [{
    name: 'webgpu',
    enableInt64: true
  }]
});

Browser Compatibility

Feature Detection

async function checkWebGPUSupport() {
  if (!navigator.gpu) {
    console.log('WebGPU not supported');
    return false;
  }
  
  try {
    const adapter = await navigator.gpu.requestAdapter();
    if (!adapter) {
      console.log('No WebGPU adapter found');
      return false;
    }
    
    console.log('WebGPU supported');
    console.log('Adapter:', adapter.name);
    return true;
  } catch (e) {
    console.error('WebGPU error:', e);
    return false;
  }
}

// Fallback to WebAssembly if WebGPU unavailable
async function createSession(modelPath) {
  const hasWebGPU = await checkWebGPUSupport();
  
  const executionProviders = hasWebGPU 
    ? ['webgpu'] 
    : ['wasm'];
  
  return await ort.InferenceSession.create(modelPath, {
    executionProviders
  });
}

Progressive Enhancement

class ModelRunner {
  async initialize(modelPath) {
    // Try WebGPU first
    try {
      this.session = await ort.InferenceSession.create(modelPath, {
        executionProviders: ['webgpu']
      });
      this.provider = 'webgpu';
      console.log('Using WebGPU');
    } catch (e) {
      // Fallback to WASM with SIMD
      this.session = await ort.InferenceSession.create(modelPath, {
        executionProviders: ['wasm']
      });
      this.provider = 'wasm';
      console.log('Using WebAssembly');
    }
  }
  
  async run(inputs) {
    return await this.session.run(inputs);
  }
}

Performance Optimization

Model Optimization

// Use optimized model formats
const session = await ort.InferenceSession.create('model.ort', {
  executionProviders: ['webgpu'],
  graphOptimizationLevel: 'all',
  enableMemPattern: true
});

// Warm up the model
const dummyInput = new ort.Tensor('float32', new Float32Array(512), [1, 512]);
await session.run({ input: dummyInput });

Batch Processing

async function processBatch(session, inputs) {
  const batchSize = inputs.length;
  
  // Concatenate inputs into single batch
  const batchedData = new Float32Array(batchSize * 512);
  inputs.forEach((input, i) => {
    batchedData.set(input, i * 512);
  });
  
  const batchedTensor = new ort.Tensor(
    'float32',
    batchedData,
    [batchSize, 512]
  );
  
  const results = await session.run({ input: batchedTensor });
  return results;
}

Precision Control

// Use FP16 for better performance (browser-dependent)
const session = await ort.InferenceSession.create('model_fp16.onnx', {
  executionProviders: [{
    name: 'webgpu',
    preferredLayout: 'NCHW'
  }]
});
  • Full precision
  • Best accuracy
  • Higher memory usage
  • Broader browser support

Advanced Usage

Web Worker Integration

// worker.js
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.webgpu.min.js');

let session;

self.onmessage = async (e) => {
  const { type, data } = e.data;
  
  if (type === 'init') {
    session = await ort.InferenceSession.create(data.modelPath, {
      executionProviders: ['webgpu']
    });
    self.postMessage({ type: 'ready' });
  }
  
  if (type === 'run') {
    const results = await session.run(data.inputs);
    self.postMessage({ type: 'result', data: results });
  }
};

// main.js
const worker = new Worker('worker.js');

worker.postMessage({
  type: 'init',
  data: { modelPath: 'model.onnx' }
});

worker.onmessage = (e) => {
  if (e.data.type === 'ready') {
    worker.postMessage({
      type: 'run',
      data: { inputs: { input: inputTensor } }
    });
  }
  
  if (e.data.type === 'result') {
    console.log('Results:', e.data.data);
  }
};

Streaming Inference

async function* streamGeneration(session, inputTokens, maxLength) {
  let tokens = [...inputTokens];
  
  for (let i = 0; i < maxLength; i++) {
    const inputTensor = new ort.Tensor(
      'int64',
      new BigInt64Array(tokens.map(t => BigInt(t))),
      [1, tokens.length]
    );
    
    const results = await session.run({ input_ids: inputTensor });
    const logits = results.logits.data;
    
    // Get next token (simplified)
    const nextToken = argmax(logits.slice(-vocabSize));
    tokens.push(nextToken);
    
    yield nextToken;
    
    if (nextToken === eosTokenId) break;
  }
}

// Usage
for await (const token of streamGeneration(session, inputTokens, 100)) {
  const text = tokenizer.decode([token]);
  console.log(text);
}

Troubleshooting

WebGPU Not Available

if (!navigator.gpu) {
  console.error('WebGPU not supported in this browser');
  console.log('Supported browsers: Chrome 113+, Edge 113+, Safari 18+');
  console.log('Fallback to WebAssembly');
}

Memory Errors

// Use smaller batches
const batchSize = 1;
// Release session when done
await session.release();
if (performance.memory) {
  console.log('Memory:', performance.memory.usedJSHeapSize / 1048576, 'MB');
}

Performance Issues

// Enable all optimizations
const session = await ort.InferenceSession.create('model.onnx', {
  executionProviders: ['webgpu'],
  graphOptimizationLevel: 'all',
  enableMemPattern: true,
  executionMode: 'sequential'
});

// Warm up model
await session.run(dummyInputs);

Benchmarking

async function benchmark(session, inputs, iterations = 10) {
  // Warm up
  await session.run(inputs);
  
  // Benchmark
  const times = [];
  for (let i = 0; i < iterations; i++) {
    const start = performance.now();
    await session.run(inputs);
    const end = performance.now();
    times.push(end - start);
  }
  
  const avg = times.reduce((a, b) => a + b) / times.length;
  const min = Math.min(...times);
  const max = Math.max(...times);
  
  console.log(`Average: ${avg.toFixed(2)}ms`);
  console.log(`Min: ${min.toFixed(2)}ms`);
  console.log(`Max: ${max.toFixed(2)}ms`);
  
  return { avg, min, max };
}

Best Practices

Feature Detection

Always check for WebGPU support before using it.

Fallback Strategy

Implement fallback to WebAssembly for unsupported browsers.

Model Optimization

Use optimized .ort format and FP16 precision when supported.

Web Workers

Run inference in Web Workers to avoid blocking the main thread.

Next Steps

Web Deployment

Deploy models to the web

Model Optimization

Optimize models for browser inference

Build docs developers (and LLMs) love