What Is Early Stopping?
Early stopping is a method to maximise a model’s generalisability during training, commonly used in deep learning alongside model checkpointing. It operates based on two main strategies:
1. Validation Error Threshold: Training halts when the validation error exceeds a preset threshold.
2. Patience Criterion: Training stops if the validation error persists or worsens over a set number of intervals (e.g., epochs).
Significance of Early Stopping
In training a neural network, the goal is to achieve the best possible performance that can be generalised. However, if the network is trained excessively, it is prone to learning the dataset-specific noise.
When model output variance is at its peak due to noise across datasets, this situation is referred to as overfitting or overtraining in neural network training. Model robustness can be improved by identifying the optimal point to halt training, reducing the need for frequent model updates. Overfitting is best illustrated by the sketch below.
The illustration above depicts the theoretical training and validation error curves during supervised neural network training. Assuming that the validation dataset mirrors the real-world dataset, we can consider the validation error as a proxy indicator for generalizable error.
It is important to steer clear of models generating substantial generalizable error, indicated by the area highlighted in red. Hence, early stopping techniques are essential to determine the appropriate training cessation points.
The complexity of the problem is heightened in real-world scenarios. Initially, machine learning professionals widely acknowledge that neural network trainings with real-world datasets exhibit more intricate, non-monotonic validation error curves. Additionally, within federated learning, where a global model is collaboratively trained across N datasets from N clients, ensuring each client contributes their optimal model per aggregation round is vital to achieving an optimal global model aggregation.
This scenario is sketched below:
Early Stopping for Federated Learning
Federated learning’s server-client topology can use two early stopping hierarchies:
1. Client Early Stopping: Stops local training when the local dataset is exhausted or reaches a threshold, using deep learning frameworks. The local model state is used for updates until the round ends.
2. Server Early Stopping: Applies a similar approach to the aggregated network, based on the server-side validation dataset error metric. The model is selected using the appropriate criterion.
Within the client hierarchy, early stopping functions similarly, occur at the client-side level, following the same specifications mentioned earlier. For instance, one can define a threshold for validation errors, establish a set interval, or a combination of both. These methods are thoroughly documented within deep learning frameworks, such as PyTorch Lightning, and can be seamlessly incorporated into the fit method of your customised (OctaiPipe) FL client.
Moreover, recognising that each client datasets vary in size, it proves advantageous to halt local training once the local dataset is fully utilised and leverage this local model state for all forthcoming steps in federated learning until the round concludes.
On the server side, a similar mechanism is used, only now the condition is imposed on the aggregated model. Once error start going over a predefined threshold, further rounds are cancelled and the latest model deemed optimal for the training session.
Conclusion
Early stopping strategies are crucial to enhance the resilience of FL systems and prevent overfitting. By strategically implementing early stopping techniques in federated learning, organisations can enhance model generalizability and efficiency, ensuring optimal performance in collaborative training environments.