- This is my second journal paper titled Speeding Up Local SGD with Straggler-Tolerant Synchronization.
- The paper was published in IEEE Transactions on Signal Processing (TSP) in 2024.
-
About this paper:
- Focuses on distributed/federated learning with synchronous local SGD.
- Aims to improve robustness of federated systems against stragglers.
- Proposes a novel local SGD strategy, STSyn, with the following key features:
- Waits for the
$K$ fastest workers while ensuring continuous computation for all workers. - Utilizes all effective (completed) local updates from every worker, even with stragglers.
- Waits for the
- Provides rigorous convergence rates for nonconvex objectives under both homogeneous and heterogeneous data distributions.
- Validates the algorithm through simulations and investigates the impact of system hyperparameters.
- The system consists of
$M$ workers, each performing$U$ local updates per round. - The server waits for the
$K$ -th fastest worker to finish$U$ updates. -
Key Concept: No worker stops computing until the
$K$ -th fastest one completes$U$ updates. - Workers that have completed at least one update upload their models to the server for aggregation.
- Example:
$M=4$ ,$U=3$ ,$K=3$ .- Workers 1, 2, and 3 are the fastest
$K=3$ workers to complete$U=3$ updates in round 0. -
Red arrows: Additional updates performed by the fastest
$K-1=2$ workers. - Light blue arrows: Straggling updates that are cancelled.
- All 4 workers upload their models, as each completes at least one update.
- Workers 1, 2, and 3 are the fastest
Below is the pseudocode for the STSyn algorithm:
- Assuming the time for a single local update by each worker follows an exponential distribution, we provide closed-form expressions for:
- The average wall-clock time per round.
- The average number of local updates per worker per round.
- The average number of uploading workers per round.
-
Heterogeneous Data Distributions:
-
The average expected squared gradient norm for nonconvex objectives is upper bounded by:
$$O\left(\frac{1}{\sqrt{K\bar{U} J}} + \frac{K}{\bar{U}^3 J}\right)$$ where:
-
$K$ : Number of workers the server waits for. -
$\bar{U}$ : Average local updates per worker per round. -
$J$ : Total number of communication rounds.
-
-
-
Homogeneous Data Distributions:
-
The convergence rate is the same as above:
$$O\left(\frac{1}{\sqrt{K\bar{U} J}} + \frac{K}{\bar{U}^3 J}\right)$$
-
- Numerical experiments validate that the STSyn algorithm achieves superior time and communication efficiency under both i.i.d. and non-i.i.d. data distributions among workers.
For more details, refer to the full paper: IEEE Xplore.
-
Comparison.py
:- Compares STSyn with state-of-the-art algorithms on the CIFAR-10 dataset using a three-layer CNN.
-
impact_K
:- Explores the impact of the hyperparameter
$K$ .
- Explores the impact of the hyperparameter
-
impact_U
:- Explores the impact of the hyperparameter
$U$ .
- Explores the impact of the hyperparameter