Pytorch Torchscript Onnx Export
After model training is complete, the model needs to be deployed to a production environment. PyTorch provides multiple model export methods, among which TorchScript and ONNX are the two most commonly used formats.\\n\\nThis section details the principles, usage, and best practices of these two export methods.\\n\\n> Model export is the process of converting a PyTorch model into a format that can run on different platforms and frameworks. This is crucial for scenarios such as model deployment, mobile inference, and cross-framework migration.\\n\\n* * *\\n\\n## 1. TorchScript Basics\\n\\n### 1.1 What is TorchScript\\n\\nTorchScript is PyTorch's serialization format, which can convert Python code into standalone C++ virtual machine code. TorchScript programs can run in environments without a Python interpreter.\\n\\nMain features of TorchScript:\\n\\n* Converts dynamic graphs to static graphs\\n* Supports a subset of Python syntax\\n* Can run in C++ environments\\n* Preserves the structure and parameters of the model\\n\\n### 1.2 Two Conversion Methods\\n\\nTorchScript provides two ways to convert models to TorchScript:\\n\\n* TorchScript Tracing: Generates a static computation graph by executing the model and recording operations\\n* TorchScript Scripting: Directly analyzes Python code and compiles it into TorchScript\\n\\n* * *\\n\\n## 2. TorchScript Tracing\\n\\n### 2.1 Basic Tracing Method\\n\\n## Instance\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\n# ββ Define Model ββββββββββββββββββββββββββββββββββββββ\\n\\nclass SimpleNet(nn.Module):\\n\\ndef __init__ (self):\\n\\nsuper(). __init__ ()\\n\\nself.conv1= nn.Conv2d(3,64, kernel_size=3, padding=1)\\n\\nself.conv2= nn.Conv2d(64,128, kernel_size=3, padding=1)\\n\\nself.pool= nn.MaxPool2d(2)\\n\\nself.fc= nn.Linear(128 * 8 * 8,10)\\n\\ndef forward(self, x):\\n\\n x =self.pool(torch.relu(self.conv1(x)))\\n\\n x =self.pool(torch.relu(self.conv2(x)))\\n\\n x = x.view(x.size(0), -1)\\n\\n x =self.fc(x)\\n\\nreturn x\\n\\n# Create a model instance\\n\\n model = SimpleNet()\\n\\n model.eval()\\n\\n# Example Input\\n\\n example_input = torch.randn(1,3,64,64)\\n\\n# ββ Method 1: torch.jit.trace Tracing βββββββββββββββββββ\\n\\n# Create TorchScript by running the model and recording operations\\n\\n traced_model = torch.jit.trace(model, example_input)\\n\\nprint("TracingModel after :")\\n\\nprint(traced_model)\\n\\n# SaveModel\\n\\n traced_model.save("simple_net_traced.pt")\\n\\n# Load Model\\n\\n loaded_model = torch.jit.load("simple_net_traced.pt")\\n\\n# UsePerform inference on the loaded model\\n\\n output = loaded_model(example_input)\\n\\nprint(f"Output shape: {output.shape}")\\n\\n### 2.2 Limitations of Tracing\\n\\nThe tracing method has some limitations:\\n\\n* Only records actually executed operations\\n* Control flow (such as if, for) gets fixed\\n* Dynamically sized inputs may cause issues\\n\\n## Instance\\n\\n# TracingExample of limitations\\n\\nclass DynamicModel(nn.Module):\\n\\ndef __init__ (self):\\n\\nsuper(). __init__ ()\\n\\ndef forward(self, x):\\n\\n# Control flow is fixed during tracing\\n\\nif x.sum()>0:\\n\\nreturn x * 2\\n\\nelse:\\n\\nreturn x / 2\\n\\nmodel = DynamicModel()\\n\\n model.eval()\\n\\n# During tracing, `if` branches are fixed to the path taken during tracing\\n\\n example_input = torch.tensor([1.0])\\n\\n traced = torch.jit.trace(model, example_input)\\n\\n# executes the same branch regardless of input differences\\n\\nprint(traced(torch.tensor([1.0])))# 2\\n\\nprint(traced(torch.tensor([-1.0])))# Still 2, not -0.5\\n\\n> For models containing control flow, TorchScript Scripting should be used instead of tracing.\\n\\n* * *\\n\\n## 3. TorchScript Scripting\\n\\n### 3.1 Basic Usage\\n\\n## Instance\\n\\nimport torch\\n\\n# ββ Use @torch.jit.script Decorators βββββββββββββββ\\n\\n@torch.jit.script\\n\\ndef scripted_function(x: torch.Tensor) -> torch.Tensor:\\n\\n"""Usefunction conversion via scripting"""\\n\\nif x.sum()>0:\\n\\nreturn x * 2\\n\\nelse:\\n\\nreturn x / 2\\n\\n# TestScriptingFunction for\\n\\n input1 = torch.tensor([1.0,2.0])\\n\\n input2 = torch.tensor([-1.0, -2.0])\\n\\nprint(scripted_function(input1))# [2., 4.]\\n\\nprint(scripted_function(input2))# [-0.5, -1.]\\n\\n# ββ ScriptingModel βββββββββββββββββββββββββββββββββββ\\n\\nclass ScriptableModel(nn.Module):\\n\\ndef __init__ (self):\\n\\nsuper(). __init__ ()\\n\\nself.fc= nn.Linear(10,10)\\n\\n@torch.jit.export\\n\\ndef forward(self, x: torch.Tensor) -> torch.Tensor:\\n\\nreturn torch.relu(self.fc(x))\\n\\n@torch.jit.export\\n\\ndef predict(self, x: torch.Tensor) -> torch.Tensor:\\n\\n"""Additional Export Methods"""\\n\\n out =self.forward(x)\\n\\nreturn torch.argmax(out, dim=1)\\n\\nmodel = ScriptableModel()\\n\\n# ScriptingModel\\n\\n scripted_model = torch.jit.script(model)\\n\\nprint(scripted_model)\\n\\n# Save\\n\\n scripted_model.save("scripted_model.pt")\\n\\n### 3.2 Scripting Complex Models\\n\\n## Instance\\n\\n# More complex Scripting example: Conditional model\\n\\nclass ConditionModel(nn.Module):\\n\\ndef __init__ (self, num_classes: int):\\n\\nsuper(). __init__ ()\\n\\nself.num_classes= num_classes\\n\\nself.features= nn.Sequential(\\n\\n nn.Conv2d(3,32,3, padding=1),\\n\\n nn.ReLU(),\\n\\n nn.MaxPool2d(2),\\n\\n nn.Conv2d(32,64,3, padding=1),\\n\\n nn.ReLU(),\\n\\n nn.AdaptiveAvgPool2d(1),\\n\\n nn.Flatten()\\n\\n)\\n\\nself.classifier= nn.Linear(64, num_classes)\\n\\ndef forward(self, x: torch.Tensor, use_softmax: bool=False) -> torch.Tensor:\\n\\n"""\\n\\n Model supporting dynamic conditions\\n\\n """\\n\\n features =self.features(x)\\n\\n logits =self.classifier(features)\\n\\nif use_softmax:\\n\\nreturn torch.softmax(logits, dim=1)\\n\\nelse:\\n\\nreturn logits\\n\\ndef get_prediction(self, x: torch.Tensor) -> torch.Tensor:\\n\\n"""Helper Methods"""\\n\\n logits =self.forward(x, use_softmax=False)\\n\\nreturn torch.argmax(logits, dim=1)\\n\\n# Scripting\\n\\n model = ConditionModel(num_classes=10)\\n\\n scripted_model = torch.jit.script(model, example_inputs=(torch.randn(1,3,32,32),))\\n\\n# Test\\n\\n test_input = torch.randn(2,3,32,32)\\n\\n output1 = scripted_model(test_input, use_softmax=False)\\n\\n output2 = scripted_model(test_input, use_softmax=True)\\n\\nprint(f"Logits Output shape: {output1.shape}")\\n\\nprint(f"Softmax Output shape: {output2.shape}")\\n\\n* * *\\n\\n## 4. ONNX Export\\n\\n### 4.1 ONNX Basics\\n\\nONNX (Open Neural Network eXchange) is an open neural network exchange format that supports converting models between different deep learning frameworks.\\n\\nAdvantages of ONNX:\\n\\n* Cross-framework: Supported by PyTorch, TensorFlow, Caffe2, etc.\\n* Cross-platform: Supports CPU, GPU, mobile, and other platforms\\n* Hardware optimization: Can leverage ONNX Runtime for efficient inference\\n* Rich tools: Has a large number of optimization tools and deployment solutions\\n\\n### 4.2 Basic ONNX Export\\n\\n## Instance\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\nimport torchvision\\n\\n# ββ Define Model ββββββββββββββββββββββββββββββββββββββ\\n\\nclass ImageClassifier(nn.Module):\\n\\ndef __init__ (self):\\n\\nsuper(). __init__ ()\\n\\nself.features= nn.Sequential(\\n\\n nn.Conv2d(3,32,3, padding=1),\\n\\n nn.ReLU(),\\n\\n nn.MaxPool2d(2),\\n\\n nn.Conv2d(32,64,3, padding=1),\\n\\n nn.ReLU(),\\n\\n nn.AdaptiveAvgPool2d(1),\\n\\n nn.Flatten()\\n\\n)\\n\\nself.classifier= nn.Linear(64,10)\\n\\ndef forward(self, x):\\n\\n x =self.features(x)\\n\\n x =self.classifier(x)\\n\\nreturn x\\n\\nmodel = ImageClassifier()\\n\\n model.eval()\\n\\n# Example Input\\n\\n example_input = torch.randn(1,3,32,32)\\n\\n# ββ Export to ONNX ββββββββββββββββββββββββββββββββββ\\n\\n output_path ="image_classifier.onnx"\\n\\ntorch.onnx.export(\\n\\n model,\\n\\n example_input,\\n\\n output_path,\\n\\n export_params=True,# Export model parameters\\n\\n opset_version=14,# ONNX Version\\n\\n do_constant_folding=True,# Constant Folding Optimization\\n\\n input_names=['input'],# Input Tensor Names\\n\\n output_names=['output'],# Output Tensor Names\\n\\n dynamic_axes={\\n\\n'input': {0: 'batch_size'},# Dynamic Batch Dimension\\n\\n'output': {0: 'batch_size'}\\n\\n}\\n\\n)\\n\\nprint(f"ModelExported to: {output_path}")\\n\\n# Validate the Exported Model\\n\\nimport onnx\\n\\n onnx_model = onnx.load(output_path)\\n\\n onnx.checker.check_model(onnx_model)\\n\\nprint("ONNX Model validation passed!")\\n\\n### 4.3 Exporting Complex Models\\n\\n## Instance\\n\\n# Export the Complete ResNet Model\\n\\nimport torchvision.models as models\\n\\n# Load pretrained model\\n\\n model = models.resnet18(pretrained=True)\\n\\n model.eval()\\n\\n# Example Input (Standard ResNet 224x224)\\n\\n example_input = torch.randn(1,3,224
YouTip