Skip to content

Commit d16688d

Browse files
authored
infra: format .py and .ipynb files with black (aws#2598)
1 parent 43b71af commit d16688d

File tree

20 files changed

+437
-357
lines changed

20 files changed

+437
-357
lines changed

hyperparameter_tuning/tensorflow_mnist/hpo_tensorflow_mnist.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"outputs": [],
76
"metadata": {
87
"collapsed": true,
98
"jupyter": {
109
"outputs_hidden": true
1110
}
1211
},
12+
"outputs": [],
1313
"source": [
1414
"# Install dependencies\n",
1515
"!pip install -q smdebug\n",
@@ -18,7 +18,7 @@
1818
"!pip install -q opencv-python\n",
1919
"!pip install -q shap\n",
2020
"!pip install -q bokeh\n",
21-
"!pip install -q imageio\n"
21+
"!pip install -q imageio"
2222
]
2323
},
2424
{

introduction_to_applying_machine_learning/US-census_population_segmentation_PCA_Kmeans/sagemaker-countycensusclustering.ipynb

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"outputs": [],
76
"metadata": {
87
"collapsed": true,
98
"jupyter": {
109
"outputs_hidden": true
1110
}
1211
},
12+
"outputs": [],
1313
"source": [
1414
"# Install dependencies\n",
1515
"!pip install -q smdebug\n",
@@ -18,7 +18,7 @@
1818
"!pip install -q opencv-python\n",
1919
"!pip install -q shap\n",
2020
"!pip install -q bokeh\n",
21-
"!pip install -q imageio\n"
21+
"!pip install -q imageio"
2222
]
2323
},
2424
{
@@ -2860,4 +2860,4 @@
28602860
},
28612861
"nbformat": 4,
28622862
"nbformat_minor": 2
2863-
}
2863+
}

sagemaker-debugger/pytorch_model_debugging/pytorch_script_change_smdebug.ipynb

