Deep Neuronal Filter
Loading...
Searching...
No Matches
dnf_torch.h
1
7#ifndef _DNF_H
8#define _DNF_H
9
10#include <stdio.h>
11#include <stdlib.h>
12#include <math.h>
13#include <assert.h>
14#include <torch/torch.h>
15#include <thread>
16#include <iostream>
17#include <deque>
18
19#ifdef NDEBUG
20constexpr bool debugOutput = false;
21#else
22constexpr bool debugOutput = true;
23#endif
24
29class DNF
30{
31public:
36 {
37 Act_Sigmoid = 1,
38 Act_Tanh = 2,
39 Act_ReLU = 3,
40 Act_NONE = 0
41 };
42
43private:
44 struct Net : public torch::nn::Module
45 {
46 std::vector<torch::nn::Linear> fc;
47 Net(int nLayers, int nInput, bool withBias = false);
48 torch::Tensor forward(torch::Tensor x, ActMethod am);
49 };
50
51public:
61 DNF(const int nLayers,
62 const int nTaps,
63 const ActMethod am = Act_Tanh,
64 const bool tryGPU = false);
65
72 void setLearningRate(float mu);
73
80 float filter(const float signal, const float noise);
81
87 inline int getSignalDelaySteps() const
88 {
89 return signalDelayLineLength;
90 }
91
97 inline float getDelayedSignal() const
98 {
99 return signal_delayLine.get(0);
100 }
101
106 inline float getRemover() const
107 {
108 return remover;
109 }
110
116 inline float getOutput() const
117 {
118 return f_nn;
119 }
120
125 const std::vector<float> getLayerWeightDistances() const;
126
131 float getWeightDistance() const;
132
137 const torch::Device getTorchDevice() const
138 {
139 return device;
140 }
141
145 const Net getModel() const
146 {
147 return model;
148 }
149
153 static constexpr double xavierGain = 0.01;
154
155private:
156 class DelayLine
157 {
158 public:
159 void init(int delay)
160 {
161 delaySamples = delay;
162 buffer = std::deque<float>(delaySamples, 0.0f);
163 }
164
165 inline float process(float input)
166 {
167 float output = buffer.front();
168 buffer.pop_front();
169 buffer.push_back(input);
170 return output;
171 }
172
173 float get(int i) const
174 {
175 return buffer[i];
176 }
177
178 float getNewest() const
179 {
180 return buffer.back();
181 }
182
183 private:
184 int delaySamples = 0;
185 std::deque<float> buffer;
186 };
187
188 void saveInitialParameters()
189 {
190 for (const auto &p : model.parameters())
191 {
192 initialParameters.push_back(p.detach().clone());
193 }
194 }
195
196 const int noiseDelayLineLength;
197 const int signalDelayLineLength;
198 const ActMethod actMethod;
199 Net model;
200 torch::optim::SGD optimizer;
201 std::vector<torch::Tensor> initialParameters;
202 DelayLine signal_delayLine;
203 DelayLine noise_delayLine;
204 float remover = 0;
205 float f_nn = 0;
206 torch::Device device = torch::kCPU;
207};
208
209#endif
Deep Neuronal Filter https://journals.plos.org/plosone/article?id=10.1371/journal....
Definition dnf_torch.h:30
float getOutput() const
Returns the output of the DNF: the noise free signal.
Definition dnf_torch.h:116
float filter(const float signal, const float noise)
Realtime sample by sample filtering operation.
float getWeightDistance() const
Gets the overall weight distsance.
void setLearningRate(float mu)
Sets the learning rate of the entire network.
ActMethod
Options for activation functions of all neurons in the network.
Definition dnf_torch.h:36
int getSignalDelaySteps() const
Returns the length of the delay line which delays the signal polluted with noise.
Definition dnf_torch.h:87
DNF(const int nLayers, const int nTaps, const ActMethod am=Act_Tanh, const bool tryGPU=false)
Constructor which sets up the delay lines, network layers and also calculates the number of neurons p...
float getRemover() const
Returns the remover signal.
Definition dnf_torch.h:106
const std::vector< float > getLayerWeightDistances() const
Gets the weight distances per layer.
static constexpr double xavierGain
Xavier gain for the weight init.
Definition dnf_torch.h:153
const Net getModel() const
Gets the torch model, for example, to read out the weights.
Definition dnf_torch.h:145
const torch::Device getTorchDevice() const
Gets the torch device for example to determine if the GPU is being used.
Definition dnf_torch.h:137
float getDelayedSignal() const
Returns the delayed with noise polluted signal by the delay indicated by getSignalDelaySteps().
Definition dnf_torch.h:97