37
37
get_cluster_table ,
38
38
get_res_during_tuning ,
39
39
is_valid_task ,
40
+ is_valid_uuid ,
40
41
list_to_string ,
41
42
serialize ,
42
43
)
@@ -97,7 +98,8 @@ def ping():
97
98
msg = "Ping fail! Make sure Neural Solution runner is running!"
98
99
break
99
100
except Exception as e :
100
- msg = "Ping fail! {}" .format (e )
101
+ print (e )
102
+ msg = "Ping fail!"
101
103
break
102
104
sock .close ()
103
105
return {"status" : "Healthy" , "msg" : msg } if count == 2 else {"status" : "Failed" , "msg" : msg }
@@ -167,26 +169,31 @@ async def submit_task(task: Task):
167
169
cursor = conn .cursor ()
168
170
task_id = str (uuid .uuid4 ()).replace ("-" , "" )
169
171
sql = (
170
- r"insert into task(id, script_url, optimized, arguments, approach, requirements, workers, status)"
171
- + r" values ('{}', '{}', {}, '{}', '{}', '{}', {}, 'pending')" .format (
172
- task_id ,
173
- task .script_url ,
174
- task .optimized ,
175
- list_to_string (task .arguments ),
176
- task .approach ,
177
- list_to_string (task .requirements ),
178
- task .workers ,
179
- )
172
+ "INSERT INTO task "
173
+ "(id, script_url, optimized, arguments, approach, requirements, workers, status) "
174
+ "VALUES (?, ?, ?, ?, ?, ?, ?, 'pending')"
180
175
)
181
- cursor .execute (sql )
176
+
177
+ task_params = (
178
+ task_id ,
179
+ task .script_url ,
180
+ task .optimized ,
181
+ list_to_string (task .arguments ),
182
+ task .approach ,
183
+ list_to_string (task .requirements ),
184
+ task .workers ,
185
+ )
186
+
187
+ conn .execute (sql , task_params )
182
188
conn .commit ()
183
189
try :
184
190
task_submitter .submit_task (task_id )
185
191
except ConnectionRefusedError :
186
192
msg = "Task Submitted fail! Make sure Neural Solution runner is running!"
187
193
status = "failed"
188
194
except Exception as e :
189
- msg = "Task Submitted fail! {}" .format (e )
195
+ msg = "Task Submitted fail!"
196
+ print (e )
190
197
status = "failed"
191
198
conn .close ()
192
199
else :
@@ -205,6 +212,8 @@ def get_task_by_id(task_id: str):
205
212
Returns:
206
213
json: task status, result, quantized model path
207
214
"""
215
+ if not is_valid_uuid (task_id ):
216
+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
208
217
res = None
209
218
db_path = get_db_path (config .workspace )
210
219
if os .path .isfile (db_path ):
@@ -246,6 +255,8 @@ def get_task_status_by_id(request: Request, task_id: str):
246
255
Returns:
247
256
json: task status and information
248
257
"""
258
+ if not is_valid_uuid (task_id ):
259
+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
249
260
status = "unknown"
250
261
tuning_info = {}
251
262
optimization_result = {}
@@ -290,7 +301,13 @@ async def read_logs(task_id: str):
290
301
Yields:
291
302
str: log lines
292
303
"""
293
- log_path = "{}/task_{}.txt" .format (get_task_log_workspace (config .workspace ), task_id )
304
+ if not is_valid_uuid (task_id ):
305
+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
306
+ log_path = os .path .normpath (os .path .join (get_task_log_workspace (config .workspace ), "task_{}.txt" .format (task_id )))
307
+
308
+ if not log_path .startswith (os .path .normpath (config .workspace )):
309
+ return {"error" : "Logfile not found." }
310
+
294
311
if not os .path .exists (log_path ):
295
312
return {"error" : "Logfile not found." }
296
313
@@ -388,12 +405,17 @@ async def websocket_endpoint(websocket: WebSocket, task_id: str):
388
405
Raises:
389
406
HTTPException: exception
390
407
"""
408
+ if not is_valid_uuid (task_id ):
409
+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
391
410
if not check_log_exists (task_id = task_id , task_log_path = get_task_log_workspace (config .workspace )):
392
411
raise HTTPException (status_code = 404 , detail = "Task not found" )
393
412
await websocket .accept ()
394
413
395
414
# send the log that has been written
396
- log_path = "{}/task_{}.txt" .format (get_task_log_workspace (config .workspace ), task_id )
415
+ log_path = os .path .normpath (os .path .join (get_task_log_workspace (config .workspace ), "task_{}.txt" .format (task_id )))
416
+
417
+ if not log_path .startswith (os .path .normpath (config .workspace )):
418
+ return {"error" : "Logfile not found." }
397
419
last_position = 0
398
420
previous_log = []
399
421
if os .path .exists (log_path ):
@@ -429,6 +451,8 @@ async def download_file(task_id: str):
429
451
Returns:
430
452
FileResponse: quantized model of zip file format
431
453
"""
454
+ if not is_valid_uuid (task_id ):
455
+ raise HTTPException (status_code = 422 , detail = "Invalid task id" )
432
456
db_path = get_db_path (config .workspace )
433
457
if os .path .isfile (db_path ):
434
458
conn = sqlite3 .connect (db_path )
@@ -444,6 +468,9 @@ async def download_file(task_id: str):
444
468
path = res [2 ]
445
469
zip_filename = "quantized_model.zip"
446
470
zip_filepath = os .path .abspath (os .path .join (get_task_workspace (config .workspace ), task_id , zip_filename ))
471
+
472
+ if not zip_filepath .startswith (os .path .normpath (os .path .abspath (get_task_workspace (config .workspace )))):
473
+ raise HTTPException (status_code = 422 , detail = "Invalid File" )
447
474
# create zipfile and add file
448
475
with zipfile .ZipFile (zip_filepath , "w" , zipfile .ZIP_DEFLATED ) as zip_file :
449
476
for root , dirs , files in os .walk (path ):
0 commit comments