+53-41
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
"outputs": [],
7777
"source": [
7878
"import sagemaker\n",
79+
"\n",
7980
"sagemaker.__version__"
8081
]
8182
},
@@ -377,7 +378,13 @@
377378
"import pytest\n",
378379
"from sagemaker.pytorch import PyTorch\n",
379380
"from sagemaker import get_execution_role\n",
380-
"from sagemaker.debugger import Rule, DebuggerHookConfig, TensorBoardOutputConfig, CollectionConfig, rule_configs"
381+
"from sagemaker.debugger import (\n",
382+
" Rule,\n",
383+
" DebuggerHookConfig,\n",
384+
" TensorBoardOutputConfig,\n",
385+
" CollectionConfig,\n",
386+
" rule_configs,\n",
387+
")"
381388
]
382389
},
383390
{
@@ -393,12 +400,7 @@
393400
"metadata": {},
394401
"outputs": [],
395402
"source": [
396-
"hyperparameters={\n",
397-
" \"epochs\": \"5\",\n",
398-
" \"batch-size\": \"32\",\n",
399-
" \"test-batch-size\": \"100\",\n",
400-
" \"lr\": \"0.001\"\n",
401-
"}"
403+
"hyperparameters = {\"epochs\": \"5\", \"batch-size\": \"32\", \"test-batch-size\": \"100\", \"lr\": \"0.001\"}"
402404
]
403405
},
404406
{
@@ -422,7 +424,7 @@
422424
" Rule.sagemaker(rule_configs.vanishing_gradient()),\n",
423425
" Rule.sagemaker(rule_configs.overfit()),\n",
424426
" Rule.sagemaker(rule_configs.overtraining()),\n",
425-
" Rule.sagemaker(rule_configs.poor_weight_initialization())\n",
427+
" Rule.sagemaker(rule_configs.poor_weight_initialization()),\n",
426428
"]"
427429
]
428430
},
@@ -459,10 +461,7 @@
459461
"outputs": [],
460462
"source": [
461463
"hook_config = DebuggerHookConfig(\n",
462-
" hook_parameters={\n",
463-
" \"train.save_interval\": \"100\",\n",
464-
" \"eval.save_interval\": \"10\"\n",
465-
" }\n",
464+
" hook_parameters={\"train.save_interval\": \"100\", \"eval.save_interval\": \"10\"}\n",
466465
")"
467466
]
468467
},
@@ -480,20 +479,19 @@
480479
"outputs": [],
481480
"source": [
482481
"estimator = PyTorch(\n",
483-
" entry_point='scripts/pytorch_mnist.py',\n",
484-
" base_job_name='smdebugger-demo-mnist-pytorch',\n",
482+
" entry_point=\"scripts/pytorch_mnist.py\",\n",
483+
" base_job_name=\"smdebugger-demo-mnist-pytorch\",\n",
485484
" role=get_execution_role(),\n",
486485
" instance_count=1,\n",
487-
" instance_type='ml.p2.xlarge',\n",
486+
" instance_type=\"ml.p2.xlarge\",\n",
488487
" volume_size=400,\n",
489488
" max_run=3600,\n",
490489
" hyperparameters=hyperparameters,\n",
491-
" framework_version='1.8',\n",
492-
" py_version='py36',\n",
493-
" \n",
490+
" framework_version=\"1.8\",\n",
491+
" py_version=\"py36\",\n",
494492
" ## Debugger parameters\n",
495-
" rules = rules,\n",
496-
" debugger_hook_config=hook_config\n",
493+
" rules=rules,\n",
494+
" debugger_hook_config=hook_config,\n",
497495
")"
498496
]
499497
},
@@ -534,7 +532,7 @@
534532
"metadata": {},
535533
"outputs": [],
536534
"source": [
537-
"job_name=estimator.latest_training_job.name\n",
535+
"job_name = estimator.latest_training_job.name\n",
538536
"client = estimator.sagemaker_session.sagemaker_client\n",
539537
"description = client.describe_training_job(TrainingJobName=estimator.latest_training_job.name)"
540538
]
@@ -547,6 +545,7 @@
547545
"source": [
548546
"import time\n",
549547
"from IPython import display\n",
548+
"\n",
550549
"%matplotlib inline\n",
551550
"\n",
552551
"while description[\"SecondaryStatus\"] not in {\"Stopped\", \"Completed\"}:\n",
@@ -557,10 +556,12 @@
557556
" print(\"TrainingJobStatus: \", primary_status, \" | SecondaryStatus: \", secondary_status)\n",
558557
" print(\"====================================================================\")\n",
559558
" for r in range(len(estimator.latest_training_job.rule_job_summary())):\n",
560-
" rule_summary=estimator.latest_training_job.rule_job_summary()\n",
561-
" print(rule_summary[r]['RuleConfigurationName'], \": \", rule_summary[r]['RuleEvaluationStatus'])\n",
562-
" if rule_summary[r]['RuleEvaluationStatus']=='IssuesFound':\n",
563-
" print(rule_summary[r]['StatusDetails'])\n",
559+
" rule_summary = estimator.latest_training_job.rule_job_summary()\n",
560+
" print(\n",
561+
" rule_summary[r][\"RuleConfigurationName\"], \": \", rule_summary[r][\"RuleEvaluationStatus\"]\n",
562+
" )\n",
563+
" if rule_summary[r][\"RuleEvaluationStatus\"] == \"IssuesFound\":\n",
564+
" print(rule_summary[r][\"StatusDetails\"])\n",
564565
" print(\"====================================================================\")\n",
565566
" print(\"Current time: \", time.asctime())\n",
566567
" display.clear_output(wait=True)\n",
@@ -581,28 +582,36 @@
581582
"outputs": [],
582583
"source": [
583584
"def _get_rule_job_name(training_job_name, rule_configuration_name, rule_job_arn):\n",
584-
" \"\"\"Helper function to get the rule job name with correct casing\"\"\"\n",
585-
" return \"{}-{}-{}\".format(\n",
586-
" training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]\n",
587-
" )\n",
588-
" \n",
585+
" \"\"\"Helper function to get the rule job name with correct casing\"\"\"\n",
586+
" return \"{}-{}-{}\".format(\n",
587+
" training_job_name[:26], rule_configuration_name[:26], rule_job_arn[-8:]\n",
588+
" )\n",
589+
"\n",
590+
"\n",
589591
"def _get_cw_url_for_rule_job(rule_job_name, region):\n",
590-
" return \"https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix\".format(region, region, rule_job_name)\n",
592+
" return \"https://{}.console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix\".format(\n",
593+
" region, region, rule_job_name\n",
594+
" )\n",
591595
"\n",
592596
"\n",
593597
"def get_rule_jobs_cw_urls(estimator):\n",
594598
" region = boto3.Session().region_name\n",
595599
" training_job = estimator.latest_training_job\n",
596600
" training_job_name = training_job.describe()[\"TrainingJobName\"]\n",
597601
" rule_eval_statuses = training_job.describe()[\"DebugRuleEvaluationStatuses\"]\n",
598-
" \n",
599-
" result={}\n",
602+
"\n",
603+
" result = {}\n",
600604
" for status in rule_eval_statuses:\n",
601605
" if status.get(\"RuleEvaluationJobArn\", None) is not None:\n",
602-
" rule_job_name = _get_rule_job_name(training_job_name, status[\"RuleConfigurationName\"], status[\"RuleEvaluationJobArn\"])\n",
603-
" result[status[\"RuleConfigurationName\"]] = _get_cw_url_for_rule_job(rule_job_name, region)\n",
606+
" rule_job_name = _get_rule_job_name(\n",
607+
" training_job_name, status[\"RuleConfigurationName\"], status[\"RuleEvaluationJobArn\"]\n",
608+
" )\n",
609+
" result[status[\"RuleConfigurationName\"]] = _get_cw_url_for_rule_job(\n",
610+
" rule_job_name, region\n",
611+
" )\n",
604612
" return result\n",
605613
"\n",
614+
"\n",
606615
"get_rule_jobs_cw_urls(estimator)"
607616
]
608617
},
@@ -632,6 +641,7 @@
632641
"source": [
633642
"from smdebug.trials import create_trial\n",
634643
"from smdebug.core.modes import ModeKeys\n",
644+
"\n",
635645
"trial = create_trial(estimator.latest_job_debugger_artifacts_path())"
636646
]
637647
},
@@ -762,25 +772,26 @@
762772
"import matplotlib.pyplot as plt\n",
763773
"from mpl_toolkits.axes_grid1 import host_subplot\n",
764774
"\n",
775+
"\n",
765776
"def plot_tensor(trial, tensor_name):\n",
766777
"\n",
767778
" steps_train, vals_train = get_data(trial, tensor_name, mode=ModeKeys.TRAIN)\n",
768779
" print(\"loaded TRAIN data\")\n",
769780
" steps_eval, vals_eval = get_data(trial, tensor_name, mode=ModeKeys.EVAL)\n",
770781
" print(\"loaded EVAL data\")\n",
771782
"\n",
772-
" fig = plt.figure(figsize=(10,7))\n",
783+
" fig = plt.figure(figsize=(10, 7))\n",
773784
" host = host_subplot(111)\n",
774785
"\n",
775786
" par = host.twiny()\n",
776787
"\n",
777788
" host.set_xlabel(\"Steps (TRAIN)\")\n",
778789
" par.set_xlabel(\"Steps (EVAL)\")\n",
779790
" host.set_ylabel(tensor_name)\n",
780-
" \n",
781-
" p1, = host.plot(steps_train, vals_train, label=tensor_name)\n",
791+
"\n",
792+
" (p1,) = host.plot(steps_train, vals_train, label=tensor_name)\n",
782793
" print(\"completed TRAIN plot\")\n",
783-
" p2, = par.plot(steps_eval, vals_eval, label=\"val_\"+tensor_name)\n",
794+
" (p2,) = par.plot(steps_eval, vals_eval, label=\"val_\" + tensor_name)\n",
784795
" print(\"completed EVAL plot\")\n",
785796
" leg = plt.legend()\n",
786797
"\n",
@@ -791,7 +802,7 @@
791802
" leg.texts[1].set_color(p2.get_color())\n",
792803
"\n",
793804
" plt.ylabel(tensor_name)\n",
794-
" \n",
805+
"\n",
795806
" plt.show()"
796807
]
797808
},
@@ -15516,7 +15527,8 @@
1551615527
],
1551715528
"source": [
1551815529
"import IPython\n",
15519-
"IPython.display.HTML(filename=profiler_report_name+\"/profiler-output/profiler-report.html\")"
15530+
"\n",
15531+
"IPython.display.HTML(filename=profiler_report_name + \"/profiler-output/profiler-report.html\")"
1552015532
]
1552115533
},
1552215534
{

0 commit comments

Comments
 (0)