44 MobileNetV2(
int num_classes = 1000,
float width_mult = 1.0f,
int round_nearest = 8,
float dropout = 0.2)
46 int input_channels = 32;
47 input_channels = make_divisible(input_channels * width_mult, round_nearest);
48 features_output_channels = make_divisible(features_output_channels * std::max(1.0f, width_mult), round_nearest);
50 features = torch::nn::Sequential();
53 Conv2dNormActivation(3,
59 for (
const auto &cfg : inverted_residual_setting)
66 int output_channel = make_divisible(c * width_mult, round_nearest);
67 for (
int i = 0; i < n; ++i)
69 const int stride = (i == 0) ? s : 1;
71 InvertedResidual(input_channels, output_channel, stride, t));
72 input_channels = output_channel;
77 Conv2dNormActivation(input_channels,
78 features_output_channels,
84 classifier = torch::nn::Sequential();
85 classifier->push_back(torch::nn::Dropout(torch::nn::DropoutOptions(dropout)));
86 classifier->push_back(torch::nn::Linear(torch::nn::LinearOptions(features_output_channels, num_classes)));
110 x = features->forward(x);
111 const torch::nn::functional::AdaptiveAvgPool2dFuncOptions &ar = torch::nn::functional::AdaptiveAvgPool2dFuncOptions({1, 1});
112 x = torch::nn::functional::adaptive_avg_pool2d(x, ar);
113 x = torch::flatten(x, 1);
114 x = classifier->forward(x);
123 for (
auto &module : modules(
false))
125 if (
auto M =
dynamic_cast<torch::nn::Conv2dImpl *
>(module.get()))
127 torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut, torch::kReLU);
128 if (M->options.bias())
129 torch::nn::init::zeros_(M->bias);
131 else if (
auto M =
dynamic_cast<torch::nn::BatchNorm2dImpl *
>(module.get()))
133 torch::nn::init::ones_(M->weight);
134 torch::nn::init::zeros_(M->bias);
136 else if (
auto M =
dynamic_cast<torch::nn::LinearImpl *
>(module.get()))
138 torch::nn::init::normal_(M->weight, 0.0, 0.01);
139 torch::nn::init::zeros_(M->bias);
156 std::ifstream input(pt, std::ios::binary);
157 input.exceptions(input.failbit);
158 std::vector<char> bytes(
159 (std::istreambuf_iterator<char>(input)),
160 (std::istreambuf_iterator<char>()));
162 const c10::Dict<c10::IValue, c10::IValue> weights = torch::pickle_load(bytes).toGenericDict();
165 std::cerr <<
"Parameters we have in this model here: " << std::endl;
166 for (
auto const &m : named_parameters())
168 auto k = ourkey2torchvision(m.key());
169 std::cerr << m.key() <<
"->" << k <<
": " << m.value().sizes() << std::endl;
171 std::cerr <<
"Named buffers we have in this model here: " << std::endl;
172 for (
const auto &b : named_buffers())
174 auto k = ourkey2torchvision(b.key());
175 std::cout << b.key() <<
"->" << k <<
": " << b.value().sizes() << std::endl;
177 std::cerr <<
"Parameters we have in the weight file " << pt <<
":" << std::endl;
178 for (
auto const &w : weights)
180 std::cerr << w.key() <<
": " << w.value().toTensor().sizes() << std::endl;
183 torch::NoGradGuard no_grad;
185 std::cerr <<
"Loading weights" << std::endl;
186 for (
auto &m : named_parameters())
188 const std::string model_key = m.key();
189 const std::string model_key4torchvision = ourkey2torchvision(model_key);
191 std::cerr <<
"Searching for: " << model_key4torchvision <<
": " << m.value().sizes() << std::endl;
192 bool foundit =
false;
193 for (
auto const &w : weights)
195 if (model_key4torchvision == w.key())
198 std::cerr <<
"Found it: " << w.key() << std::endl;
199 m.value().copy_(w.value().toTensor());
205 std::cerr <<
"Key: " << model_key4torchvision <<
" could not be found!" << std::endl;
208 std::cerr <<
"Loading named buffers" << std::endl;
209 for (
auto &b : named_buffers())
211 std::string model_key = b.key();
212 std::string model_key4torchvision = ourkey2torchvision(model_key);
214 std::cerr <<
"Searching for: " << model_key4torchvision <<
": " << b.value().sizes() << std::endl;
215 bool foundit =
false;
216 for (
auto const &w : weights)
218 if (model_key4torchvision == w.key())
221 std::cerr <<
"Found it: " << w.key() << std::endl;
222 b.value().copy_(w.value().toTensor());
228 std::cerr <<
"Key: " << model_key4torchvision <<
" could not be found!" << std::endl;
242 static torch::Tensor
preprocess(cv::Mat img,
bool resizeOnly =
false)
244 constexpr int imageSizeBeforeCrop = 256;
245 constexpr int finalImageSize = 224;
246 constexpr int numChannels = 3;
248 if (img.depth() != CV_8U)
249 throw std::invalid_argument(
"Image is not 8bit.");
250 if (img.channels() != numChannels)
251 throw std::invalid_argument(
"Image is not BGR / colour.");
255 cv::resize(img, img, cv::Size(finalImageSize, finalImageSize));
259 cv::resize(img, img, cv::Size(imageSizeBeforeCrop, imageSizeBeforeCrop));
260 constexpr int start = (imageSizeBeforeCrop - finalImageSize) / 2;
261 const cv::Rect roi(start, start, finalImageSize, finalImageSize);
262 img = img(roi).clone();
264 cv::cvtColor(img, img, cv::COLOR_BGR2RGB);
266 torch::Tensor tensor = torch::from_blob(img.data, {img.rows, img.cols, 3}, torch::kByte);
267 tensor = tensor.permute({2, 0, 1}).to(torch::kFloat).div_(255.0);
268 tensor = torch::data::transforms::Normalize({0.485, 0.456, 0.406}, {0.229, 0.224, 0.225})(tensor);
281 return features_output_channels;
292 classifier = newClassifier;
304 for (
auto &p : features->parameters())
305 p.requires_grad_(doLearn);
325 torch::nn::Sequential classifier{
nullptr};
330 torch::nn::Sequential features{
nullptr};
333 int features_output_channels = 1280;
337 const std::vector<std::array<int, 4>> inverted_residual_setting = {
354 std::string ourkey2torchvision(std::string k)
const
357 k = std::regex_replace(k, std::regex(InvertedResidual::className),
"conv");
359 const std::string r = std::string(Conv2dNormActivation::className) +
"\\.";
360 k = std::regex_replace(k, std::regex(r),
"");
365 inline int make_divisible(
int v,
int divisor = 8,
int min_value = -1)
const
369 int new_v = std::max(min_value, ((
int)(((
int)(v + divisor / 2)) / divisor)) * divisor);
370 if (new_v < (0.9 * (
float)v))
379 class Conv2dNormActivation :
public torch::nn::Module
382 static constexpr char className[] =
"Conv2dNormActivation";
384 static inline torch::Tensor relu6(
const torch::Tensor &x)
386 return torch::clamp(torch::relu(x), 0, 6);
389 Conv2dNormActivation(
int in_channels,
396 const int dilation = 1;
397 conv = torch::nn::Sequential();
400 padding = (kernel_size - 1) / 2 * dilation;
402 conv->push_back(torch::nn::Conv2d(
403 torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size)
408 conv->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels)));
409 conv->push_back(torch::nn::Functional(relu6));
410 register_module(className, conv);
413 torch::Tensor forward(
const torch::Tensor &x)
415 return conv->forward(x);
419 torch::nn::Sequential conv{
nullptr};
426 class InvertedResidual :
public torch::nn::Module
429 static constexpr char className[] =
"InvertedResidual";
431 InvertedResidual(
int inp,
int oup,
int stride,
int expand_ratio)
433 if ((stride < 1) || (stride > 2))
435 throw std::invalid_argument(
"Stride needs to be 1 or 2.");
437 const int hidden_dim = (int)round(inp * expand_ratio);
438 use_res_connect = (stride == 1) && (inp == oup);
440 conv = torch::nn::Sequential();
442 if (expand_ratio != 1)
445 Conv2dNormActivation(inp,
451 Conv2dNormActivation(hidden_dim,
458 conv->push_back(torch::nn::Conv2d(
459 torch::nn::Conv2dOptions(hidden_dim, oup,
464 conv->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(oup)));
466 register_module(className, conv);
469 torch::Tensor forward(
const torch::Tensor &x)
473 return x + conv->forward(x);
477 return conv->forward(x);
480 torch::nn::Sequential conv{
nullptr};
481 bool use_res_connect;
MobileNetV2(int num_classes=1000, float width_mult=1.0f, int round_nearest=8, float dropout=0.2)
Construct a new MobileNetV2 object.
Definition mobilenet_v2.h:44
static torch::Tensor preprocess(cv::Mat img, bool resizeOnly=false)
Preprocessing of an openCV image for inference or learning.
Definition mobilenet_v2.h:242