Skip to content

Commit bc2540f

Browse files
gcramer23facebook-github-bot
authored andcommittedMay 8, 2021
benchmark rpc ps (pytorch#57454)
Summary: Pull Request resolved: pytorch#57454 DDP with NCCL AllReduce for the entire model experiment from Quip https://fb.quip.com/iQUtAeKIxWpF I have been testing this on the AI cluster. There seem to be some connection problems with RPC when using multiple trainers or parameter servers. ``` Namespace(bconfig_id='3', dconfig_id='DummyData', mconfig_id='DummyModel', pconfig_id='None', tconfig_id='DdpNcclTrainer') benchmark warmup done metrics for trainer=0 +-----------------------------------+----------+---------+----------+------------+-----------+ | name | min | max | mean | variance | stdev | +===================================+==========+=========+==========+============+===========+ | backward_metric,backward | 2.45248 | 4.18304 | 3.972 | 0.097122 | 0.311644 | +-----------------------------------+----------+---------+----------+------------+-----------+ | batch_level_metric,batch_all | 4.11955 | 4.58138 | 4.31439 | 0.00229848 | 0.0479424 | +-----------------------------------+----------+---------+----------+------------+-----------+ | foward_metric,forward_pass | 0.141312 | 1.4807 | 0.222566 | 0.0555432 | 0.235676 | +-----------------------------------+----------+---------+----------+------------+-----------+ | hook_future_metric,nccl_allreduce | 0.191488 | 3.54099 | 3.11694 | 0.557106 | 0.746395 | +-----------------------------------+----------+---------+----------+------------+-----------+ metrics for trainer=1 +-----------------------------------+----------+---------+----------+-------------+------------+ | name | min | max | mean | variance | stdev | +===================================+==========+=========+==========+=============+============+ | backward_metric,backward | 2.4617 | 2.59174 | 2.51196 | 0.000938276 | 0.0306313 | +-----------------------------------+----------+---------+----------+-------------+------------+ | batch_level_metric,batch_all | 4.22605 | 4.71757 | 4.27921 | 0.00468424 | 0.0684415 | +-----------------------------------+----------+---------+----------+-------------+------------+ | foward_metric,forward_pass | 0.807936 | 1.50118 | 0.846008 | 0.00601693 | 0.0775688 | +-----------------------------------+----------+---------+----------+-------------+------------+ | hook_future_metric,nccl_allreduce | 0.108544 | 0.1536 | 0.11222 | 2.16726e-05 | 0.00465538 | +-----------------------------------+----------+---------+----------+-------------+------------+ metrics for all trainer +-----------------------------------+----------+---------+----------+------------+-----------+ | name | min | max | mean | variance | stdev | +===================================+==========+=========+==========+============+===========+ | backward_metric,backward | 2.45248 | 4.18304 | 3.24198 | 0.584391 | 0.764455 | +-----------------------------------+----------+---------+----------+------------+-----------+ | batch_level_metric,batch_all | 4.11955 | 4.71757 | 4.2968 | 0.00378467 | 0.0615197 | +-----------------------------------+----------+---------+----------+------------+-----------+ | foward_metric,forward_pass | 0.141312 | 1.50118 | 0.534287 | 0.128284 | 0.358167 | +-----------------------------------+----------+---------+----------+------------+-----------+ | hook_future_metric,nccl_allreduce | 0.108544 | 3.54099 | 1.61458 | 2.5456 | 1.59549 | +-----------------------------------+----------+---------+----------+------------+-----------+ ``` Test Plan: Imported from OSS Reviewed By: H-Huang, ngimel Differential Revision: D28296175 Pulled By: gcramer23 fbshipit-source-id: 5dd208fc86f8b5558d7c8860d685bb25c2e09fe7
1 parent 94080f4 commit bc2540f

21 files changed

+1071
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class BenchmarkConfigurations:
6+
trainer_count: int = 1
7+
ps_count: int = 0
8+
batch_size: int = 1
9+
print_metrics_to_dir: bool = False
10+
master_addr: str = "localhost"
11+
master_port: str = "29500"
12+
rpc_async_timeout: int = 5
13+
rpc_init_method: str = "tcp://localhost:29501"
14+
trainer_config: dict = None
15+
ps_config: dict = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# RPC PS Benchmark
2+
3+
## How to add your experiment
4+
5+
1. Data
6+
- Create a data class and add it to the data directory
7+
- Update benchmark_class_helper.py to include your data class in the data_map
8+
- Add configurations to data_configurations.json in the configurations directory
9+
2. Model
10+
- Create a model class and add it to the model directory
11+
- Update benchmark_class_helper.py to include your model class in the model_map
12+
- Add configurations to model_configurations.json in the configurations directory
13+
3. Trainer
14+
- Create a trainer class and add it to the trainer directory
15+
- Update benchmark_class_helper.py to include your trainer class in the trainer_map
16+
- Add configurations to trainer_configurations.json in the configurations directory
17+
4. Parameter Server
18+
- Create a parameter server class and add it to the parameter_servers directory
19+
- Update benchmark_class_helper.py to include your parameter_server class in the ps_map
20+
- Add configurations to parameter_server_configurations.json in the configurations directory
21+
5. Script
22+
- Create a bash script for your experiment and add it to the bash_experiment_scripts directory
23+
24+
## Trainer class
25+
26+
The trainer directory contains base classes to provide a starting point for implementing a trainer.
27+
Inherit from a base class and implement your trainer. The benchmark has two requirements for trainers.
28+
29+
1. It must implement a __init__ method that takes rank, trainer_count, and ps_rref as arguments
30+
31+
```python
32+
def __init__(self, rank, trainer_count, ps_rref, backend, use_cuda_rpc):
33+
```
34+
35+
2. It must implement a train method that takes model and data as arguments.
36+
37+
```python
38+
def train(self, model, data):
39+
```
40+
41+
## Parameter Server class
42+
43+
The parameter_server directory contains base classes to provide a starting point for implementing a parameter server.
44+
Inherit from a base class and implement your parameter server. The benchmark has two requirements for parameter servers.
45+
46+
1. It must implement a __init__ method that takes rank and ps_trainer_count as arguments
47+
48+
```python
49+
def __init__(self, rank, ps_trainer_count, backend, use_cuda_rpc):
50+
```
51+
52+
2. It must implement a reset_state method
53+
54+
```python
55+
def reset_state(ps_rref):
56+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/sh
2+
3+
# requires slurm
4+
# configuration ids
5+
benchmark=3
6+
data="DummyData"
7+
model="DummyModel"
8+
trainer="DdpNcclTrainer"
9+
server="None"
10+
# moves to directory and runs the benchmark with the configurations selected
11+
cd "$(dirname $(dirname "$0"))"
12+
source ./bash_experiment_scripts/helper_functions.sh
13+
run_benchmark_basic "$benchmark" "$data" "$model" "$trainer" "$server"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/sh
2+
3+
run_benchmark_basic() {
4+
# requires slurm
5+
gpurun='srun -p q2 --cpus-per-task=16 -t 5:00:00 --gpus-per-node=4'
6+
$gpurun python launcher.py --benchmark=$1 --data=$2 --model=$3 --trainer=$4 --server=$5
7+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from data.DummyData import DummyData
2+
from models.DummyModel import DummyModel
3+
from trainers.DdpNcclTrainer import DdpNcclTrainer
4+
5+
trainer_map = {
6+
"DdpNcclTrainer": DdpNcclTrainer
7+
}
8+
9+
ps_map = {}
10+
11+
model_map = {
12+
"DummyModel": DummyModel
13+
}
14+
15+
data_map = {
16+
"DummyData": DummyData
17+
}
18+
19+
20+
def get_benchmark_trainer_map():
21+
return trainer_map
22+
23+
24+
def get_benchmark_ps_map():
25+
return ps_map
26+
27+
28+
def get_benchmark_model_map():
29+
return model_map
30+
31+
32+
def get_benchmark_data_map():
33+
return data_map
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"3": {
3+
"trainer_count": 2,
4+
"ps_count": 0,
5+
"rpc_async_timeout": 15,
6+
"batch_size": 5
7+
}
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"DummyData": {
3+
"data_class": "DummyData",
4+
"configurations": {
5+
"max_val": 100,
6+
"input_samples": 100,
7+
"input_dim": 100,
8+
"sparsity_percentage": 20
9+
}
10+
},
11+
"DummyData2": {
12+
"data_class": "DummyData",
13+
"configurations": {
14+
"max_val": 100,
15+
"input_samples": 100,
16+
"input_dim": 100,
17+
"sparsity_percentage": 80
18+
}
19+
}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"DummyModel": {
3+
"model_class": "DummyModel",
4+
"configurations": {
5+
"num_embeddings": 100,
6+
"embedding_dim": 100,
7+
"dense_input_size": 100,
8+
"dense_output_size": 100,
9+
"sparse": false
10+
}
11+
},
12+
"DummyModelSparse": {
13+
"model_class": "DummyModel",
14+
"configurations": {
15+
"num_embeddings": 100,
16+
"embedding_dim": 100,
17+
"dense_input_size": 100,
18+
"dense_output_size": 100,
19+
"sparse": true
20+
}
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"DdpNcclTrainer": {
3+
"trainer_class": "DdpNcclTrainer",
4+
"configurations": {
5+
"epochs": 10
6+
}
7+
}
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import random
2+
3+
import numpy as np
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
8+
class DummyData(Dataset):
9+
10+
def __init__(
11+
self,
12+
max_val: int,
13+
input_samples: int,
14+
input_dim: int,
15+
sparsity_percentage: int
16+
):
17+
self.max_val = max_val
18+
self.input_samples = input_samples
19+
self.input_dim = input_dim
20+
self.sparsity_percentage = sparsity_percentage
21+
22+
def generate_input():
23+
precentage_of_elements = (100 - self.sparsity_percentage) / float(100)
24+
index_count = int(self.max_val * precentage_of_elements)
25+
elements = list(range(self.max_val))
26+
random.shuffle(elements)
27+
elements = elements[:index_count]
28+
data = [
29+
[
30+
elements[random.randint(0, index_count - 1)]
31+
for _ in range(self.input_dim)
32+
]
33+
for _ in range(self.input_samples)
34+
]
35+
return torch.from_numpy(np.array(data))
36+
37+
self.input = generate_input()
38+
self.target = torch.randint(0, max_val, [input_samples])
39+
self.start = 0
40+
self.end = max_val
41+
42+
def __len__(self):
43+
return len(self.input)
44+
45+
def __getitem__(self, index):
46+
return self.input[index], self.target[index]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import argparse
2+
import copy
3+
import json
4+
import os
5+
from pathlib import Path
6+
7+
import torch.distributed.rpc as rpc
8+
import torch.multiprocessing as mp
9+
from torch.distributed.rpc import TensorPipeRpcBackendOptions
10+
from torch.utils.data import DataLoader
11+
12+
from benchmark_class_helper import (get_benchmark_data_map,
13+
get_benchmark_model_map,
14+
get_benchmark_ps_map,
15+
get_benchmark_trainer_map)
16+
from BenchmarkConfigurations import BenchmarkConfigurations
17+
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
18+
19+
USE_CUDA_RPC = "use_cuda_rpc"
20+
21+
22+
def get_name(rank, configs):
23+
t_count = configs.trainer_count
24+
ps_count = configs.ps_count
25+
if rank < t_count:
26+
return f"trainer{rank}"
27+
elif rank < (t_count + ps_count):
28+
return f"ps{rank}"
29+
else:
30+
return "master"
31+
32+
33+
def get_parameter_server_rank(rank, config):
34+
# rank mod parameter server count to get parameter server number
35+
# add trainer_count to get parameter server rank
36+
rank_mod_ps_count = rank % config.ps_count
37+
return rank_mod_ps_count + config.trainer_count
38+
39+
40+
def get_ps_rref(parameter_server_rank, config):
41+
ps_config = config.ps_config
42+
ps = get_benchmark_ps_map()[str(ps_config["ps_class"])]
43+
name = get_name(
44+
parameter_server_rank,
45+
config
46+
)
47+
ps_args = ps_config["configurations"].values()
48+
ps_trainer_count = config.trainer_count / ps_config.ps_count
49+
rem = config.trainer_count % ps_config.ps_count
50+
if parameter_server_rank - config.trainer_count < rem:
51+
ps_trainer_count += 1
52+
return rpc.remote(
53+
name,
54+
ps,
55+
args=(
56+
parameter_server_rank,
57+
ps_trainer_count,
58+
*ps_args,
59+
),
60+
)
61+
62+
63+
def run_trainer(
64+
config, model, data, rank, ps_rref
65+
):
66+
trainer_config = config.trainer_config
67+
trainer_class = get_benchmark_trainer_map()[str(trainer_config["trainer_class"])]
68+
trainer_args = trainer_config["configurations"].values()
69+
trainer = trainer_class(
70+
rank,
71+
config.trainer_count,
72+
ps_rref,
73+
*trainer_args
74+
)
75+
trainer.train(model, data)
76+
metrics = trainer.get_metrics()
77+
return [rank, metrics]
78+
79+
80+
def call_trainers(config, model, train_data, parameter_server_rrefs):
81+
futs = []
82+
for trainer_rank in range(0, config.trainer_count):
83+
trainer_name = get_name(
84+
trainer_rank,
85+
config
86+
)
87+
ps_rref = None
88+
if parameter_server_rrefs:
89+
ps_rank = get_parameter_server_rank(trainer_rank, config)
90+
ps_rref = parameter_server_rrefs[ps_rank]
91+
fut = rpc.rpc_async(
92+
trainer_name,
93+
run_trainer,
94+
args=(
95+
config,
96+
copy.deepcopy(model),
97+
train_data[trainer_rank],
98+
trainer_rank,
99+
ps_rref,
100+
),
101+
timeout=config.rpc_async_timeout
102+
)
103+
futs.append(fut)
104+
return futs
105+
106+
107+
def benchmark_warmup(
108+
config, model, data, parameter_server_rrefs
109+
):
110+
if config.ps_count > 0:
111+
ps_config = config.ps_config
112+
ps = get_benchmark_ps_map()[str(ps_config["ps_class"])]
113+
futs = call_trainers(config, model, data, parameter_server_rrefs)
114+
for fut in futs:
115+
fut.wait()
116+
for ps_rref in parameter_server_rrefs.values():
117+
rpc.rpc_sync(
118+
ps_rref.owner(),
119+
ps.reset_state,
120+
args=(ps_rref,)
121+
)
122+
print("benchmark warmup done\n")
123+
124+
125+
def split_list(arr, n):
126+
return [arr[i::n] for i in range(n)]
127+
128+
129+
def run_master(rank, model, data, config, rpc_backend_options):
130+
world_size = config.trainer_count + config.ps_count + 1
131+
rpc.init_rpc(
132+
get_name(
133+
rank,
134+
config
135+
),
136+
rank=rank,
137+
world_size=world_size,
138+
rpc_backend_options=rpc_backend_options
139+
)
140+
parameter_server_rrefs = {}
141+
for i in range(
142+
config.trainer_count, world_size - 1
143+
):
144+
parameter_server_rrefs[i] = get_ps_rref(i, config)
145+
146+
train_data = split_list(
147+
list(DataLoader(data, batch_size=config.batch_size)),
148+
config.trainer_count
149+
)
150+
151+
# warmup run the benchmark
152+
benchmark_warmup(
153+
config, model, train_data, parameter_server_rrefs
154+
)
155+
# run the benchmark
156+
trainer_futs = call_trainers(
157+
config, model, train_data, parameter_server_rrefs
158+
)
159+
# collect metrics and print
160+
metrics_printer = ProcessedMetricsPrinter()
161+
rank_metrics_list = [fut.wait() for fut in trainer_futs]
162+
metrics_printer.print_metrics("trainer", rank_metrics_list)
163+
164+
165+
def run_benchmark(rank, model, data, config):
166+
167+
world_size = config.trainer_count + config.ps_count + 1
168+
os.environ['MASTER_ADDR'] = config.master_addr
169+
os.environ['MASTER_PORT'] = config.master_port
170+
rpc_backend_options = TensorPipeRpcBackendOptions()
171+
rpc_backend_options.init_method = config.rpc_init_method
172+
if rank == world_size - 1:
173+
# master = [trainer_count + parameter_server_count, trainer_count + parameter_server_count]
174+
run_master(rank, model, data, config, rpc_backend_options)
175+
elif rank >= config.trainer_count:
176+
# parameter_servers = [trainer_count, trainer_count + parameter_server_count)
177+
rpc.init_rpc(
178+
get_name(
179+
rank,
180+
config
181+
),
182+
rank=rank,
183+
world_size=world_size,
184+
rpc_backend_options=rpc_backend_options
185+
)
186+
else:
187+
# trainers = [0, trainer_count)
188+
trainer_config = config.trainer_config
189+
ps_config = config.ps_config
190+
if (USE_CUDA_RPC in trainer_config and
191+
trainer_config[USE_CUDA_RPC] and
192+
USE_CUDA_RPC in ps_config and
193+
ps_config[USE_CUDA_RPC] and
194+
config.ps_count > 0):
195+
ps_rank = get_parameter_server_rank(rank, config)
196+
ps_name = get_name(
197+
ps_rank,
198+
config
199+
)
200+
rpc_backend_options.set_device_map(
201+
ps_name,
202+
{rank: ps_rank}
203+
)
204+
trainer_name = get_name(
205+
rank,
206+
config
207+
)
208+
rpc.init_rpc(
209+
trainer_name,
210+
rank=rank,
211+
world_size=world_size,
212+
rpc_backend_options=rpc_backend_options
213+
)
214+
rpc.shutdown()
215+
216+
217+
def get_json_config(file_name, id):
218+
f = open(
219+
os.path.join(
220+
Path(__file__).parent, file_name
221+
),
222+
"r"
223+
)
224+
json_config = json.load(f)[id]
225+
f.close()
226+
return json_config
227+
228+
229+
def load_configurations(args):
230+
trainer_config_file = args.trainer_config_path
231+
ps_config_file = args.server_config_path
232+
benchmark_config = get_json_config(args.benchmark_config_path, args.benchmark)
233+
benchmark_config["trainer_config"] = get_json_config(trainer_config_file, args.trainer)
234+
if args.server != "None":
235+
benchmark_config["ps_config"] = get_json_config(ps_config_file, args.server)
236+
else:
237+
benchmark_config["ps_config"] = None
238+
return BenchmarkConfigurations(**benchmark_config)
239+
240+
241+
def get_data(data_class, data_config):
242+
data_class = get_benchmark_data_map()[data_class]
243+
return data_class(**data_config)
244+
245+
246+
def load_data(args):
247+
data_config_file = args.data_config_path
248+
data_config = get_json_config(data_config_file, args.data)
249+
return get_data(data_config["data_class"], data_config["configurations"])
250+
251+
252+
def get_model(model_class, model_config):
253+
model_class = get_benchmark_model_map()[model_class]
254+
return model_class(**model_config)
255+
256+
257+
def load_model(args):
258+
model_config_file = args.model_config_path
259+
model_config = get_json_config(model_config_file, args.model)
260+
return get_model(model_config["model_class"], model_config["configurations"])
261+
262+
263+
def main():
264+
parser = argparse.ArgumentParser(description="RPC PS Benchmark")
265+
266+
parser.add_argument(
267+
"--benchmark_config_path",
268+
type=str,
269+
default="configurations/benchmark_configurations.json",
270+
help="path to benchmark configuration file"
271+
)
272+
parser.add_argument(
273+
"--data_config_path",
274+
type=str,
275+
default="configurations/data_configurations.json",
276+
help="path to data configuration file"
277+
)
278+
parser.add_argument(
279+
"--model_config_path",
280+
type=str,
281+
default="configurations/model_configurations.json",
282+
help="path to model configuration file"
283+
)
284+
parser.add_argument(
285+
"--server_config_path",
286+
type=str,
287+
default="configurations/server_configurations.json",
288+
help="path to server configuration file"
289+
)
290+
parser.add_argument(
291+
"--trainer_config_path",
292+
type=str,
293+
default="configurations/trainer_configurations.json",
294+
help="path to trainer configuration file"
295+
)
296+
parser.add_argument(
297+
"--benchmark",
298+
type=str,
299+
help="id for benchmark configuration"
300+
)
301+
parser.add_argument(
302+
"--data",
303+
type=str,
304+
help="id for data configuration"
305+
)
306+
parser.add_argument(
307+
"--model",
308+
type=str,
309+
help="id for model configuration"
310+
)
311+
parser.add_argument(
312+
"--server",
313+
type=str,
314+
help="id for parameter server configuration"
315+
)
316+
parser.add_argument(
317+
"--trainer",
318+
type=str,
319+
help="id for trainer configuration"
320+
)
321+
args = parser.parse_args()
322+
print(f"{args}\n")
323+
324+
config = load_configurations(args)
325+
data = load_data(args)
326+
model = load_model(args)
327+
328+
world_size = config.trainer_count + config.ps_count + 1
329+
330+
mp.spawn(
331+
run_benchmark,
332+
args=(
333+
model,
334+
data,
335+
config,
336+
),
337+
nprocs=world_size,
338+
join=True
339+
)
340+
341+
342+
if __name__ == "__main__":
343+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import time
2+
3+
from .MetricBase import MetricBase
4+
5+
6+
class CPUMetric(MetricBase):
7+
def __init__(self, name: str):
8+
self.name = name
9+
self.start = None
10+
self.end = None
11+
12+
def record_start(self):
13+
self.start = time.time()
14+
15+
def record_end(self):
16+
self.end = time.time()
17+
18+
def elapsed_time(self):
19+
if self.start is None:
20+
raise RuntimeError("start is None")
21+
if self.end is None:
22+
raise RuntimeError("end is None")
23+
return self.end - self.start
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
3+
from .MetricBase import MetricBase
4+
5+
6+
class CUDAMetric(MetricBase):
7+
def __init__(self, rank: int, name: str):
8+
self.rank = rank
9+
self.name = name
10+
self.start = None
11+
self.end = None
12+
13+
def record_start(self):
14+
self.start = torch.cuda.Event(enable_timing=True)
15+
with torch.cuda.device(self.rank):
16+
self.start.record()
17+
18+
def record_end(self):
19+
self.end = torch.cuda.Event(enable_timing=True)
20+
with torch.cuda.device(self.rank):
21+
self.end.record()
22+
23+
def elapsed_time(self):
24+
if not self.start.query():
25+
raise RuntimeError("start event did not complete")
26+
if not self.end.query():
27+
raise RuntimeError("end event did not complete")
28+
return self.start.elapsed_time(self.end)
29+
30+
def synchronize(self):
31+
self.start.synchronize()
32+
self.end.synchronize()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class MetricBase(ABC):
5+
def __init__(self, name):
6+
self.name = name
7+
self.start = None
8+
self.end = None
9+
10+
@abstractmethod
11+
def record_start(self):
12+
return
13+
14+
@abstractmethod
15+
def record_end(self):
16+
return
17+
18+
@abstractmethod
19+
def elapsed_time(self):
20+
return
21+
22+
def get_name(self):
23+
return self.name
24+
25+
def get_end(self):
26+
return self.end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from .CPUMetric import CPUMetric
2+
from .CUDAMetric import CUDAMetric
3+
4+
5+
class MetricsLogger:
6+
7+
def __init__(self, rank=None):
8+
self.rank = rank
9+
self.metrics = {}
10+
11+
def record_start(self, type, key, name, cuda):
12+
if type in self.metrics and key in self.metrics[type]:
13+
raise RuntimeError(f"metric_type={type} with key={key} already exists")
14+
if cuda:
15+
if self.rank is None:
16+
raise RuntimeError("rank is required for cuda")
17+
metric = CUDAMetric(self.rank, name)
18+
else:
19+
metric = CPUMetric(name)
20+
if type not in self.metrics:
21+
self.metrics[type] = {}
22+
self.metrics[type][key] = metric
23+
metric.record_start()
24+
25+
def record_end(self, type, key):
26+
if type not in self.metrics or key not in self.metrics[type]:
27+
raise RuntimeError(f"metric_type={type} with key={key} not found")
28+
if self.metrics[type][key].get_end() is not None:
29+
raise RuntimeError(f"end for metric_type={type} with key={key} already exists")
30+
self.metrics[type][key].record_end()
31+
32+
def clear_metrics(self):
33+
self.metrics.clear()
34+
35+
def get_metrics(self):
36+
return self.metrics
37+
38+
def get_processed_metrics(self):
39+
r"""
40+
A method that processes the metrics recorded during the benchmark.
41+
42+
Returns::
43+
It returns a dictionary containing keys as the metrics
44+
and values list of elapsed times.
45+
46+
Examples::
47+
48+
>>> instance = MetricsLogger(rank)
49+
>>> instance.cuda_record_start("forward_metric_type", "1", "forward_pass")
50+
>>> instance.cuda_record_end("forward_metric_type", "1")
51+
>>> instance.cuda_record_start("forward_metric_type", "2", "forward_pass")
52+
>>> instance.cuda_record_end("forward_metric_type", "2")
53+
>>> print(instance.metrics)
54+
{
55+
"forward_metric_type": {
56+
"1": metric1,
57+
"2": metric2
58+
}
59+
}
60+
61+
>>> print(instance.get_processed_metrics())
62+
{
63+
"forward_metric_type,forward_pass" : [.0429, .0888]
64+
}
65+
"""
66+
processed_metrics = {}
67+
for metric_type in self.metrics.keys():
68+
for metric_key in self.metrics[metric_type].keys():
69+
metric = self.metrics[metric_type][metric_key]
70+
if isinstance(metric, CUDAMetric):
71+
metric.synchronize()
72+
metric_name = metric.get_name()
73+
elapsed_time = metric.elapsed_time()
74+
processed_metric_name = f"{metric_type},{metric_name}"
75+
if processed_metric_name not in processed_metrics:
76+
processed_metrics[processed_metric_name] = []
77+
processed_metrics[processed_metric_name].append(elapsed_time)
78+
return processed_metrics
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import statistics
2+
3+
import pandas as pd
4+
from tabulate import tabulate
5+
6+
7+
class ProcessedMetricsPrinter:
8+
9+
def print_data_frame(self, name, processed_metrics):
10+
print(f"metrics for {name}")
11+
data_frame = self.get_data_frame(processed_metrics)
12+
print(tabulate(data_frame, showindex=False, headers=data_frame.columns, tablefmt="grid"))
13+
14+
def combine_processed_metrics(self, processed_metrics_list):
15+
r"""
16+
A method that merges the value arrays of the keys in the dictionary
17+
of processed metrics.
18+
19+
Args:
20+
processed_metrics_list (list): a list containing dictionaries with
21+
recorded metrics as keys, and the values are lists of elapsed times.
22+
23+
Returns::
24+
A merged dictionary that is created from the list of dictionaries passed
25+
into the method.
26+
27+
Examples::
28+
>>> instance = ProcessedMetricsPrinter()
29+
>>> dict_1 = trainer1.get_processed_metrics()
30+
>>> dict_2 = trainer2.get_processed_metrics()
31+
>>> print(dict_1)
32+
{
33+
"forward_metric_type,forward_pass" : [.0429, .0888]
34+
}
35+
>>> print(dict_2)
36+
{
37+
"forward_metric_type,forward_pass" : [.0111, .0222]
38+
}
39+
>>> processed_metrics_list = [dict_1, dict_2]
40+
>>> result = instance.combine_processed_metrics(processed_metrics_list)
41+
>>> print(result)
42+
{
43+
"forward_metric_type,forward_pass" : [.0429, .0888, .0111, .0222]
44+
}
45+
"""
46+
processed_metric_totals = {}
47+
for processed_metrics in processed_metrics_list:
48+
for metric_name, values in processed_metrics.items():
49+
if metric_name not in processed_metric_totals:
50+
processed_metric_totals[metric_name] = []
51+
processed_metric_totals[metric_name] += values
52+
return processed_metric_totals
53+
54+
def get_data_frame(self, processed_metrics):
55+
df = pd.DataFrame(
56+
columns=['name', 'min', 'max', 'mean', 'variance', 'stdev']
57+
)
58+
for metric_name in sorted(processed_metrics.keys()):
59+
values = processed_metrics[metric_name]
60+
row = {
61+
"name": metric_name,
62+
"min": min(values),
63+
"max": max(values),
64+
"mean": statistics.mean(values),
65+
"variance": statistics.variance(values),
66+
"stdev": statistics.stdev(values)
67+
}
68+
df = df.append(row, ignore_index=True)
69+
return df
70+
71+
def print_metrics(self, name, rank_metrics_list):
72+
if rank_metrics_list:
73+
metrics_list = []
74+
for rank, metric in rank_metrics_list:
75+
self.print_data_frame(f"{name}={rank}", metric)
76+
metrics_list.append(metric)
77+
combined_metrics = self.combine_processed_metrics(metrics_list)
78+
self.print_data_frame(f"all {name}", combined_metrics)
79+
80+
def save_to_file(self, data_frame, file_name):
81+
file_name = f"data_frames/{file_name}.csv"
82+
data_frame.to_csv(file_name, encoding='utf-8', index=False)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
5+
class DummyModel(nn.Module):
6+
def __init__(
7+
self,
8+
num_embeddings: int,
9+
embedding_dim: int,
10+
dense_input_size: int,
11+
dense_output_size: int,
12+
sparse: bool
13+
):
14+
super().__init__()
15+
self.embedding = nn.EmbeddingBag(
16+
num_embeddings, embedding_dim, sparse=sparse
17+
)
18+
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(10)])
19+
20+
def forward(self, x):
21+
x = self.embedding(x)
22+
return F.softmax(self.dense(x), dim=1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
import torch.distributed as c10d
3+
import torch.nn as nn
4+
from torch.nn.parallel import DistributedDataParallel as DDP
5+
6+
from .DdpTrainerBase import DdpTrainerBase
7+
8+
9+
class DdpNcclTrainer(DdpTrainerBase):
10+
11+
class HookState:
12+
13+
def __init__(self, cref, process_group):
14+
self.cref = cref
15+
self.process_group = process_group
16+
self.process_group_size = process_group.size()
17+
self.param_location = 0
18+
self.batch_number = -1
19+
20+
def get_key(self):
21+
return f"{self.batch_number},{self.param_location}"
22+
23+
def next_batch_state(self):
24+
self.param_location = 0
25+
self.batch_number += 1
26+
27+
def __init__(self, rank, trainer_count, ps_rref, epochs):
28+
super().__init__(rank)
29+
self.rank = rank
30+
self.trainer_count = trainer_count
31+
self.epochs = epochs
32+
33+
@staticmethod
34+
def hook(state, bucket):
35+
cref = state.cref
36+
tensors_count = len(cref.bucket_to_parameters(bucket))
37+
tensors = [bucket.get_tensor() / state.process_group_size]
38+
key = state.get_key()
39+
cref.record_hook_fut_start(key, cref.NCCL_ALLREDUCE)
40+
fut = state.process_group.allreduce(tensors).get_future()
41+
state.param_location += tensors_count
42+
43+
def callback(fut):
44+
cref.record_hook_fut_end(key)
45+
return fut.wait()
46+
47+
return fut.then(callback)
48+
49+
def train(self, model, data):
50+
torch.manual_seed(0)
51+
model = model.cuda(self.rank)
52+
for i in range(len(data)):
53+
data[i][0] = data[i][0].cuda(self.rank)
54+
data[i][1] = data[i][1].cuda(self.rank)
55+
torch.cuda.synchronize(self.rank)
56+
57+
process_group_size = self.trainer_count
58+
59+
store = c10d.FileStore("/tmp/tmpn_k_8so02", process_group_size)
60+
61+
process_group = c10d.ProcessGroupNCCL(
62+
store, self.rank, process_group_size
63+
)
64+
65+
ddp_model = DDP(
66+
model, device_ids=[self.rank], process_group=process_group
67+
)
68+
69+
hook_state = self.HookState(self, process_group)
70+
71+
ddp_model.register_comm_hook(hook_state, DdpNcclTrainer.hook)
72+
73+
criterion = nn.CrossEntropyLoss().cuda(self.rank)
74+
75+
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
76+
77+
def epoch_key(epoch, index):
78+
return f"{epoch},{index}"
79+
80+
for epoch in range(self.epochs):
81+
for index, batch in enumerate(data):
82+
hook_state.next_batch_state()
83+
input, target = batch[0], batch[1]
84+
85+
self.record_batch_start(epoch_key(epoch, index))
86+
87+
optimizer.zero_grad()
88+
89+
self.record_forward_start(epoch_key(epoch, index))
90+
91+
out = ddp_model(input)
92+
93+
self.record_forward_end(epoch_key(epoch, index))
94+
95+
loss = criterion(out, target)
96+
97+
self.record_backward_start(epoch_key(epoch, index))
98+
99+
loss.backward()
100+
101+
self.record_backward_end(epoch_key(epoch, index))
102+
103+
optimizer.step()
104+
105+
self.record_batch_end(epoch_key(epoch, index))
106+
107+
torch.cuda.synchronize(self.rank)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from abc import abstractmethod
2+
3+
from .TrainerBase import TrainerBase
4+
5+
6+
class DdpTrainerBase(TrainerBase):
7+
8+
HOOK_FUTURE_METRIC = "hook_future_metric"
9+
NCCL_ALLREDUCE = "nccl_allreduce"
10+
GLOO_ALLREDUCE = "gloo_allreduce"
11+
12+
def __init__(self, rank):
13+
super().__init__(rank)
14+
15+
@staticmethod
16+
@abstractmethod
17+
def hook(state, bucket):
18+
return
19+
20+
def record_hook_fut_start(self, key, name, cuda=True):
21+
self.record_start(self.HOOK_FUTURE_METRIC, key, name, cuda)
22+
23+
def record_hook_fut_end(self, key):
24+
self.record_end(self.HOOK_FUTURE_METRIC, key)
25+
26+
def bucket_to_parameters(self, bucket):
27+
parameter_tensors = bucket.get_per_parameter_tensors()
28+
parameter_tensors_count = len(parameter_tensors)
29+
if parameter_tensors_count > 0:
30+
return parameter_tensors
31+
else:
32+
return [bucket.get_tensor()]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import functools
2+
import time
3+
from abc import ABC, abstractmethod
4+
5+
from metrics.MetricsLogger import MetricsLogger
6+
7+
8+
class TrainerBase(ABC):
9+
10+
BATCH_LEVEL_METRIC = "batch_level_metric"
11+
BATCH_ALL = "batch_all"
12+
FORWARD_METRIC = "foward_metric"
13+
FORWARD_PASS = "forward_pass"
14+
BACKWARD_METRIC = "backward_metric"
15+
BACKWARD = "backward"
16+
17+
def __init__(self, rank):
18+
self.__metrics_logger = MetricsLogger(rank)
19+
20+
@abstractmethod
21+
def train(self):
22+
return
23+
24+
def record_start(self, type, key, name, cuda=True):
25+
self.__metrics_logger.record_start(
26+
type,
27+
key,
28+
name,
29+
cuda
30+
)
31+
32+
def record_end(self, type, key):
33+
self.__metrics_logger.record_end(
34+
type,
35+
key
36+
)
37+
38+
def record_batch_start(self, key, cuda=True):
39+
self.__metrics_logger.record_start(
40+
self.BATCH_LEVEL_METRIC,
41+
key,
42+
self.BATCH_ALL,
43+
cuda
44+
)
45+
46+
def record_batch_end(self, key):
47+
self.__metrics_logger.record_end(
48+
self.BATCH_LEVEL_METRIC,
49+
key
50+
)
51+
52+
def record_forward_start(self, key, cuda=True):
53+
self.__metrics_logger.record_start(
54+
self.FORWARD_METRIC,
55+
key,
56+
self.FORWARD_PASS,
57+
cuda
58+
)
59+
60+
def record_forward_end(self, key):
61+
self.__metrics_logger.record_end(
62+
self.FORWARD_METRIC,
63+
key
64+
)
65+
66+
def record_backward_start(self, key, cuda=True):
67+
self.__metrics_logger.record_start(
68+
self.BACKWARD_METRIC,
69+
key,
70+
self.BACKWARD,
71+
cuda
72+
)
73+
74+
def record_backward_end(self, key):
75+
self.__metrics_logger.record_end(
76+
self.BACKWARD_METRIC,
77+
key
78+
)
79+
80+
@staticmethod
81+
def methodmetric(name, type="method_metric", cuda=True):
82+
def decorator(function):
83+
@functools.wraps(function)
84+
def wrapper(self, *args):
85+
key = time.time()
86+
self.__metrics_logger.record_start(type, key, name, cuda)
87+
result = function(self, *args)
88+
self.__metrics_logger.record_end(type, key)
89+
return result
90+
return wrapper
91+
return decorator
92+
93+
def get_metrics(self):
94+
return self.__metrics_logger.get_processed_metrics()
95+
96+
def clear_metrics(self):
97+
return self.__metrics_logger.clear_metrics()

0 commit comments

Comments
 (0)
Please sign in to comment.