13
13
import contextlib
14
14
import errno
15
15
import importlib
16
+ import threading
16
17
import inspect
17
18
import os
18
19
import json
36
37
ERROR_LOG_FILE = os .path .join (ERROR_LOG_PATH , "failure" )
37
38
SETUP_SCRIPT_PATH = os .path .join (OPT_BRAKET , "additional_setup" )
38
39
40
+ _local = threading .local ()
41
+ _error_log_lock = threading .Lock ()
42
+ _path_lock = threading .Lock ()
43
+ _chdir_lock = threading .Lock ()
44
+ _unpack_lock = threading .Lock ()
45
+
39
46
print ("Boto3 Version: " , boto3 .__version__ )
40
47
41
48
@@ -47,12 +54,13 @@ def _log_failure(*args, display=True):
47
54
Args:
48
55
args: variable list of text to write to the file.
49
56
"""
50
- Path (ERROR_LOG_PATH ).mkdir (parents = True , exist_ok = True )
51
- with open (ERROR_LOG_FILE , 'a' ) as error_log :
52
- for text in args :
53
- error_log .write (text )
54
- if display :
55
- print (text )
57
+ with _error_log_lock :
58
+ Path (ERROR_LOG_PATH ).mkdir (parents = True , exist_ok = True )
59
+ with open (ERROR_LOG_FILE , 'a' ) as error_log :
60
+ for text in args :
61
+ error_log .write (text )
62
+ if display :
63
+ print (text )
56
64
57
65
58
66
def log_failure_and_exit (* args ):
@@ -91,6 +99,12 @@ def create_symlink():
91
99
log_failure_and_exit (f"Symlink failure.\n Exception: { e } " )
92
100
93
101
102
+ def get_s3_client ():
103
+ if not hasattr (_local , 's3_client' ):
104
+ _local .s3_client = boto3 .client ("s3" )
105
+ return _local .s3_client
106
+
107
+
94
108
def download_s3_file (s3_uri : str , local_path : str ) -> str :
95
109
"""
96
110
Downloads a file to a local path.
@@ -101,7 +115,8 @@ def download_s3_file(s3_uri: str, local_path: str) -> str:
101
115
Returns:
102
116
str: the path to the file containing the downloaded path.
103
117
"""
104
- s3_client = boto3 .client ("s3" )
118
+
119
+ s3_client = get_s3_client ()
105
120
parsed_url = urlparse (s3_uri , allow_fragments = False )
106
121
s3_bucket = parsed_url .netloc
107
122
s3_key = parsed_url .path .lstrip ("/" )
@@ -138,15 +153,18 @@ def unpack_code_and_add_to_path(local_s3_file: str, compression_type: str):
138
153
"""
139
154
if compression_type and compression_type .strip ().lower () in ["gzip" , "zip" ]:
140
155
try :
141
- shutil .unpack_archive (local_s3_file , EXTRACTED_CUSTOMER_CODE_PATH )
156
+ with _unpack_lock :
157
+ shutil .unpack_archive (local_s3_file , EXTRACTED_CUSTOMER_CODE_PATH )
142
158
except Exception as e :
143
159
log_failure_and_exit (
144
160
f"Got an exception while trying to unpack archive: { local_s3_file } of type: "
145
161
f"{ compression_type } .\n Exception: { e } "
146
162
)
147
163
else :
148
164
shutil .copy (local_s3_file , EXTRACTED_CUSTOMER_CODE_PATH )
149
- sys .path .append (EXTRACTED_CUSTOMER_CODE_PATH )
165
+ with _path_lock :
166
+ if EXTRACTED_CUSTOMER_CODE_PATH not in sys .path :
167
+ sys .path .append (EXTRACTED_CUSTOMER_CODE_PATH )
150
168
151
169
152
170
def try_bind_hyperparameters_to_customer_method (customer_method : Callable ):
@@ -243,12 +261,13 @@ def customer_code():
243
261
244
262
@contextlib .contextmanager
245
263
def in_extracted_code_dir ():
246
- current_dir = os .getcwd ()
247
- try :
248
- os .chdir (EXTRACTED_CUSTOMER_CODE_PATH )
249
- yield
250
- finally :
251
- os .chdir (current_dir )
264
+ with _chdir_lock :
265
+ current_dir = os .getcwd ()
266
+ try :
267
+ os .chdir (EXTRACTED_CUSTOMER_CODE_PATH )
268
+ yield
269
+ finally :
270
+ os .chdir (current_dir )
252
271
253
272
254
273
def wrap_customer_code (customer_method : Callable ) -> Callable :
@@ -301,6 +320,8 @@ def join_customer_script(customer_code_process: multiprocessing.Process):
301
320
try :
302
321
customer_code_process .join ()
303
322
except Exception as e :
323
+ customer_code_process .terminate ()
324
+ customer_code_process .join ()
304
325
log_failure_and_exit (f"Job did not exit gracefully.\n Exception: { e } " )
305
326
print ("Code Run Finished" )
306
327
return customer_code_process .exitcode
0 commit comments