Skip to content

Commit fde220c

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Get rid of six in caffe2 code (pytorch#93956)
Mostly `s/string_types/str/` `s/binary_types/bytes/` and `s/text_types/str/` Also `y.extend([str(x) for x in foo])`->`y.extend(map(str, foo))` As Python-2 is long dead Pull Request resolved: pytorch#93956 Approved by: https://github.com/albanD, https://github.com/Skylion007
1 parent 37fcc53 commit fde220c

File tree

9 files changed

+32
-38
lines changed

9 files changed

+32
-38
lines changed

.circleci/docker/common/install_conda.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
7575
}
7676

7777
# Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README
78-
CONDA_COMMON_DEPS="astunparse pyyaml mkl=2021.4.0 mkl-include=2021.4.0 setuptools six"
78+
CONDA_COMMON_DEPS="astunparse pyyaml mkl=2021.4.0 mkl-include=2021.4.0 setuptools"
7979
if [ "$ANACONDA_PYTHON_VERSION" = "3.11" ]; then
8080
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
8181
# TODO: Stop using `-c malfet`

caffe2/python/core.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from collections import namedtuple, OrderedDict, defaultdict
99
from past.builtins import basestring
1010
from itertools import chain
11-
from six import binary_type, string_types, text_type
1211

1312
from caffe2.proto import caffe2_pb2
1413
from caffe2.python import scope, utils, workspace
@@ -215,9 +214,9 @@ def __init__(self, name, net=None):
215214
Note that this does not prepends the namescope. If needed, use
216215
ScopedBlobReference() to prepend the existing namespace.
217216
"""
218-
if isinstance(name, string_types):
217+
if isinstance(name, str):
219218
self._name = name
220-
elif isinstance(name, binary_type):
219+
elif isinstance(name, bytes):
221220
self._name = name.decode('utf-8')
222221
else:
223222
self._name = str(name)
@@ -230,9 +229,9 @@ def __hash__(self):
230229
return hash(self._name)
231230

232231
def __eq__(self, other):
233-
if isinstance(other, string_types):
232+
if isinstance(other, str):
234233
return self._name == other
235-
elif isinstance(other, binary_type):
234+
elif isinstance(other, bytes):
236235
return self._name == other.decode('utf-8')
237236
elif isinstance(other, BlobReference):
238237
return self._name == other._name
@@ -249,12 +248,12 @@ def __repr__(self):
249248
return 'BlobReference("{}")'.format(self._name)
250249

251250
def __add__(self, other):
252-
if not isinstance(other, string_types):
251+
if not isinstance(other, str):
253252
raise RuntimeError('Cannot add BlobReference to a non-string.')
254253
return BlobReference(self._name + other, self._from_net)
255254

256255
def __radd__(self, other):
257-
if not isinstance(other, string_types):
256+
if not isinstance(other, str):
258257
raise RuntimeError('Cannot add a non-string to BlobReference.')
259258
return BlobReference(other + self._name, self._from_net)
260259

@@ -272,7 +271,7 @@ def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
272271
network's __getattr__ function.
273272
"""
274273
inputs = [] if inputs is None else inputs
275-
if isinstance(inputs, BlobReference) or isinstance(inputs, string_types):
274+
if isinstance(inputs, BlobReference) or isinstance(inputs, str):
276275
inputs = [inputs]
277276
# add self to the input list.
278277
inputs.insert(0, self)
@@ -317,7 +316,7 @@ def __dir__(self):
317316

318317
def ScopedName(name):
319318
"""prefix the name with the current scope."""
320-
if isinstance(name, binary_type):
319+
if isinstance(name, bytes):
321320
name = name.decode('ascii')
322321
return scope.CurrentNameScope() + name
323322

