前言
分析 torch model 的中間層與 shape
範例
import torch
model_path="model.pt"
model = torch.jit.load(model_path)
print(model.graph)
inputs=list(model.graph.inputs())
outputs=list(model.graph.outputs())
# model quote
print(" ------- quote of the model ------- ")
print(f"{inputs[0]=}")
inputs=inputs[1:]
print(" ------- input layer ------- ")
for i, input in enumerate(inputs):
print(f"[input {i}]")
print(f"{input.debugName()=}")
print(" ------- intermediate layer ------- ")
for name, param in model.named_parameters():
print(f"Layer: {name}, Shape: {param.shape}")
print(" ------- output layer ------- ")
for i, output in enumerate(outputs):
print(f"[output {i}]")
print(f"{output.debugName()=}")