Skip to content

Commit 5c8a6f5

Browse files
committed
Move SSL redirection into util.py
1 parent 782929e commit 5c8a6f5

File tree

2 files changed

+60
-30
lines changed

2 files changed

+60
-30
lines changed

neam.py

+25-29
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#!/usr/bin/python3
22
from neam.python.neam import main as run_neam
3+
from neam.python.util import unverified_https_context
34

45
import nltk
56

6-
from contextlib import contextmanager
7-
import ssl
87
import sys
98

109
NEEDED_COLLECTIONS = {
@@ -18,28 +17,21 @@ def main():
1817
Makes sure the NLTK dependencies have been downloaded, then passes off to
1918
the real NEAM script
2019
"""
21-
ensure_collections_exist(NEEDED_COLLECTIONS)
20+
ensure_nltk_collections_exist(NEEDED_COLLECTIONS)
2221
run_neam()
2322

2423

25-
def ensure_collections_exist(collections):
24+
def ensure_nltk_collections_exist(collections):
2625
"""
27-
Ensures that all required collections exist on the system, and downloads
28-
them if they are not present
26+
Ensures that all required NLTK collections exist on the system, and
27+
downloads them if they are not present
2928
3029
:param collections: The collections that must be present - a dict of
3130
folder/name pairs, corresponding to where the
3231
collections should exist in the filesystem
3332
:see: http://www.nltk.org/_modules/nltk/downloader.html
3433
"""
35-
missing_collections = []
36-
37-
for collection_type, collection_names in collections.items():
38-
for collection in collection_names:
39-
try:
40-
nltk.data.find(f"{collection_type}/{collection}")
41-
except LookupError:
42-
missing_collections.append(collection)
34+
missing_collections = find_missing_nltk_collections(collections)
4335

4436
if missing_collections:
4537
print("Downloading NLTK models...", file=sys.stderr)
@@ -49,23 +41,27 @@ def ensure_collections_exist(collections):
4941
nltk.download(collection)
5042

5143

52-
@contextmanager
53-
def unverified_https_context():
44+
def find_missing_nltk_collections(collections):
5445
"""
55-
Turns off the SSL check for the duration of the function
46+
Checks if a dict of NLTK collections exist on the system and returns a
47+
list of collections that are missing
48+
49+
:param collections: The collections that must be present - a dict of
50+
folder/name pairs, corresponding to where the
51+
collections should exist in the filesystem
52+
:return: A list of collections not present on the system
5653
"""
57-
try:
58-
create_unverified_https_context = ssl._create_unverified_context
59-
create_default_https_context = ssl.create_default_https_context
60-
except AttributeError:
61-
pass
62-
else:
63-
ssl._create_default_https_context = create_unverified_https_context
54+
missing_collections = []
55+
56+
for collection_type, collection_names in collections.items():
57+
for collection in collection_names:
58+
try:
59+
nltk.data.find(f"{collection_type}/{collection}")
60+
except LookupError:
61+
missing_collections.append(collection)
6462

65-
try:
66-
yield
67-
finally:
68-
ssl._create_default_https_context = create_default_https_context
63+
return missing_collections
6964

7065

71-
main()
66+
if __name__ == '__main__':
67+
main()

neam/python/util.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,40 @@
1+
"""
2+
Miscellaneous helper methods - should be refactored should this file get too
3+
large.
4+
"""
5+
from contextlib import contextmanager
16
import re
7+
import ssl
8+
29

310
def multi_sub(correspondences, text):
4-
pattern = '({})'.format('|'.join(re.escape(key) for key in correspondences.keys()))
11+
"""
12+
Substitutes multiple keys from a dict with their values within a string
13+
14+
:param correspondences: A dict where the key is the string to look up and
15+
the value is the string to replace it with
16+
:param text: The string on which to conduct the replacements
17+
:return: The substituted string
18+
"""
19+
options = '|'.join(re.escape(key) for key in correspondences.keys())
20+
pattern = '(' + options + ')'
521
return re.sub(pattern, lambda match: correspondences[match.group(0)], text)
622

23+
24+
@contextmanager
25+
def unverified_https_context():
26+
"""
27+
Turns off the SSL check for the duration of the function
28+
"""
29+
try:
30+
create_unverified_https_context = ssl._create_unverified_context
31+
create_default_https_context = ssl.create_default_https_context
32+
except AttributeError:
33+
pass
34+
else:
35+
ssl._create_default_https_context = create_unverified_https_context
36+
37+
try:
38+
yield
39+
finally:
40+
ssl._create_default_https_context = create_default_https_context

0 commit comments

Comments
 (0)