-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtarpcat
executable file
·125 lines (98 loc) · 3.19 KB
/
tarpcat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
#
# Copyright (c) 2017-2019 NVIDIA CORPORATION. All rights reserved.
# This file is part of webloader (see TBD).
# See the LICENSE file for licensing terms (BSD-style).
#
import argparse
import multiprocessing as mp
import queue as mpq
import random
import sys
import braceexpand
from tarproclib import gopen, proc, reader, writer
parser = argparse.ArgumentParser(
description="Read, shuffle, and combine multiple shards in parallel."
)
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-T", "--filelist", default=None)
parser.add_argument("-b", "--braceexpand", action="store_true")
parser.add_argument("-c", "--count", type=int, default=1000000000)
parser.add_argument("-s", "--shuffle", type=int, default=0)
parser.add_argument("-p", "--workers", type=int, default=8)
parser.add_argument("-o", "--output", default="-")
parser.add_argument("--dummy", action="store_true")
parser.add_argument("input", nargs="*")
args = parser.parse_args()
def dprint(*args, **kw):
print(*args, file=sys.stderr, **kw)
def read_filelist(filelist):
with gopen.gopen(filelist, "r") as stream:
for line in stream:
yield line.strip()
def reader_proc(file_queue, sample_queue):
try:
fname = file_queue.get(timeout=5.0)
print(f"# opening {fname}", file=sys.stderr)
for sample in reader.TarIterator(fname, braceexpand=False):
if "__source__" not in sample:
sample["__source__"] = fname
sample_queue.put(sample)
print(f"# done {fname}", file=sys.stderr)
except mpq.Empty:
pass
if args.filelist is not None:
filelist = list(read_filelist(args.filelist))
elif args.braceexpand:
assert len(args.input) == 1, args.input
filelist = list(braceexpand.braceexpand(args.input[0]))
else:
filelist = args.input
dprint(f"# got {len(filelist)} files")
n = 0
if args.shuffle > 0:
random.shuffle(filelist)
file_queue = mp.Queue(len(filelist) + 10)
sample_queue = mp.Queue(10000)
for fname in filelist:
file_queue.put(fname)
def parallel_source():
jobs = []
for i in range(args.workers):
job_args = (file_queue, sample_queue)
process = mp.Process(target=reader_proc, args=job_args)
jobs.append(process)
for job in jobs:
job.start()
dprint(f"# started {len(jobs)} jobs")
try:
while len(jobs) > 0:
try:
sample = sample_queue.get(timeout=5.0)
yield sample
except mpq.Empty:
dprint("timeout")
for job in jobs:
alive = []
if not job.is_alive():
job.join()
else:
alive.append(job)
jobs = alive
finally:
for job in jobs:
job.kill()
job.join()
source = parallel_source()
if args.shuffle > 0:
source = proc.ishuffle(source, args.shuffle)
sink = writer.TarWriter(args.output, keep_meta=True)
total = 0
for sample in source:
total += 1
if args.dummy:
dprint(sample.get("__key__", total), sample.get("__source__", None))
else:
sink.write(sample)
if total > args.count:
break