Skip to content

Commit 01933d7

Browse files
authored
Allow nesting for local data dir, for external data (#583)
* Allow nesting. * Remove warnings.
1 parent 3455c31 commit 01933d7

File tree

4 files changed

+60
-2
lines changed

4 files changed

+60
-2
lines changed

poetry.lock

+20-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ mlflow = "^1.29.0"
9393
mypy = "^1.0.0"
9494
pytest-split = "^0.8.1"
9595
httpx = {extras = ["cli"], version = "^0.24.1"}
96+
requests-mock = "^1.11.0"
9697

9798
[build-system]
9899
requires = ["poetry-core>=1.2.1"]

truss/tests/test_util.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
import requests_mock
5+
from truss.truss_config import ExternalData
6+
from truss.util.download import download_external_data
7+
8+
TEST_DOWNLOAD_URL = "http://example.com/some-download-url"
9+
10+
11+
def test_download(tmp_path):
12+
mocked_download_content = b"mocked content"
13+
with patch.dict(os.environ, {}), requests_mock.Mocker() as m:
14+
m.get(TEST_DOWNLOAD_URL, content=mocked_download_content)
15+
external_data = ExternalData.from_list(
16+
[{"local_data_path": "foo", "url": TEST_DOWNLOAD_URL}]
17+
)
18+
download_external_data(external_data=external_data, data_dir=tmp_path)
19+
20+
with open(tmp_path / "foo", "rb") as f:
21+
content = f.read()
22+
23+
assert content == mocked_download_content
24+
25+
26+
def test_download_into_nested_subdir(tmp_path):
27+
mocked_download_content = b"mocked content"
28+
with patch.dict(os.environ, {}), requests_mock.Mocker() as m:
29+
m.get(TEST_DOWNLOAD_URL, content=mocked_download_content)
30+
external_data = ExternalData.from_list(
31+
[{"local_data_path": "foo/bar/baz", "url": TEST_DOWNLOAD_URL}]
32+
)
33+
download_external_data(external_data=external_data, data_dir=tmp_path)
34+
35+
with open(tmp_path / "foo" / "bar" / "baz", "rb") as f:
36+
content = f.read()
37+
38+
assert content == mocked_download_content

truss/util/download.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def download_external_data(external_data: Optional[ExternalData], data_dir: Path
2525
raise ValueError(
2626
"Local data path of external data cannot point to outside data directory"
2727
)
28-
path.parent.mkdir(exist_ok=True)
28+
path.parent.mkdir(exist_ok=True, parents=True)
2929

3030
if b10cp_path is not None:
3131
print("b10cp found, using it to download external data")

0 commit comments

Comments
 (0)