Skip to content

Instantly share code, notes, and snippets.

@fukuroder
Created December 13, 2025 14:04
Show Gist options
  • Select an option

  • Save fukuroder/ca55236faf73519b5964d4816fedc5d1 to your computer and use it in GitHub Desktop.

Select an option

Save fukuroder/ca55236faf73519b5964d4816fedc5d1 to your computer and use it in GitHub Desktop.
LibTorch(C++)でカスタムモジュールを実装する
#include <torch/torch.h>
#include <torch/script.h>
struct Bottleneck : torch::nn::Module {
torch::nn::Conv2d _conv1;
torch::nn::BatchNorm2d _bn1;
torch::nn::Conv2d _conv2;
torch::nn::BatchNorm2d _bn2;
torch::nn::Conv2d _conv3;
torch::nn::BatchNorm2d _bn3;
torch::nn::ReLU _relu;
torch::nn::Sequential _downsample;
Bottleneck(int in_channels, int out_channels, int width, int stride):
_conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, width, 1).stride(1).bias(false))),
_bn1(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width))),
_conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(width, width, 3).stride(stride).padding(1).bias(false))),
_bn2(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width))),
_conv3(torch::nn::Conv2d(torch::nn::Conv2dOptions(width, out_channels, 1).stride(1).bias(false))),
_bn3(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels))),
_relu(torch::nn::ReLU(torch::nn::ReLUOptions(true)))
{
register_module("conv1", _conv1);
register_module("bn1", _bn1);
register_module("conv2", _conv2);
register_module("bn2", _bn2);
register_module("conv3", _conv3);
register_module("bn3", _bn3);
register_module("relu", _relu);
if (in_channels != out_channels) {
_downsample = torch::nn::Sequential(
torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 1).stride(stride).bias(false)),
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
register_module("downsample", _downsample);
}
}
torch::Tensor forward(torch::Tensor x) {
auto identity = x;
auto out = _conv1(x);
out = _bn1(out);
out = _relu(out);
out = _conv2(out);
out = _bn2(out);
out = _relu(out);
out = _conv3(out);
out = _bn3(out);
if (_downsample->is_empty() == false) {
identity = _downsample->forward(x);
}
out += identity;
out = _relu(out);
return out;
}
};
struct ResNet50Impl : torch::nn::Module {
torch::nn::Conv2d _conv1;
torch::nn::BatchNorm2d _bn1;
torch::nn::ReLU _relu;
torch::nn::MaxPool2d _maxpool;
torch::nn::Sequential _layer1;
torch::nn::Sequential _layer2;
torch::nn::Sequential _layer3;
torch::nn::Sequential _layer4;
torch::nn::AdaptiveAvgPool2d _avgpool;
torch::nn::Linear _fc;
ResNet50Impl() :
torch::nn::Module("ResNet50"),
_conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)),
_bn1(torch::nn::BatchNorm2dOptions(64)),
_relu(torch::nn::ReLUOptions(true)),
_maxpool(torch::nn::MaxPool2dOptions(3).stride(2).padding(1)),
_layer1(
Bottleneck(64, 256, 64, 1),
Bottleneck(256, 256, 64, 1),
Bottleneck(256, 256, 64, 1)),
_layer2(
Bottleneck(256, 512, 128, 2),
Bottleneck(512, 512, 128, 1),
Bottleneck(512, 512, 128, 1),
Bottleneck(512, 512, 128, 1)),
_layer3(
Bottleneck(512, 1024, 256, 2),
Bottleneck(1024, 1024, 256, 1),
Bottleneck(1024, 1024, 256, 1),
Bottleneck(1024, 1024, 256, 1),
Bottleneck(1024, 1024, 256, 1),
Bottleneck(1024, 1024, 256, 1)),
_layer4(
Bottleneck(1024, 2048, 512, 2),
Bottleneck(2048, 2048, 512, 1),
Bottleneck(2048, 2048, 512, 1)),
_avgpool(torch::nn::AdaptiveAvgPool2dOptions(1)),
_fc(torch::nn::LinearOptions(2048,1000))
{
register_module("conv1", _conv1);
register_module("bn1", _bn1);
register_module("relu", _relu);
register_module("maxpool", _maxpool);
register_module("layer1", _layer1);
register_module("layer2", _layer2);
register_module("layer3", _layer3);
register_module("layer4", _layer4);
register_module("avgpool", _avgpool);
register_module("fc", _fc);
}
torch::Tensor forward(torch::Tensor x) {
x = _conv1(x);
x = _bn1(x);
x = _relu(x);
x = _maxpool(x);
x = _layer1->forward(x);
x = _layer2->forward(x);
x = _layer3->forward(x);
x = _layer4->forward(x);
x = _avgpool(x);
x = torch::flatten(x, 1);
x = _fc(x);
return x;
}
};
TORCH_MODULE(ResNet50);
int main()
{
try {
// カスタムモジュール
ResNet50 resnet50_pt;
//torch::load(resnet50_pt, "resnet50.pt"); //★例外が発生する
torch::load(resnet50_pt, "resnet50.ts");
resnet50_pt->eval();
// TorchScript
torch::jit::Module resnet50_ts = torch::jit::load("resnet50.ts");
resnet50_ts.eval();
// 推論
auto x = torch::rand({ 1, 3, 512, 512 });
auto output_pt = resnet50_pt(x);
auto output_ts = resnet50_ts({ x }).toTensor();
// 誤差計算
std::cout << "error: " << (output_pt - output_ts).abs().max().item<float_t>() << std::endl;
}
catch (std::exception& ex) {
std::cout << ex.what() << std::endl;
}
return 0;
}
import torch
from torchvision import models
# download resnet50
resnet50_pt = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
resnet50_pt.eval()
# save state dict
torch.save(resnet50_pt.state_dict(), "resnet50.pt")
# save torchscript
example = torch.zeros((1,3,512,512))
resnet50_ts = torch.jit.trace(resnet50_pt, example)
resnet50_ts.save("resnet50.ts")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment