Building a feedforward neural network in C++: base class
In my implementation of a feedforward neural network in C++, the base class ANN_MLP
serves as the foundation for constructing and training the network. This post covers the design and functionality of this class, which encapsulates the core components needed to manage the neural network efficiently.
Class overview
The ANN_MLP
class is within the nn
namespace and provides essential methods for managing network parameters, such as weights and biases. These are crucial to the network’s ability to learn and generalize from data. To enable flexibility, I have included methods like PrintNetworkInfo
, PrintBiases
, and PrintWeights
, which help inspect the current state of the network at any point during training.
Core functionalities
At its core, the ANN_MLP
class supports the ability to store, serialize, and deserialize the network state. This feature is particularly useful when training large models, where saving intermediate states is crucial for resuming training without starting from scratch. By serializing the network, I can save its current state to a file, and with deserialization, I can restore the network later for further training or testing.
void Serialize(const std::string& fname);
void Deserialize(const std::string& fname);
These two methods allow for persistent storage of the model in files, making it practical to store complex configurations and retrieve them as needed. Serialization and deserialization are essential for long-term projects or when working with large datasets that require multiple training sessions.
Epoch Management
Another vital aspect of the ANN_MLP
class is its support for managing the number of training epochs. The methods SetEpochs
and UpdateEpochs
allow me to set the total number of epochs and dynamically update them as training progresses. This flexibility ensures that the training process can be adjusted on the fly, depending on the network’s performance.
void SetEpochs(size_t n) { nEpochs = n; };
void UpdateEpochs(size_t n = 1) { nEpochs += n; };
By using these methods, I can easily control the length of the training process, making it more adaptive to the complexity of the problem and ensuring that the model trains for the optimal number of iterations.
Activation Functions
Non-linear activation functions are crucial for neural networks to learn complex patterns. In my implementation, the ANN_MLP
class supports both the sigmoid and tanh activation functions. These are implemented in both scalar and vectorized forms, ensuring efficiency when applying these functions across matrices during forward propagation.
template <;typename T> inline T sigmoid(T x)
{
return T_C(1) / (T_C(1) + exp(-x));
}
template <typename T&> inline T tanh(T x)
{
return std::tanh(x);
}
The forward propagation uses these activation functions to introduce non-linearity into the network, allowing it to model complex data. Additionally, the derivatives of these functions are used during backpropagation to compute gradients for stochastic gradient descent (SGD), which updates the network’s weights during training.
Random Number Generation
The ANN_MLP
class also integrates random number generators for initializing weights and biases. I use inline methods for generating random numbers from uniform and normal distributions, ensuring that the network starts with diverse initial parameters. These random initializations are key to breaking the symmetry in neural networks and ensuring that different neurons learn different features from the data.
inline T GetRandomNormal() { return normal_distribution(generator); }
inline T GetRandomUniformReal() { return static_cast<T>(uniform_real_distribution(generator)); }
inline int GetRandomUniformInt() { return uniform_int_distribution(generator); }
These randomization methods help maintain flexibility in setting up the neural network and avoid cases where neurons would be initialized with identical weights.
Conclusion
The ANN_MLP
class provides a robust foundation for building and managing feedforward neural networks in C++. By encapsulating network parameters, offering methods for data persistence, and supporting dynamic training processes, this class allows me to focus on higher-level aspects of network design and performance optimization.
For more insights into this topic, you can find the details here
The complete code for this base class is available on Github here.