from torch._C._onnx import OperatorExportTypes
import torch
def torch_onnx_export(
model,
args,
f,
export_params=True,
verbose=False,
training=torch.onnx.TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=OperatorExportTypes.ONNX,
opset_version=14,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
export_modules_as_functions=False,
):
torch.onnx.export(
model=model,
args=args,
f=f,
export_params=export_params,
verbose=verbose,
training=training,
input_names=input_names,
output_names=output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
dynamo=False,
)
# Usage
model = YourModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch_onnx_export(
model=model,
args=(dummy_input,),
f="model.onnx",
input_names=["image"],
output_names=["output"],
opset_version=14
)