@@ -99,7 +99,7 @@ class BaseReader(ClassProvider):
99
99
__provider_type__ = 'reader'
100
100
101
101
def __init__ (self , data_source , config = None , ** kwargs ):
102
- self .config = config
102
+ self .config = config or {}
103
103
self .data_source = data_source
104
104
self .read_dispatcher = singledispatch (self .read )
105
105
self .read_dispatcher .register (list , self ._read_list )
@@ -127,6 +127,7 @@ def __call__(self, context=None, identifier=None, **kwargs):
127
127
128
128
def configure (self ):
129
129
self .data_source = get_path (self .data_source , is_directory = True )
130
+ self .multi_infer = self .config .get ('multi_infer' , False )
130
131
131
132
def validate_config (self ):
132
133
pass
@@ -146,7 +147,10 @@ def _read_frames_multi_input(self, data_id):
146
147
return self .read_dispatcher (data_id .frames )
147
148
148
149
def read_item (self , data_id ):
149
- return DataRepresentation (self .read_dispatcher (data_id ), identifier = data_id )
150
+ data_rep = DataRepresentation (self .read_dispatcher (data_id ), identifier = data_id )
151
+ if self .multi_infer :
152
+ data_rep .metadata ['multi_infer' ] = True
153
+ return data_rep
150
154
151
155
@property
152
156
def name (self ):
@@ -181,6 +185,7 @@ def configure(self):
181
185
reading_scheme [pattern ] = reader
182
186
183
187
self .reading_scheme = reading_scheme
188
+ self .multi_infer = self .config .get ('multi_infer' , False )
184
189
185
190
def read (self , data_id ):
186
191
for pattern , reader in self .reading_scheme .items ():
@@ -207,14 +212,15 @@ class OpenCVImageReader(BaseReader):
207
212
208
213
def validate_config (self ):
209
214
if self .config :
210
- config_validator = OpenCVImageReaderConfig ('opencv_imread_config' )
215
+ config_validator = OpenCVImageReaderConfig (
216
+ 'opencv_imread_config' , on_extra_argument = ConfigValidator .IGNORE_ON_EXTRA_ARGUMENT
217
+ )
211
218
config_validator .validate (self .config )
212
219
213
220
def configure (self ):
214
221
super ().configure ()
215
222
self .flag = OPENCV_IMREAD_FLAGS [self .config .get ('reading_flag' , 'color' ) if self .config else 'color' ]
216
223
217
-
218
224
def read (self , data_id ):
219
225
return cv2 .imread (str (get_path (self .data_source / data_id )), self .flag )
220
226
@@ -281,6 +287,7 @@ def _read_sequence(self, data_id):
281
287
def configure (self ):
282
288
self .data_source = get_path (self .data_source )
283
289
self .videocap = cv2 .VideoCapture (str (self .data_source ))
290
+ self .multi_infer = self .config .get ('multi_infer' , False )
284
291
285
292
def reset (self ):
286
293
self .current = - 1
@@ -301,6 +308,7 @@ def validate_config(self):
301
308
302
309
def configure (self ):
303
310
self .key = self .config .get ('key' )
311
+ self .multi_infer = self .config .get ('multi_infer' , False )
304
312
305
313
def read (self , data_id ):
306
314
data = read_json (str (self .data_source / data_id ))
@@ -343,6 +351,7 @@ def configure(self):
343
351
if nib is None :
344
352
raise ImportError ('nifty backend for image reading requires nibabel. Please install it before usage.' )
345
353
self .channels_first = self .config .get ('channels_first' , False ) if self .config else False
354
+ self .multi_infer = self .config .get ('multi_infer' , False )
346
355
347
356
def read (self , data_id ):
348
357
nib_image = nib .load (str (get_path (self .data_source / data_id )))
@@ -353,10 +362,12 @@ def read(self, data_id):
353
362
354
363
return image
355
364
365
+
356
366
class NumpyReaderConfig (ConfigValidator ):
357
367
type = StringField (optional = True )
358
368
keys = StringField (optional = True , default = "" )
359
369
370
+
360
371
class NumPyReader (BaseReader ):
361
372
__provider__ = 'numpy_reader'
362
373
@@ -366,8 +377,10 @@ def validate_config(self):
366
377
config_validator .validate (self .config )
367
378
368
379
def configure (self ):
380
+ self .multi_infer = self .config .get ('multi_infer' , False )
369
381
self .keys = self .config .get ('keys' , "" ) if self .config else ""
370
382
self .keys = [t .strip () for t in self .keys .split (',' )] if len (self .keys ) > 0 else []
383
+ self .multi_infer = self .config .get ('multi_infer' , False )
371
384
372
385
def read (self , data_id ):
373
386
data = np .load (str (self .data_source / data_id ))
@@ -384,6 +397,7 @@ def read(self, data_id):
384
397
key = next (iter (data .keys ()))
385
398
return data [key ]
386
399
400
+
387
401
class TensorflowImageReader (BaseReader ):
388
402
__provider__ = 'tf_imread'
389
403
@@ -422,6 +436,7 @@ def configure(self):
422
436
self .single = len (self .feature_list ) == 1
423
437
self .counter = 0
424
438
self .subset = range (len (self .data_source ))
439
+ self .multi_infer = self .config .get ('multi_infer' , False )
425
440
426
441
def read (self , data_id ):
427
442
relevant_annotation = self .data_source [self .subset [self .counter ]]
0 commit comments