前言
產生簡單的 torch model, 並保存
範例
產生一個最單純的 nn.Linear model, 並保存
import torch
import torch.nn as nn
class SingleLinearModel(nn.Module):
def __init__(self):
super(SingleLinearModel, self).__init__()
self.linear = nn.Linear(in_features=10, out_features=5)
def forward(self, x):
return self.linear(x)
linear_model = SingleLinearModel()
dummy_input = torch.rand(1, 10) # Batch size 1, input size 10
traced_script_module = torch.jit.trace(linear_model, dummy_input)
torch.jit.save(traced_script_module, 'model.pt')
print(" ------ model.pt save completed!!! ------")