Mobilenet with Libtorch
Loading...
Searching...
No Matches
mobilenet_v2.h
1#pragma once
2
3#include <torch/torch.h>
4#include <vector>
5#include <algorithm>
6#include <fstream>
7#include <string>
8#include <regex>
9#include <filesystem>
10#include <iostream>
11#include <system_error>
12#include <opencv2/opencv.hpp>
13
14/***
15 * MobileNetV2 C++ Implementation (LibTorch).
16 * It's able to load pre-trained weights from torchvision
17 * and has the neccessary methods to enable transfer learning.
18 * (c) 2025 Bernd Porr, GPLv3.
19 ***/
20
21#ifdef NDEBUG
22constexpr bool debugOutput = false;
23#else
24constexpr bool debugOutput = true;
25#endif
26
31class MobileNetV2 : public torch::nn::Module
32{
33public:
44 MobileNetV2(int num_classes = 1000, float width_mult = 1.0f, int round_nearest = 8, float dropout = 0.2)
45 {
46 int input_channels = 32;
47 input_channels = make_divisible(input_channels * width_mult, round_nearest);
48 features_output_channels = make_divisible(features_output_channels * std::max(1.0f, width_mult), round_nearest);
49
50 features = torch::nn::Sequential();
51
52 features->push_back(
53 Conv2dNormActivation(3,
54 input_channels,
55 /*kernel_size=*/3,
56 /*stride =*/2));
57
58 // inverted residual blocks
59 for (const auto &cfg : inverted_residual_setting)
60 {
61 const int t = cfg[0];
62 const int c = cfg[1];
63 const int n = cfg[2];
64 const int s = cfg[3];
65
66 int output_channel = make_divisible(c * width_mult, round_nearest);
67 for (int i = 0; i < n; ++i)
68 {
69 const int stride = (i == 0) ? s : 1;
70 features->push_back(
71 InvertedResidual(input_channels, output_channel, stride, t));
72 input_channels = output_channel;
73 }
74 }
75
76 features->push_back(
77 Conv2dNormActivation(input_channels,
78 features_output_channels,
79 /*kernel_size=*/1));
80
81 register_module(featuresModuleName, features);
82
83 // classifier: Dropout + Linear
84 classifier = torch::nn::Sequential();
85 classifier->push_back(torch::nn::Dropout(torch::nn::DropoutOptions(dropout)));
86 classifier->push_back(torch::nn::Linear(torch::nn::LinearOptions(features_output_channels, num_classes)));
87 register_module(classifierModuleName, classifier);
88 }
89
94 static constexpr char featuresModuleName[] = "features";
95
100 static constexpr char classifierModuleName[] = "classifier";
101
108 torch::Tensor forward(torch::Tensor x)
109 {
110 x = features->forward(x);
111 const torch::nn::functional::AdaptiveAvgPool2dFuncOptions &ar = torch::nn::functional::AdaptiveAvgPool2dFuncOptions({1, 1});
112 x = torch::nn::functional::adaptive_avg_pool2d(x, ar);
113 x = torch::flatten(x, 1);
114 x = classifier->forward(x);
115 return x;
116 }
117
122 {
123 for (auto &module : modules(/*include_self=*/false))
124 {
125 if (auto M = dynamic_cast<torch::nn::Conv2dImpl *>(module.get()))
126 {
127 torch::nn::init::kaiming_normal_(M->weight, /*a=*/0, torch::kFanOut, torch::kReLU);
128 if (M->options.bias())
129 torch::nn::init::zeros_(M->bias);
130 }
131 else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl *>(module.get()))
132 {
133 torch::nn::init::ones_(M->weight);
134 torch::nn::init::zeros_(M->bias);
135 }
136 else if (auto M = dynamic_cast<torch::nn::LinearImpl *>(module.get()))
137 {
138 torch::nn::init::normal_(M->weight, 0.0, 0.01);
139 torch::nn::init::zeros_(M->bias);
140 }
141 }
142 }
143
154 void load_torchvision_weights(std::string pt)
155 {
156 std::ifstream input(pt, std::ios::binary);
157 input.exceptions(input.failbit);
158 std::vector<char> bytes(
159 (std::istreambuf_iterator<char>(input)),
160 (std::istreambuf_iterator<char>()));
161 input.close();
162 const c10::Dict<c10::IValue, c10::IValue> weights = torch::pickle_load(bytes).toGenericDict();
163 if (debugOutput)
164 {
165 std::cerr << "Parameters we have in this model here: " << std::endl;
166 for (auto const &m : named_parameters())
167 {
168 auto k = ourkey2torchvision(m.key());
169 std::cerr << m.key() << "->" << k << ": " << m.value().sizes() << std::endl;
170 }
171 std::cerr << "Named buffers we have in this model here: " << std::endl;
172 for (const auto &b : named_buffers())
173 {
174 auto k = ourkey2torchvision(b.key());
175 std::cout << b.key() << "->" << k << ": " << b.value().sizes() << std::endl;
176 }
177 std::cerr << "Parameters we have in the weight file " << pt << ":" << std::endl;
178 for (auto const &w : weights)
179 {
180 std::cerr << w.key() << ": " << w.value().toTensor().sizes() << std::endl;
181 }
182 }
183 torch::NoGradGuard no_grad;
184 if (debugOutput)
185 std::cerr << "Loading weights" << std::endl;
186 for (auto &m : named_parameters())
187 {
188 const std::string model_key = m.key();
189 const std::string model_key4torchvision = ourkey2torchvision(model_key);
190 if (debugOutput)
191 std::cerr << "Searching for: " << model_key4torchvision << ": " << m.value().sizes() << std::endl;
192 bool foundit = false;
193 for (auto const &w : weights)
194 {
195 if (model_key4torchvision == w.key())
196 {
197 if (debugOutput)
198 std::cerr << "Found it: " << w.key() << std::endl;
199 m.value().copy_(w.value().toTensor());
200 foundit = true;
201 break;
202 }
203 }
204 if (!foundit)
205 std::cerr << "Key: " << model_key4torchvision << " could not be found!" << std::endl;
206 }
207 if (debugOutput)
208 std::cerr << "Loading named buffers" << std::endl;
209 for (auto &b : named_buffers())
210 {
211 std::string model_key = b.key();
212 std::string model_key4torchvision = ourkey2torchvision(model_key);
213 if (debugOutput)
214 std::cerr << "Searching for: " << model_key4torchvision << ": " << b.value().sizes() << std::endl;
215 bool foundit = false;
216 for (auto const &w : weights)
217 {
218 if (model_key4torchvision == w.key())
219 {
220 if (debugOutput)
221 std::cerr << "Found it: " << w.key() << std::endl;
222 b.value().copy_(w.value().toTensor());
223 foundit = true;
224 break;
225 }
226 }
227 if (!foundit)
228 std::cerr << "Key: " << model_key4torchvision << " could not be found!" << std::endl;
229 }
230 }
231
242 static torch::Tensor preprocess(cv::Mat img, bool resizeOnly = false)
243 {
244 constexpr int imageSizeBeforeCrop = 256;
245 constexpr int finalImageSize = 224;
246 constexpr int numChannels = 3; // colour
247
248 if (img.depth() != CV_8U)
249 throw std::invalid_argument("Image is not 8bit.");
250 if (img.channels() != numChannels)
251 throw std::invalid_argument("Image is not BGR / colour.");
252
253 if (resizeOnly)
254 {
255 cv::resize(img, img, cv::Size(finalImageSize, finalImageSize));
256 }
257 else
258 {
259 cv::resize(img, img, cv::Size(imageSizeBeforeCrop, imageSizeBeforeCrop));
260 constexpr int start = (imageSizeBeforeCrop - finalImageSize) / 2;
261 const cv::Rect roi(start, start, finalImageSize, finalImageSize);
262 img = img(roi).clone();
263 }
264 cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
265
266 torch::Tensor tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);
267 tensor = tensor.permute({2, 0, 1}).to(torch::kFloat).div_(255.0);
268 tensor = torch::data::transforms::Normalize({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(tensor);
269 return tensor;
270 }
271
280 {
281 return features_output_channels;
282 }
283
290 void replaceClassifier(torch::nn::Sequential &newClassifier)
291 {
292 classifier = newClassifier;
293 replace_module(MobileNetV2::classifierModuleName, newClassifier);
294 }
295
302 void setFeaturesLearning(bool doLearn)
303 {
304 for (auto &p : features->parameters())
305 p.requires_grad_(doLearn);
306 }
307
316 torch::nn::Sequential getClassifier() const
317 {
318 return classifier;
319 }
320
321private:
325 torch::nn::Sequential classifier{nullptr};
326
330 torch::nn::Sequential features{nullptr};
331
332 // Features output channels but can be scaled.
333 int features_output_channels = 1280;
334
335 // MobileNetV2 inverted residual settings:
336 // t, c, n, s (expansion, output channels, repeats, stride)
337 const std::vector<std::array<int, 4>> inverted_residual_setting = {
338 {1, 16, 1, 1},
339 {6, 24, 2, 2},
340 {6, 32, 3, 2},
341 {6, 64, 4, 2},
342 {6, 96, 3, 1},
343 {6, 160, 3, 2},
344 {6, 320, 1, 1},
345 };
346
347 // Helper which maps the libtorch keys to pytorch keys.
348 // libtorch requires names for the submodules, for example:
349 // features.14.InvertedResidual.1.Conv2dNormActivation.1.weight.
350 // However, pytorch has no names for the submodules and needs to be removed:
351 // features.14.conv.1.1.weight.
352 // Also it renames "InvertedResidual" to "conv" which is just due to my
353 // choice to call it what it is and not just "conv".
354 std::string ourkey2torchvision(std::string k) const
355 {
356 // called simply "conv" in the weights file
357 k = std::regex_replace(k, std::regex(InvertedResidual::className), "conv");
358 // not used at all in the weights file
359 const std::string r = std::string(Conv2dNormActivation::className) + "\\.";
360 k = std::regex_replace(k, std::regex(r), "");
361 return k;
362 }
363
364 // Makes a value divisible.
365 inline int make_divisible(int v, int divisor = 8, int min_value = -1) const
366 {
367 if (min_value < 0)
368 min_value = divisor;
369 int new_v = std::max(min_value, ((int)(((int)(v + divisor / 2)) / divisor)) * divisor);
370 if (new_v < (0.9 * (float)v))
371 new_v += divisor;
372 return new_v;
373 }
374
379 class Conv2dNormActivation : public torch::nn::Module
380 {
381 public:
382 static constexpr char className[] = "Conv2dNormActivation";
383
384 static inline torch::Tensor relu6(const torch::Tensor &x)
385 {
386 return torch::clamp(torch::relu(x), 0, 6);
387 }
388
389 Conv2dNormActivation(int in_channels,
390 int out_channels,
391 int kernel_size = 3,
392 int stride = 1,
393 int padding = -1,
394 int groups = 1)
395 {
396 const int dilation = 1;
397 conv = torch::nn::Sequential();
398 if (padding < 0)
399 {
400 padding = (kernel_size - 1) / 2 * dilation;
401 }
402 conv->push_back(torch::nn::Conv2d(
403 torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size)
404 .stride(stride)
405 .padding(padding)
406 .groups(groups)
407 .bias(false)));
408 conv->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
409 conv->push_back(torch::nn::Functional(relu6));
410 register_module(className, conv);
411 }
412
413 torch::Tensor forward(const torch::Tensor &x)
414 {
415 return conv->forward(x);
416 }
417
418 private:
419 torch::nn::Sequential conv{nullptr};
420 };
421
426 class InvertedResidual : public torch::nn::Module
427 {
428 public:
429 static constexpr char className[] = "InvertedResidual";
430
431 InvertedResidual(int inp, int oup, int stride, int expand_ratio)
432 {
433 if ((stride < 1) || (stride > 2))
434 {
435 throw std::invalid_argument("Stride needs to be 1 or 2.");
436 }
437 const int hidden_dim = (int)round(inp * expand_ratio);
438 use_res_connect = (stride == 1) && (inp == oup);
439
440 conv = torch::nn::Sequential();
441
442 if (expand_ratio != 1)
443 {
444 conv->push_back(
445 Conv2dNormActivation(inp,
446 hidden_dim,
447 /*kernel_size*/ 1));
448 }
449
450 conv->push_back(
451 Conv2dNormActivation(hidden_dim,
452 hidden_dim,
453 /*kernel_size=*/3,
454 /*stride=*/stride,
455 /*padding=*/-1,
456 /*groups=*/hidden_dim));
457
458 conv->push_back(torch::nn::Conv2d(
459 torch::nn::Conv2dOptions(hidden_dim, oup,
460 /*kernel_size=*/1)
461 .stride(1)
462 .padding(0)
463 .bias(false)));
464 conv->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(oup)));
465
466 register_module(className, conv);
467 }
468
469 torch::Tensor forward(const torch::Tensor &x)
470 {
471 if (use_res_connect)
472 {
473 return x + conv->forward(x);
474 }
475 else
476 {
477 return conv->forward(x);
478 }
479 }
480 torch::nn::Sequential conv{nullptr};
481 bool use_res_connect;
482 };
483};
Implementation of MobileNetV2 as done in py-torchvision See: // https://github.com/pytorch/vision/blo...
Definition mobilenet_v2.h:32
torch::Tensor forward(torch::Tensor x)
Performs the forward pass.
Definition mobilenet_v2.h:108
int getNinputChannelsOfClassifier() const
Gets the number of input channels of the classifier.
Definition mobilenet_v2.h:279
void load_torchvision_weights(std::string pt)
Loads a .pt weight file containing a dict with key/parameter pairs.
Definition mobilenet_v2.h:154
MobileNetV2(int num_classes=1000, float width_mult=1.0f, int round_nearest=8, float dropout=0.2)
Construct a new MobileNetV2 object.
Definition mobilenet_v2.h:44
void setFeaturesLearning(bool doLearn)
Enables/disables learning in the feature layers.
Definition mobilenet_v2.h:302
torch::nn::Sequential getClassifier() const
Gets the Classifier object.
Definition mobilenet_v2.h:316
void initialize_weights()
Initialize conv/bn/linear similar to torchvision defaults.
Definition mobilenet_v2.h:121
void replaceClassifier(torch::nn::Sequential &newClassifier)
Replaces classifier with a new one.
Definition mobilenet_v2.h:290
static torch::Tensor preprocess(cv::Mat img, bool resizeOnly=false)
Preprocessing of an openCV image for inference or learning.
Definition mobilenet_v2.h:242
static constexpr char classifierModuleName[]
Name of the classifier submodule.
Definition mobilenet_v2.h:100
static constexpr char featuresModuleName[]
Name of the features submodule.
Definition mobilenet_v2.h:94