Created
December 13, 2025 14:04
-
-
Save fukuroder/ca55236faf73519b5964d4816fedc5d1 to your computer and use it in GitHub Desktop.
LibTorch(C++)でカスタムモジュールを実装する
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <torch/torch.h> | |
| #include <torch/script.h> | |
| struct Bottleneck : torch::nn::Module { | |
| torch::nn::Conv2d _conv1; | |
| torch::nn::BatchNorm2d _bn1; | |
| torch::nn::Conv2d _conv2; | |
| torch::nn::BatchNorm2d _bn2; | |
| torch::nn::Conv2d _conv3; | |
| torch::nn::BatchNorm2d _bn3; | |
| torch::nn::ReLU _relu; | |
| torch::nn::Sequential _downsample; | |
| Bottleneck(int in_channels, int out_channels, int width, int stride): | |
| _conv1(torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, width, 1).stride(1).bias(false))), | |
| _bn1(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width))), | |
| _conv2(torch::nn::Conv2d(torch::nn::Conv2dOptions(width, width, 3).stride(stride).padding(1).bias(false))), | |
| _bn2(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width))), | |
| _conv3(torch::nn::Conv2d(torch::nn::Conv2dOptions(width, out_channels, 1).stride(1).bias(false))), | |
| _bn3(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels))), | |
| _relu(torch::nn::ReLU(torch::nn::ReLUOptions(true))) | |
| { | |
| register_module("conv1", _conv1); | |
| register_module("bn1", _bn1); | |
| register_module("conv2", _conv2); | |
| register_module("bn2", _bn2); | |
| register_module("conv3", _conv3); | |
| register_module("bn3", _bn3); | |
| register_module("relu", _relu); | |
| if (in_channels != out_channels) { | |
| _downsample = torch::nn::Sequential( | |
| torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 1).stride(stride).bias(false)), | |
| torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels))); | |
| register_module("downsample", _downsample); | |
| } | |
| } | |
| torch::Tensor forward(torch::Tensor x) { | |
| auto identity = x; | |
| auto out = _conv1(x); | |
| out = _bn1(out); | |
| out = _relu(out); | |
| out = _conv2(out); | |
| out = _bn2(out); | |
| out = _relu(out); | |
| out = _conv3(out); | |
| out = _bn3(out); | |
| if (_downsample->is_empty() == false) { | |
| identity = _downsample->forward(x); | |
| } | |
| out += identity; | |
| out = _relu(out); | |
| return out; | |
| } | |
| }; | |
| struct ResNet50Impl : torch::nn::Module { | |
| torch::nn::Conv2d _conv1; | |
| torch::nn::BatchNorm2d _bn1; | |
| torch::nn::ReLU _relu; | |
| torch::nn::MaxPool2d _maxpool; | |
| torch::nn::Sequential _layer1; | |
| torch::nn::Sequential _layer2; | |
| torch::nn::Sequential _layer3; | |
| torch::nn::Sequential _layer4; | |
| torch::nn::AdaptiveAvgPool2d _avgpool; | |
| torch::nn::Linear _fc; | |
| ResNet50Impl() : | |
| torch::nn::Module("ResNet50"), | |
| _conv1(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)), | |
| _bn1(torch::nn::BatchNorm2dOptions(64)), | |
| _relu(torch::nn::ReLUOptions(true)), | |
| _maxpool(torch::nn::MaxPool2dOptions(3).stride(2).padding(1)), | |
| _layer1( | |
| Bottleneck(64, 256, 64, 1), | |
| Bottleneck(256, 256, 64, 1), | |
| Bottleneck(256, 256, 64, 1)), | |
| _layer2( | |
| Bottleneck(256, 512, 128, 2), | |
| Bottleneck(512, 512, 128, 1), | |
| Bottleneck(512, 512, 128, 1), | |
| Bottleneck(512, 512, 128, 1)), | |
| _layer3( | |
| Bottleneck(512, 1024, 256, 2), | |
| Bottleneck(1024, 1024, 256, 1), | |
| Bottleneck(1024, 1024, 256, 1), | |
| Bottleneck(1024, 1024, 256, 1), | |
| Bottleneck(1024, 1024, 256, 1), | |
| Bottleneck(1024, 1024, 256, 1)), | |
| _layer4( | |
| Bottleneck(1024, 2048, 512, 2), | |
| Bottleneck(2048, 2048, 512, 1), | |
| Bottleneck(2048, 2048, 512, 1)), | |
| _avgpool(torch::nn::AdaptiveAvgPool2dOptions(1)), | |
| _fc(torch::nn::LinearOptions(2048,1000)) | |
| { | |
| register_module("conv1", _conv1); | |
| register_module("bn1", _bn1); | |
| register_module("relu", _relu); | |
| register_module("maxpool", _maxpool); | |
| register_module("layer1", _layer1); | |
| register_module("layer2", _layer2); | |
| register_module("layer3", _layer3); | |
| register_module("layer4", _layer4); | |
| register_module("avgpool", _avgpool); | |
| register_module("fc", _fc); | |
| } | |
| torch::Tensor forward(torch::Tensor x) { | |
| x = _conv1(x); | |
| x = _bn1(x); | |
| x = _relu(x); | |
| x = _maxpool(x); | |
| x = _layer1->forward(x); | |
| x = _layer2->forward(x); | |
| x = _layer3->forward(x); | |
| x = _layer4->forward(x); | |
| x = _avgpool(x); | |
| x = torch::flatten(x, 1); | |
| x = _fc(x); | |
| return x; | |
| } | |
| }; | |
| TORCH_MODULE(ResNet50); | |
| int main() | |
| { | |
| try { | |
| // カスタムモジュール | |
| ResNet50 resnet50_pt; | |
| //torch::load(resnet50_pt, "resnet50.pt"); //★例外が発生する | |
| torch::load(resnet50_pt, "resnet50.ts"); | |
| resnet50_pt->eval(); | |
| // TorchScript | |
| torch::jit::Module resnet50_ts = torch::jit::load("resnet50.ts"); | |
| resnet50_ts.eval(); | |
| // 推論 | |
| auto x = torch::rand({ 1, 3, 512, 512 }); | |
| auto output_pt = resnet50_pt(x); | |
| auto output_ts = resnet50_ts({ x }).toTensor(); | |
| // 誤差計算 | |
| std::cout << "error: " << (output_pt - output_ts).abs().max().item<float_t>() << std::endl; | |
| } | |
| catch (std::exception& ex) { | |
| std::cout << ex.what() << std::endl; | |
| } | |
| return 0; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torchvision import models | |
| # download resnet50 | |
| resnet50_pt = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2) | |
| resnet50_pt.eval() | |
| # save state dict | |
| torch.save(resnet50_pt.state_dict(), "resnet50.pt") | |
| # save torchscript | |
| example = torch.zeros((1,3,512,512)) | |
| resnet50_ts = torch.jit.trace(resnet50_pt, example) | |
| resnet50_ts.save("resnet50.ts") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment