Skip to content

Commit 233d0da

Browse files
author
Varun Shenoy
authored
Caching Weights from S3 (#709)
* init s3 caching * update toml to test on dev * fix gcs tests + add s3 tests * cleanup * add boto to deps * update pyproject to include boto * bump dev * update poetry lock * public s3 buckets are working * update dev * bump rc
1 parent 91603f5 commit 233d0da

File tree

6 files changed

+236
-42
lines changed

6 files changed

+236
-42
lines changed

poetry.lock

+11-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.7.14rc2"
3+
version = "0.7.14rc3"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
@@ -40,6 +40,8 @@ huggingface_hub = "^0.16.4"
4040
rich-click = "^1.6.1"
4141
inquirerpy = "^0.3.4"
4242
google-cloud-storage = "2.10.0"
43+
botocore =">=1.31.7"
44+
4345

4446
[tool.poetry.group.builder.dependencies]
4547
python = ">=3.8,<3.12"
@@ -58,6 +60,7 @@ httpx = "^0.24.1"
5860
psutil = "^5.9.4"
5961
huggingface_hub = "^0.16.4"
6062
google-cloud-storage = "2.10.0"
63+
boto3 = "^1.26.157"
6164

6265
[tool.poetry.dev-dependencies]
6366
ipython = "^7.16"

truss/contexts/image_builder/cache_warmer.py

+82-26
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import datetime
2+
import json
23
import os
34
import subprocess
45
import sys
56
from pathlib import Path
67
from typing import Optional
78

9+
import boto3
10+
from botocore.client import Config
811
from google.cloud import storage
912
from huggingface_hub import hf_hub_download
1013

@@ -32,52 +35,105 @@ def _download_from_url_using_b10cp(
3235
)
3336

3437

35-
def split_gs_path(gs_path):
38+
def split_path(path, prefix="gs://"):
3639
# Remove the 'gs://' prefix
37-
path = gs_path.replace("gs://", "")
40+
path = path.replace(prefix, "")
3841

3942
# Split on the first slash
4043
parts = path.split("/", 1)
4144

4245
bucket_name = parts[0]
43-
prefix = parts[1] if len(parts) > 1 else ""
46+
path = parts[1] if len(parts) > 1 else ""
4447

45-
return bucket_name, prefix
48+
return bucket_name, path
49+
50+
51+
def parse_s3_service_account_file(file_path):
52+
# open the json file
53+
with open(file_path, "r") as f:
54+
data = json.load(f)
55+
56+
# validate the data
57+
if "aws_access_key_id" not in data or "aws_secret_access_key" not in data:
58+
raise ValueError("Invalid AWS credentials file")
59+
60+
# parse the data
61+
aws_access_key_id = data["aws_access_key_id"]
62+
aws_secret_access_key = data["aws_secret_access_key"]
63+
aws_region = data["aws_region"]
64+
65+
return aws_access_key_id, aws_secret_access_key, aws_region
4666

4767

4868
def download_file(
4969
repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json"
5070
):
5171
# Check if repo_name starts with "gs://"
52-
if "gs://" in repo_name:
72+
if repo_name.startswith(("gs://", "s3://")):
73+
prefix = repo_name[:5]
74+
5375
# Create directory if not exist
54-
bucket_name, _ = split_gs_path(repo_name)
55-
repo_name = repo_name.replace("gs://", "")
76+
bucket_name, _ = split_path(repo_name, prefix=prefix)
77+
repo_name = repo_name.replace(prefix, "")
5678
cache_dir = Path(f"/app/hf_cache/{bucket_name}")
5779
cache_dir.mkdir(parents=True, exist_ok=True)
5880

59-
# Connect to GCS storage
60-
storage_client = storage.Client.from_service_account_json(key_file)
61-
bucket = storage_client.bucket(bucket_name)
62-
blob = bucket.blob(file_name)
81+
if prefix == "gs://":
82+
# Connect to GCS storage
83+
storage_client = storage.Client.from_service_account_json(key_file)
84+
bucket = storage_client.bucket(bucket_name)
85+
blob = bucket.blob(file_name)
6386

64-
dst_file = Path(f"{cache_dir}/{file_name}")
65-
if not dst_file.parent.exists():
66-
dst_file.parent.mkdir(parents=True)
87+
dst_file = Path(f"{cache_dir}/{file_name}")
88+
if not dst_file.parent.exists():
89+
dst_file.parent.mkdir(parents=True)
6790

68-
if not blob.exists(storage_client):
69-
raise RuntimeError(f"File not found on GCS bucket: {blob.name}")
91+
if not blob.exists(storage_client):
92+
raise RuntimeError(f"File not found on GCS bucket: {blob.name}")
7093

71-
url = blob.generate_signed_url(
72-
version="v4",
73-
expiration=datetime.timedelta(minutes=15),
74-
method="GET",
75-
)
76-
try:
77-
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file)
78-
proc.wait()
79-
except Exception as e:
80-
raise RuntimeError(f"Failure downloading file from GCS: {e}")
94+
url = blob.generate_signed_url(
95+
version="v4",
96+
expiration=datetime.timedelta(minutes=15),
97+
method="GET",
98+
)
99+
try:
100+
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file)
101+
proc.wait()
102+
except Exception as e:
103+
raise RuntimeError(f"Failure downloading file from GCS: {e}")
104+
elif prefix == "s3://":
105+
(
106+
AWS_ACCESS_KEY_ID,
107+
AWS_SECRET_ACCESS_KEY,
108+
AWS_REGION,
109+
) = parse_s3_service_account_file(key_file)
110+
client = boto3.client(
111+
"s3",
112+
aws_access_key_id=AWS_ACCESS_KEY_ID,
113+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
114+
region_name=AWS_REGION,
115+
config=Config(signature_version="s3v4"),
116+
)
117+
bucket_name, _ = split_path(bucket_name, prefix="s3://")
118+
119+
dst_file = Path(f"{cache_dir}/{file_name}")
120+
if not dst_file.parent.exists():
121+
dst_file.parent.mkdir(parents=True)
122+
123+
try:
124+
url = client.generate_presigned_url(
125+
"get_object",
126+
Params={"Bucket": bucket_name, "Key": file_name},
127+
ExpiresIn=3600,
128+
)
129+
except Exception:
130+
raise RuntimeError(f"File not found on S3 bucket: {file_name}")
131+
132+
try:
133+
proc = _download_from_url_using_b10cp(_b10cp_path(), url, dst_file)
134+
proc.wait()
135+
except Exception as e:
136+
raise RuntimeError(f"Failure downloading file from S3: {e}")
81137
else:
82138
secret_path = Path("/etc/secrets/hf-access-token")
83139
secret = secret_path.read_text().strip() if secret_path.exists() else None

truss/contexts/image_builder/serving_image_builder.py

+65-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import json
12
from pathlib import Path
23
from typing import Any, Dict, List, Optional
34

5+
import boto3
46
import yaml
57
from google.cloud import storage
68
from huggingface_hub import get_hf_file_metadata, hf_hub_url, list_repo_files
@@ -88,29 +90,31 @@ def create_triton_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Pat
8890
(build_dir / SYSTEM_PACKAGES_TXT_FILENAME).write_text(_spec.system_packages_txt)
8991

9092

91-
def split_gs_path(gs_path):
93+
def split_path(path, prefix="gs://"):
9294
# Remove the 'gs://' prefix
93-
path = gs_path.replace("gs://", "")
95+
path = path.replace(prefix, "")
9496

9597
# Split on the first slash
9698
parts = path.split("/", 1)
9799

98100
bucket_name = parts[0]
99-
prefix = parts[1] if len(parts) > 1 else ""
101+
path = parts[1] if len(parts) > 1 else ""
100102

101-
return bucket_name, prefix
103+
return bucket_name, path
102104

103105

104-
def list_bucket_files(bucket_name, data_dir, is_trusted=False):
105-
# TODO(varun): provide support for aws s3
106-
106+
def list_gcs_bucket_files(
107+
bucket_name,
108+
data_dir,
109+
is_trusted=False,
110+
):
107111
if is_trusted:
108112
storage_client = storage.Client.from_service_account_json(
109113
data_dir / "service_account.json"
110114
)
111115
else:
112116
storage_client = storage.Client()
113-
bucket_name, prefix = split_gs_path(bucket_name)
117+
bucket_name, prefix = split_path(bucket_name)
114118
blobs = storage_client.list_blobs(bucket_name, prefix=prefix)
115119

116120
all_objects = []
@@ -123,9 +127,52 @@ def list_bucket_files(bucket_name, data_dir, is_trusted=False):
123127
return all_objects
124128

125129

130+
def parse_s3_service_account_file(file_path):
131+
# open the json file
132+
with open(file_path, "r") as f:
133+
data = json.load(f)
134+
135+
# validate the data
136+
if "aws_access_key_id" not in data or "aws_secret_access_key" not in data:
137+
raise ValueError("Invalid AWS credentials file")
138+
139+
# parse the data
140+
aws_access_key_id = data["aws_access_key_id"]
141+
aws_secret_access_key = data["aws_secret_access_key"]
142+
143+
return aws_access_key_id, aws_secret_access_key
144+
145+
146+
def list_s3_bucket_files(bucket_name, data_dir, is_trusted=False):
147+
if is_trusted:
148+
AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY = parse_s3_service_account_file(
149+
data_dir / "service_account.json"
150+
)
151+
session = boto3.Session(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
152+
s3 = session.resource("s3")
153+
else:
154+
s3 = boto3.client("s3")
155+
156+
bucket_name, _ = split_path(bucket_name, prefix="s3://")
157+
bucket = s3.Bucket(bucket_name)
158+
159+
all_objects = []
160+
for blob in bucket.objects.all():
161+
all_objects.append(blob.key)
162+
163+
return all_objects
164+
165+
126166
def list_files(repo_id, data_dir, revision=None):
127-
if repo_id.startswith(("s3://", "gs://")):
128-
return list_bucket_files(repo_id, data_dir, is_trusted=True)
167+
credentials_file = data_dir / "service_account.json"
168+
if repo_id.startswith("gs://"):
169+
return list_gcs_bucket_files(
170+
repo_id, data_dir, is_trusted=credentials_file.exists()
171+
)
172+
elif repo_id.startswith("s3://"):
173+
return list_s3_bucket_files(
174+
repo_id, data_dir, is_trusted=credentials_file.exists()
175+
)
129176
else:
130177
# we assume it's a HF bucket
131178
return list_repo_files(repo_id, revision=revision)
@@ -201,10 +248,16 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
201248

202249

203250
def fetch_files_to_cache(cached_files: list, repo_id: str, filtered_repo_files: list):
204-
if "gs://" in repo_id:
205-
bucket_name, _ = split_gs_path(repo_id)
251+
if repo_id.startswith("gs://"):
252+
bucket_name, _ = split_path(repo_id)
206253
repo_id = f"gs://{bucket_name}"
207254

255+
for filename in filtered_repo_files:
256+
cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}")
257+
elif repo_id.startswith("s3://"):
258+
bucket_name, _ = split_path(repo_id, prefix="s3://")
259+
repo_id = f"s3://{bucket_name}"
260+
208261
for filename in filtered_repo_files:
209262
cached_files.append(f"/app/hf_cache/{bucket_name}/{filename}")
210263
else:
+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
huggingface-hub==0.16.4
22
google-cloud-storage==2.10.0
3+
boto3==1.28.70
34
hf-transfer==0.1.3

0 commit comments

Comments
 (0)