@@ -331,7 +330,7 @@ def _RectifyInputOutput(blobs, net=None):
331330
"""A helper function to rectify the input or output of the CreateOperator
332331
interface.
333332
"""
334-
if isinstance(blobs, string_types) or isinstance(blobs, binary_type):
333+
if isinstance(blobs, (bytes, str)):
335334
# If blobs is a single string, prepend scope.CurrentNameScope()
336335
# and put it as a list.
337336
# TODO(jiayq): enforce using BlobReference instead of raw strings.
@@ -343,7 +342,7 @@ def _RectifyInputOutput(blobs, net=None):
343342
# If blob is a list, we go through it and type check.
344343
rectified = []
345344
for blob in blobs:
346-
if isinstance(blob, string_types) or isinstance(blob, binary_type):
345+
if isinstance(blob, (bytes, str)):
347346
rectified.append(ScopedBlobReference(blob, net=net))
348347
elif type(blob) is BlobReference:
349348
rectified.append(blob)
@@ -385,11 +384,11 @@ def CreateOperator(
385384
# Add rectified inputs and outputs
386385
inputs = _RectifyInputOutput(inputs)
387386
outputs = _RectifyInputOutput(outputs)
388-
operator.input.extend([text_type(i) for i in inputs])
389-
operator.output.extend([text_type(o) for o in outputs])
387+
operator.input.extend(map(str, inputs))
388+
operator.output.extend(map(str, outputs))
390389
if control_input:
391390
control_input = _RectifyInputOutput(control_input)
392-
operator.control_input.extend([text_type(i) for i in control_input])
391+
operator.control_input.extend(map(str, control_input))
393392
# Set device option:
394393
# (1) If device_option is explicitly set, use device_option.
395394
# (2) If not, but scope.CurrentDeviceScope() is set,
@@ -667,7 +666,7 @@ def BuildGradientGenerators( # NOQA
667666
# (2) add outputs to the locally generated blobs
668667
# If an output corresponds to the gradient of an input, we also
669668
# record it to gradient_generators
670-
locally_generated_blobs.extend([str(s) for s in grad_op.output])
669+
locally_generated_blobs.extend(map(str, grad_op.output))
671670
for i, output in enumerate(grad_op.output):
672671
input_index = GetIndexFromGradientList(g_input, output)
673672
if input_index is not None:
@@ -1095,8 +1094,7 @@ def GetBackwardPass(self, ys):
10951094
all_input_to_grad_out = {}
10961095
for key, val in all_input_to_grad.items():
10971096
if val is not None:
1098-
if (isinstance(val, string_types) or
1099-
isinstance(val, binary_type)):
1097+
if isinstance(val, (bytes, str)):
11001098
grad_out = BlobReference(val)
11011099
else:
11021100
grad_out = GradientSlice(BlobReference(val[0]),
@@ -1310,7 +1308,7 @@ def recurrent_network_op_remap(op, prefix, blob_remap):
13101308
"""
13111309

13121310
def get_remapped_str(blob_str):
1313-
if isinstance(blob_str, binary_type):
1311+
if isinstance(blob_str, bytes):
13141312
blob_str = blob_str.decode('utf-8')
13151313
return blob_remap.get(blob_str, blob_str).encode('utf-8')
13161314

@@ -1983,7 +1981,7 @@ def NextName(self, prefix=None, output_id=None):
19831981
def _ExtendOps(self, new_ops):
19841982
self._net.op.extend(new_ops)
19851983
for op in new_ops:
1986-
self._op_outputs.update([text_type(o) for o in op.output])
1984+
self._op_outputs.update([str(o) for o in op.output])
19871985

19881986
def _CheckLookupTables(self):
19891987
'''

caffe2/python/functional.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from caffe2.proto import caffe2_pb2
88
from caffe2.python.onnx.workspace import Workspace
99
from collections import namedtuple
10-
from six import string_types
1110

1211
OpSchema = workspace.C.OpSchema
1312

@@ -19,7 +18,7 @@ def namedtupledict(typename, field_names, *args, **kwargs):
1918
data = namedtuple(typename, field_names, *args, **kwargs)
2019

2120
def getitem(self, key):
22-
if isinstance(key, string_types):
21+
if isinstance(key, str):
2322
key = field_names_map[key]
2423
return super(type(self), self).__getitem__(key)
2524

caffe2/python/net_printer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from contextlib import contextmanager
1414
from copy import copy
1515
from itertools import chain
16-
from six import binary_type, text_type
1716

1817

1918
class Visitor(object):
@@ -192,9 +191,9 @@ def __init__(self, factor_prefixes=False, c2_syntax=True):
192191

193192

194193
def _sanitize_str(s):
195-
if isinstance(s, text_type):
194+
if isinstance(s, str):
196195
sanitized = s
197-
elif isinstance(s, binary_type):
196+
elif isinstance(s, bytes):
198197
sanitized = s.decode('ascii', errors='ignore')
199198
else:
200199
sanitized = str(s)

caffe2/python/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from collections import OrderedDict, namedtuple
2727
from past.builtins import basestring
2828
from itertools import islice
29-
from six import StringIO
29+
from io import StringIO
3030
from typing import Sequence
3131

3232
logger = logging.getLogger(__name__)

caffe2/python/trt/test_trt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tarfile
2222
import tempfile
2323
import shutil
24-
from six.moves.urllib.request import urlretrieve
24+
from urllib.request import urlretrieve
2525

2626
def _print_net(net):
2727
for i in net.external_input:

caffe2/python/utils.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import copy
1515
import functools
1616
import numpy as np
17-
from six import integer_types, binary_type, text_type, string_types
1817

1918
OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
2019
OPTIMIZER_ITERATION_LR_NAME = "optimizer_iteration_lr"
@@ -30,7 +29,7 @@ def OpAlmostEqual(op_a, op_b, ignore_fields=None):
3029
if not isinstance(ignore_fields, list):
3130
ignore_fields = [ignore_fields]
3231

33-
assert all(isinstance(f, text_type) for f in ignore_fields), (
32+
assert all(isinstance(f, str) for f in ignore_fields), (
3433
'Expect each field is text type, but got {}'.format(ignore_fields))
3534

3635
def clean_op(op):
@@ -145,13 +144,13 @@ def MakeArgument(key, value):
145144

146145
if type(value) is float:
147146
argument.f = value
148-
elif type(value) in integer_types or type(value) is bool:
147+
elif type(value) in [bool, int]:
149148
# We make a relaxation that a boolean variable will also be stored as
150149
# int.
151150
argument.i = value
152-
elif isinstance(value, binary_type):
151+
elif isinstance(value, bytes):
153152
argument.s = value
154-
elif isinstance(value, text_type):
153+
elif isinstance(value, str):
155154
argument.s = value.encode('utf-8')
156155
elif isinstance(value, caffe2_pb2.NetDef):
157156
argument.n.CopyFrom(value)
@@ -162,16 +161,16 @@ def MakeArgument(key, value):
162161
v.item() if type(v) is np.float_ else v for v in value
163162
)
164163
elif iterable and all(
165-
type(v) in integer_types or type(v) in [bool, np.int_] for v in value
164+
type(v) in [bool, int, np.int_] for v in value
166165
):
167166
argument.ints.extend(
168167
v.item() if type(v) is np.int_ else v for v in value
169168
)
170169
elif iterable and all(
171-
isinstance(v, binary_type) or isinstance(v, text_type) for v in value
170+
isinstance(v, bytes) or isinstance(v, str) for v in value
172171
):
173172
argument.strings.extend(
174-
v.encode('utf-8') if isinstance(v, text_type) else v
173+
v.encode('utf-8') if isinstance(v, str) else v
175174
for v in value
176175
)
177176
elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
@@ -384,7 +383,7 @@ def EnumClassKeyVals(cls):
384383
for k in dir(cls):
385384
if k == k.upper():
386385
v = getattr(cls, k)
387-
if isinstance(v, string_types):
386+
if isinstance(v, str):
388387
assert v not in enum.values(), (
389388
"Failed to resolve {} as Enum: "
390389
"duplicate entries {}={}, {}={}".format(

functorch/examples/maml_omniglot/support/omniglot_loaders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _check_exists(self):
8282
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))
8383

8484
def download(self):
85-
from six.moves import urllib
85+
import urllib
8686
import zipfile
8787

8888
if self._check_exists():

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ psutil
77
pyyaml
88
requests
99
setuptools
10-
six
1110
types-dataclasses
1211
typing_extensions
1312
sympy

0 commit comments

Comments
 (0)