Skip to content

Commit 0f0a56e

Browse files
authored
fix: add locking for braket_container for writes and local creation of boto clients
1 parent eff1c22 commit 0f0a56e

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

src/braket_container.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
ERROR_LOG_FILE = os.path.join(ERROR_LOG_PATH, "failure")
3737
SETUP_SCRIPT_PATH = os.path.join(OPT_BRAKET, "additional_setup")
3838

39+
_local = threading.local()
40+
_error_log_lock = threading.Lock()
41+
_path_lock = threading.Lock()
42+
_chdir_lock = threading.Lock()
43+
_unpack_lock = threading.Lock()
44+
3945
print("Boto3 Version: ", boto3.__version__)
4046

4147

@@ -47,12 +53,13 @@ def _log_failure(*args, display=True):
4753
Args:
4854
args: variable list of text to write to the file.
4955
"""
50-
Path(ERROR_LOG_PATH).mkdir(parents=True, exist_ok=True)
51-
with open(ERROR_LOG_FILE, 'a') as error_log:
52-
for text in args:
53-
error_log.write(text)
54-
if display:
55-
print(text)
56+
with _error_log_lock:
57+
Path(ERROR_LOG_PATH).mkdir(parents=True, exist_ok=True)
58+
with open(ERROR_LOG_FILE, 'a') as error_log:
59+
for text in args:
60+
error_log.write(text)
61+
if display:
62+
print(text)
5663

5764

5865
def log_failure_and_exit(*args):
@@ -91,6 +98,12 @@ def create_symlink():
9198
log_failure_and_exit(f"Symlink failure.\n Exception: {e}")
9299

93100

101+
def get_s3_client():
102+
if not hasattr(_local, 's3_client'):
103+
_local.s3_client = boto3.client("s3")
104+
return _local.s3_client
105+
106+
94107
def download_s3_file(s3_uri: str, local_path: str) -> str:
95108
"""
96109
Downloads a file to a local path.
@@ -101,7 +114,8 @@ def download_s3_file(s3_uri: str, local_path: str) -> str:
101114
Returns:
102115
str: the path to the file containing the downloaded path.
103116
"""
104-
s3_client = boto3.client("s3")
117+
118+
s3_client = get_s3_client()
105119
parsed_url = urlparse(s3_uri, allow_fragments=False)
106120
s3_bucket = parsed_url.netloc
107121
s3_key = parsed_url.path.lstrip("/")
@@ -138,15 +152,18 @@ def unpack_code_and_add_to_path(local_s3_file: str, compression_type: str):
138152
"""
139153
if compression_type and compression_type.strip().lower() in ["gzip", "zip"]:
140154
try:
141-
shutil.unpack_archive(local_s3_file, EXTRACTED_CUSTOMER_CODE_PATH)
155+
with _unpack_lock:
156+
shutil.unpack_archive(local_s3_file, EXTRACTED_CUSTOMER_CODE_PATH)
142157
except Exception as e:
143158
log_failure_and_exit(
144159
f"Got an exception while trying to unpack archive: {local_s3_file} of type: "
145160
f"{compression_type}.\nException: {e}"
146161
)
147162
else:
148163
shutil.copy(local_s3_file, EXTRACTED_CUSTOMER_CODE_PATH)
149-
sys.path.append(EXTRACTED_CUSTOMER_CODE_PATH)
164+
with _path_lock:
165+
if EXTRACTED_CUSTOMER_CODE_PATH not in sys.path:
166+
sys.path.append(EXTRACTED_CUSTOMER_CODE_PATH)
150167

151168

152169
def try_bind_hyperparameters_to_customer_method(customer_method: Callable):
@@ -243,12 +260,13 @@ def customer_code():
243260

244261
@contextlib.contextmanager
245262
def in_extracted_code_dir():
246-
current_dir = os.getcwd()
247-
try:
248-
os.chdir(EXTRACTED_CUSTOMER_CODE_PATH)
249-
yield
250-
finally:
251-
os.chdir(current_dir)
263+
with _chdir_lock:
264+
current_dir = os.getcwd()
265+
try:
266+
os.chdir(EXTRACTED_CUSTOMER_CODE_PATH)
267+
yield
268+
finally:
269+
os.chdir(current_dir)
252270

253271

254272
def wrap_customer_code(customer_method: Callable) -> Callable:
@@ -301,6 +319,8 @@ def join_customer_script(customer_code_process: multiprocessing.Process):
301319
try:
302320
customer_code_process.join()
303321
except Exception as e:
322+
customer_code_process.terminate()
323+
customer_code_process.join()
304324
log_failure_and_exit(f"Job did not exit gracefully.\nException: {e}")
305325
print("Code Run Finished")
306326
return customer_code_process.exitcode

0 commit comments

Comments
 (0)