Skip to content

Commit 79d171f

Browse files
authored
Add s3 session token support (#718)
* Add s3 session token support * Add more documentation. Re-order cache Dockerfile for better caching * Update pyproject.toml
1 parent 1dc9be5 commit 79d171f

File tree

5 files changed

+47
-24
lines changed

5 files changed

+47
-24
lines changed

docs/guides/model-cache.mdx

+27
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ model_cache:
135135
If you are accessing a public GCS bucket, you can ignore the subsequent steps, but make sure you set an appropriate appropriate policy on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail.
136136

137137
However, for a private S3 bucket, you need to first find your `aws_access_key_id`, `aws_secret_access_key`, and `aws_region` in your AWS dashboard. Create a file named `s3_credentials.json`. Inside this file, add the credentials that you identified earlier as shown below. Place this file into the `data` directory of your Truss.
138+
The key `aws_session_token` can be included, but is optional.
138139

139140
Here is an example of how your `s3_credentials.json` file should look:
140141

@@ -156,6 +157,32 @@ your-truss
156157
|. └── s3_credentials.json
157158
```
158159

160+
When you are generating credentials, make sure that the resulting keys have at minimum the following IAM policy:
161+
162+
```json
163+
{
164+
"Version": "2012-10-17",
165+
"Statement": [
166+
{
167+
"Action": [
168+
"s3:GetObject",
169+
"s3:ListObjects",
170+
],
171+
"Effect": "Allow",
172+
"Resource": ["arn:aws:s3:::S3_BUCKET/PATH_TO_MODEL/*"]
173+
},
174+
{
175+
"Action": [
176+
"s3:ListBucket",
177+
],
178+
"Effect": "Allow",
179+
"Resource": ["arn:aws:s3:::S3_BUCKET"]
180+
}
181+
]
182+
}
183+
```
184+
185+
159186
<Warning>
160187
If you are using version control, like git, for your Truss, make sure to add `s3_credentials.json` to your `.gitignore` file. You don't want to accidentally expose your service account key.
161188
</Warning>

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.7.15rc1"
3+
version = "0.7.15rc2"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/contexts/image_builder/cache_warmer.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class AWSCredentials:
4747
access_key_id: str
4848
secret_access_key: str
4949
region: str
50+
session_token: Optional[str]
5051

5152

5253
def parse_s3_credentials_file(key_file_path: str) -> AWSCredentials:
@@ -67,6 +68,7 @@ def parse_s3_credentials_file(key_file_path: str) -> AWSCredentials:
6768
access_key_id=data["aws_access_key_id"],
6869
secret_access_key=data["aws_secret_access_key"],
6970
region=data["aws_region"],
71+
session_token=data.get("aws_session_token", None),
7072
)
7173

7274
return aws_sa
@@ -179,6 +181,7 @@ def download_to_cache(self):
179181
aws_access_key_id=s3_credentials.access_key_id,
180182
aws_secret_access_key=s3_credentials.secret_access_key,
181183
region_name=s3_credentials.region,
184+
aws_session_token=s3_credentials.session_token,
182185
config=Config(signature_version="s3v4"),
183186
)
184187
else:

truss/contexts/image_builder/serving_image_builder.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
from abc import ABC, abstractmethod
54
from dataclasses import dataclass
65
from pathlib import Path
@@ -27,6 +26,10 @@
2726
TEMPLATES_DIR,
2827
TRITON_SERVER_CODE_DIR,
2928
)
29+
from truss.contexts.image_builder.cache_warmer import (
30+
AWSCredentials,
31+
parse_s3_credentials_file,
32+
)
3033
from truss.contexts.image_builder.image_builder import ImageBuilder
3134
from truss.contexts.image_builder.util import (
3235
TRUSS_BASE_IMAGE_VERSION_TAG,
@@ -136,10 +139,15 @@ def list_files(self, revision=None):
136139
s3_credentials_file = self.data_dir / S3_CREDENTIALS
137140

138141
if s3_credentials_file.exists():
139-
AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY = parse_s3_service_account_file(
142+
s3_credentials: AWSCredentials = parse_s3_credentials_file(
140143
self.data_dir / S3_CREDENTIALS
141144
)
142-
session = boto3.Session(AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)
145+
session = boto3.Session(
146+
aws_access_key_id=s3_credentials.access_key_id,
147+
aws_secret_access_key=s3_credentials.secret_access_key,
148+
aws_session_token=s3_credentials.session_token,
149+
region_name=s3_credentials.region,
150+
)
143151
s3 = session.resource("s3")
144152
else:
145153
s3 = boto3.resource("s3", config=Config(signature_version=UNSIGNED))
@@ -274,22 +282,6 @@ class CachedFile:
274282
dst: str
275283

276284

277-
def parse_s3_service_account_file(file_path):
278-
# open the json file
279-
with open(file_path, "r") as f:
280-
data = json.load(f)
281-
282-
# validate the data
283-
if "aws_access_key_id" not in data or "aws_secret_access_key" not in data:
284-
raise ValueError("Invalid AWS credentials file")
285-
286-
# parse the data
287-
aws_access_key_id = data["aws_access_key_id"]
288-
aws_secret_access_key = data["aws_secret_access_key"]
289-
290-
return aws_access_key_id, aws_secret_access_key
291-
292-
293285
def update_model_key(config: TrussConfig) -> str:
294286
server_name = config.build.model_server
295287

truss/templates/cache.Dockerfile.jinja

+5-4
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ WORKDIR /app
77
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
88
{% endif %}
99

10-
{% for credential in credentials_to_cache %}
11-
COPY ./{{credential}} /app/{{credential}}
12-
{% endfor %}
13-
1410
RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
1511
ENV B10CP_PATH_TRUSS /app/b10cp
1612
COPY ./cache_requirements.txt /app/cache_requirements.txt
1713
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
1814
COPY ./cache_warmer.py /cache_warmer.py
15+
16+
{% for credential in credentials_to_cache %}
17+
COPY ./{{credential}} /app/{{credential}}
18+
{% endfor %}
19+
1920
{% for repo, hf_dir in models.items() %}
2021
{% for file in hf_dir.files %}
2122
{{ "RUN --mount=type=secret,id=" + hf_access_token_file_name + ",dst=/etc/secrets/" + hf_access_token_file_name if use_hf_secret else "RUN" }} python3 /cache_warmer.py {{file}} {{repo}} {% if hf_dir.revision != None %}{{hf_dir.revision}}{% endif %}

0 commit comments

Comments
 (0)