import * as ort from 'onnxruntime-web';
class ImageClassifier {
constructor() {
this.session = null;
}
async initialize(modelPath) {
this.session = await ort.InferenceSession.create(modelPath, {
executionProviders: ['webgpu', 'wasm'],
graphOptimizationLevel: 'all'
});
console.log('Model loaded');
console.log('Inputs:', this.session.inputNames);
console.log('Outputs:', this.session.outputNames);
}
async classify(imageElement) {
// Preprocess image
const tensor = await this.preprocessImage(imageElement);
// Run inference
const feeds = { [this.session.inputNames[0]]: tensor };
const results = await this.session.run(feeds);
// Get output
const output = results[this.session.outputNames[0]];
return this.postprocess(output);
}
async preprocessImage(img) {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
canvas.width = 224;
canvas.height = 224;
ctx.drawImage(img, 0, 0, 224, 224);
const imageData = ctx.getImageData(0, 0, 224, 224);
const pixels = imageData.data;
// Convert to CHW format and normalize
const mean = [0.485, 0.456, 0.406];
const std = [0.229, 0.224, 0.225];
const data = new Float32Array(3 * 224 * 224);
for (let i = 0; i < 224 * 224; i++) {
data[i] = (pixels[i * 4] / 255 - mean[0]) / std[0];
data[224 * 224 + i] = (pixels[i * 4 + 1] / 255 - mean[1]) / std[1];
data[224 * 224 * 2 + i] = (pixels[i * 4 + 2] / 255 - mean[2]) / std[2];
}
return new ort.Tensor('float32', data, [1, 3, 224, 224]);
}
postprocess(output) {
const predictions = Array.from(output.data)
.map((prob, idx) => ({ class: idx, probability: prob }))
.sort((a, b) => b.probability - a.probability)
.slice(0, 5);
return predictions;
}
}
// Usage
const classifier = new ImageClassifier();
await classifier.initialize('./resnet50.onnx');
const img = document.getElementById('image');
const predictions = await classifier.classify(img);
console.log('Top predictions:', predictions);