8
8
from collections import namedtuple , OrderedDict , defaultdict
9
9
from past .builtins import basestring
10
10
from itertools import chain
11
- from six import binary_type , string_types , text_type
12
11
13
12
from caffe2 .proto import caffe2_pb2
14
13
from caffe2 .python import scope , utils , workspace
@@ -215,9 +214,9 @@ def __init__(self, name, net=None):
215
214
Note that this does not prepends the namescope. If needed, use
216
215
ScopedBlobReference() to prepend the existing namespace.
217
216
"""
218
- if isinstance (name , string_types ):
217
+ if isinstance (name , str ):
219
218
self ._name = name
220
- elif isinstance (name , binary_type ):
219
+ elif isinstance (name , bytes ):
221
220
self ._name = name .decode ('utf-8' )
222
221
else :
223
222
self ._name = str (name )
@@ -230,9 +229,9 @@ def __hash__(self):
230
229
return hash (self ._name )
231
230
232
231
def __eq__ (self , other ):
233
- if isinstance (other , string_types ):
232
+ if isinstance (other , str ):
234
233
return self ._name == other
235
- elif isinstance (other , binary_type ):
234
+ elif isinstance (other , bytes ):
236
235
return self ._name == other .decode ('utf-8' )
237
236
elif isinstance (other , BlobReference ):
238
237
return self ._name == other ._name
@@ -249,12 +248,12 @@ def __repr__(self):
249
248
return 'BlobReference("{}")' .format (self ._name )
250
249
251
250
def __add__ (self , other ):
252
- if not isinstance (other , string_types ):
251
+ if not isinstance (other , str ):
253
252
raise RuntimeError ('Cannot add BlobReference to a non-string.' )
254
253
return BlobReference (self ._name + other , self ._from_net )
255
254
256
255
def __radd__ (self , other ):
257
- if not isinstance (other , string_types ):
256
+ if not isinstance (other , str ):
258
257
raise RuntimeError ('Cannot add a non-string to BlobReference.' )
259
258
return BlobReference (other + self ._name , self ._from_net )
260
259
@@ -272,7 +271,7 @@ def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
272
271
network's __getattr__ function.
273
272
"""
274
273
inputs = [] if inputs is None else inputs
275
- if isinstance (inputs , BlobReference ) or isinstance (inputs , string_types ):
274
+ if isinstance (inputs , BlobReference ) or isinstance (inputs , str ):
276
275
inputs = [inputs ]
277
276
# add self to the input list.
278
277
inputs .insert (0 , self )
@@ -317,7 +316,7 @@ def __dir__(self):
317
316
318
317
def ScopedName (name ):
319
318
"""prefix the name with the current scope."""
320
- if isinstance (name , binary_type ):
319
+ if isinstance (name , bytes ):
321
320
name = name .decode ('ascii' )
322
321
return scope .CurrentNameScope () + name
323
322
@@ -331,7 +330,7 @@ def _RectifyInputOutput(blobs, net=None):
331
330
"""A helper function to rectify the input or output of the CreateOperator
332
331
interface.
333
332
"""
334
- if isinstance (blobs , string_types ) or isinstance ( blobs , binary_type ):
333
+ if isinstance (blobs , ( bytes , str ) ):
335
334
# If blobs is a single string, prepend scope.CurrentNameScope()
336
335
# and put it as a list.
337
336
# TODO(jiayq): enforce using BlobReference instead of raw strings.
@@ -343,7 +342,7 @@ def _RectifyInputOutput(blobs, net=None):
343
342
# If blob is a list, we go through it and type check.
344
343
rectified = []
345
344
for blob in blobs :
346
- if isinstance (blob , string_types ) or isinstance ( blob , binary_type ):
345
+ if isinstance (blob , ( bytes , str ) ):
347
346
rectified .append (ScopedBlobReference (blob , net = net ))
348
347
elif type (blob ) is BlobReference :
349
348
rectified .append (blob )
@@ -385,11 +384,11 @@ def CreateOperator(
385
384
# Add rectified inputs and outputs
386
385
inputs = _RectifyInputOutput (inputs )
387
386
outputs = _RectifyInputOutput (outputs )
388
- operator .input .extend ([ text_type ( i ) for i in inputs ] )
389
- operator .output .extend ([ text_type ( o ) for o in outputs ] )
387
+ operator .input .extend (map ( str , inputs ) )
388
+ operator .output .extend (map ( str , outputs ) )
390
389
if control_input :
391
390
control_input = _RectifyInputOutput (control_input )
392
- operator .control_input .extend ([ text_type ( i ) for i in control_input ] )
391
+ operator .control_input .extend (map ( str , control_input ) )
393
392
# Set device option:
394
393
# (1) If device_option is explicitly set, use device_option.
395
394
# (2) If not, but scope.CurrentDeviceScope() is set,
@@ -667,7 +666,7 @@ def BuildGradientGenerators( # NOQA
667
666
# (2) add outputs to the locally generated blobs
668
667
# If an output corresponds to the gradient of an input, we also
669
668
# record it to gradient_generators
670
- locally_generated_blobs .extend ([ str ( s ) for s in grad_op .output ] )
669
+ locally_generated_blobs .extend (map ( str , grad_op .output ) )
671
670
for i , output in enumerate (grad_op .output ):
672
671
input_index = GetIndexFromGradientList (g_input , output )
673
672
if input_index is not None :
@@ -1095,8 +1094,7 @@ def GetBackwardPass(self, ys):
1095
1094
all_input_to_grad_out = {}
1096
1095
for key , val in all_input_to_grad .items ():
1097
1096
if val is not None :
1098
- if (isinstance (val , string_types ) or
1099
- isinstance (val , binary_type )):
1097
+ if isinstance (val , (bytes , str )):
1100
1098
grad_out = BlobReference (val )
1101
1099
else :
1102
1100
grad_out = GradientSlice (BlobReference (val [0 ]),
@@ -1310,7 +1308,7 @@ def recurrent_network_op_remap(op, prefix, blob_remap):
1310
1308
"""
1311
1309
1312
1310
def get_remapped_str (blob_str ):
1313
- if isinstance (blob_str , binary_type ):
1311
+ if isinstance (blob_str , bytes ):
1314
1312
blob_str = blob_str .decode ('utf-8' )
1315
1313
return blob_remap .get (blob_str , blob_str ).encode ('utf-8' )
1316
1314
@@ -1983,7 +1981,7 @@ def NextName(self, prefix=None, output_id=None):
1983
1981
def _ExtendOps (self , new_ops ):
1984
1982
self ._net .op .extend (new_ops )
1985
1983
for op in new_ops :
1986
- self ._op_outputs .update ([text_type (o ) for o in op .output ])
1984
+ self ._op_outputs .update ([str (o ) for o in op .output ])
1987
1985
1988
1986
def _CheckLookupTables (self ):
1989
1987
'''
0 commit comments