forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrecord_queue.py
118 lines (100 loc) · 4.32 KB
/
record_queue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
## @package record_queue
# Module caffe2.python.record_queue
"""
Implementation of a queue wrapper.
"""
from caffe2.python import core
from caffe2.python.dataio import Reader, Writer
from caffe2.python.schema import (
Struct, Field, from_column_list)
class _QueueReader(Reader):
def __init__(self, blobs_queue, schema, name=None):
"""Don't call this directly. Instead, use dataset.reader()"""
super().__init__(schema)
self.blobs_queue = blobs_queue
self.name = name
def read(self, read_net):
with core.NameScope(read_net.NextName(self.name)):
status = read_net.NextName()
fields = read_net.SafeDequeueBlobs(
self.blobs_queue, self._schema.field_names() + [status])
return (fields[-1], fields[:-1])
class _QueueWriter(Writer):
def __init__(self, blobs_queue, schema):
self.blobs_queue = blobs_queue
self.schema = schema
def write(self, writer_net, fields):
if isinstance(fields, Field):
fields = fields.field_blobs()
writer_net.CheckDatasetConsistency(
fields, [], fields=self.schema.field_names())
status = writer_net.NextName()
writer_net.SafeEnqueueBlobs(
[self.blobs_queue] + fields, fields + [status])
return status
class RecordQueue:
""" The class is used to feed data with some process from a reader into a
queue and provider a reader interface for data fetching from the queue.
"""
def __init__(self, fields, name=None, capacity=1,
enforce_unique_name=False, num_threads=1):
assert isinstance(fields, list) or isinstance(fields, Struct), (
'fields must be either a Struct or a list of raw field names.')
if isinstance(fields, list):
fields = from_column_list(fields)
self.schema = fields
self.name = name or 'queue'
self.num_threads = num_threads
num_blobs = len(self.schema.field_names())
init_net = core.Net(self.name + '/init_net')
self.blobs_queue = init_net.CreateBlobsQueue(
[], 1,
capacity=capacity,
num_blobs=num_blobs,
enforce_unique_name=enforce_unique_name)
core.workspace.RunNetOnce(init_net)
self.writer = _QueueWriter(self.blobs_queue, self.schema)
reader_name = self.name + '_reader'
self.reader = _QueueReader(self.blobs_queue, self.schema, reader_name)
exit_net = core.Net(self.name + '/exit_net')
exit_net.CloseBlobsQueue(self.blobs_queue, 0)
self.exit_step = core.execution_step(
'{}_close_step'.format(str(exit_net)),
exit_net)
def build(self, reader, process=None):
"""
Build the producer_step to feed data from reader into the queue, and
return the reader interface.
Inputs:
reader: read data which will be stored in the queue.
process: preprocess data before enqueue.
Outputs:
reader: reader to fetch the data from the queue.
producer_step: the step insert the data into the queue. Should be
run with comsume_step together.
exit_step: the step to close queue
schema: the schema for the reader.
"""
producer_steps = []
for i in range(self.num_threads):
name = 'reader_' + str(i)
net_reader = core.Net(name)
should_stop, fields = reader.read_record(net_reader)
step_read = core.execution_step(name, net_reader)
name = 'queue_writer' + str(i)
net_prod = core.Net(name)
field_blobs = fields.field_blobs()
if process:
field_blobs = process(net_prod, fields).field_blobs()
self.writer.write(net_prod, field_blobs)
step_prod = core.execution_step(name, net_prod)
step = core.execution_step(
'producer_' + str(i),
[step_read, step_prod],
should_stop_blob=should_stop)
producer_steps.append(step)
producer_step = core.execution_step(
'producers',
producer_steps,
concurrent_substeps=True)
return self.reader, producer_step, self.exit_step, self.schema