1
1
from pathlib import Path
2
- from typing import Any , Dict , Optional
2
+ from typing import Any , Dict , List , Optional
3
3
4
4
import yaml
5
5
from google .cloud import storage
6
- from huggingface_hub import list_repo_files
6
+ from huggingface_hub import get_hf_file_metadata , hf_hub_url , list_repo_files
7
7
from huggingface_hub .utils import filter_repo_objects
8
8
from truss .constants import (
9
9
BASE_SERVER_REQUIREMENTS_TXT_FILENAME ,
29
29
)
30
30
from truss .contexts .truss_context import TrussContext
31
31
from truss .patch .hash import directory_content_hash
32
- from truss .truss_config import (
33
- Build ,
34
- HuggingFaceCache ,
35
- HuggingFaceModel ,
36
- ModelServer ,
37
- TrussConfig ,
38
- )
32
+ from truss .truss_config import Build , HuggingFaceModel , ModelServer , TrussConfig
39
33
from truss .truss_spec import TrussSpec
40
34
from truss .util .download import download_external_data
41
35
from truss .util .jinja import read_template_from_fs
@@ -133,33 +127,43 @@ def list_files(repo_id, data_dir, revision=None):
133
127
return list_repo_files (repo_id , revision = revision )
134
128
135
129
136
- def update_config_and_gather_files (
137
- config : TrussConfig , truss_dir : Path , build_dir : Path , server_name : str
138
- ):
130
+ def update_model_key (config : TrussConfig ) -> str :
131
+ server_name = config .build .model_server
132
+
133
+ if server_name == ModelServer .TGI :
134
+ return "model_id"
135
+ elif server_name == ModelServer .VLLM :
136
+ return "model"
137
+
138
+ raise ValueError (
139
+ f"Invalid server name (must be `TGI` or `VLLM`, not `{ server_name } `)."
140
+ )
141
+
142
+
143
+ def update_model_name (config : TrussConfig , model_key : str ) -> str :
144
+ if model_key not in config .build .arguments :
145
+ # We should definitely just use the same key across both vLLM and TGI
146
+ raise KeyError (
147
+ "Key for model missing in config or incorrect key used. Use `model` for VLLM and `model_id` for TGI."
148
+ )
149
+ model_name = config .build .arguments [model_key ]
150
+ if "gs://" in model_name :
151
+ # if we are pulling from a gs bucket, we want to alias it as a part of the cache
152
+ model_to_cache = HuggingFaceModel (model_name )
153
+ config .hf_cache .models .append (model_to_cache )
154
+
155
+ config .build .arguments [
156
+ model_key
157
+ ] = f"/app/hf_cache/{ model_name .replace ('gs://' , '' )} "
158
+ return model_name
159
+
160
+
161
+ def get_files_to_cache (config : TrussConfig , truss_dir : Path , build_dir : Path ):
139
162
def copy_into_build_dir (from_path : Path , path_in_build_dir : str ):
140
163
copy_tree_or_file (from_path , build_dir / path_in_build_dir ) # type: ignore[operator]
141
164
142
- if server_name == "TGI" :
143
- model_key = "model_id"
144
- elif server_name == "vLLM" :
145
- model_key = "model"
146
-
147
- if server_name != "TrussServer" :
148
- model_name = config .build .arguments [model_key ]
149
- if "gs://" in model_name :
150
- # if we are pulling from a gs bucket, we want to alias it as a part of the cache
151
- model_to_cache = {"repo_id" : model_name }
152
- if config .hf_cache :
153
- config .hf_cache .models .append (
154
- HuggingFaceModel .from_dict (model_to_cache )
155
- )
156
- else :
157
- config .hf_cache = HuggingFaceCache .from_list ([model_to_cache ])
158
- config .build .arguments [
159
- model_key
160
- ] = f"/app/hf_cache/{ model_name .replace ('gs://' , '' )} "
161
-
162
165
model_files = {}
166
+ cached_files : List [str ] = []
163
167
if config .hf_cache :
164
168
curr_dir = Path (__file__ ).parent .resolve ()
165
169
copy_into_build_dir (curr_dir / "cache_warmer.py" , "cache_warmer.py" )
@@ -180,18 +184,52 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
180
184
)
181
185
)
182
186
183
- if "gs://" in repo_id :
184
- repo_id , _ = split_gs_path (repo_id )
185
- repo_id = f"gs://{ repo_id } "
187
+ cached_files = fetch_files_to_cache (
188
+ cached_files , repo_id , filtered_repo_files
189
+ )
190
+
191
+ model_files [repo_id ] = {"files" : filtered_repo_files , "revision" : revision }
186
192
187
- model_files [repo_id ] = {
188
- "files" : filtered_repo_files ,
189
- "revision" : revision ,
190
- }
191
193
copy_into_build_dir (
192
194
TEMPLATES_DIR / "cache_requirements.txt" , "cache_requirements.txt"
193
195
)
194
- return model_files
196
+ return model_files , cached_files
197
+
198
+
199
+ def fetch_files_to_cache (cached_files : list , repo_id : str , filtered_repo_files : list ):
200
+ if "gs://" in repo_id :
201
+ bucket_name , _ = split_gs_path (repo_id )
202
+ repo_id = f"gs://{ bucket_name } "
203
+
204
+ for filename in filtered_repo_files :
205
+ cached_files .append (f"/app/hf_cache/{ bucket_name } /{ filename } " )
206
+ else :
207
+ repo_folder_name = f"models--{ repo_id .replace ('/' , '--' )} "
208
+ for filename in filtered_repo_files :
209
+ hf_url = hf_hub_url (repo_id , filename )
210
+ hf_file_metadata = get_hf_file_metadata (hf_url )
211
+
212
+ cached_files .append (f"{ repo_folder_name } /blobs/{ hf_file_metadata .etag } " )
213
+
214
+ # snapshots is just a set of folders with symlinks -- we can copy the entire thing separately
215
+ cached_files .append (f"{ repo_folder_name } /snapshots/" )
216
+
217
+ # refs just has files with revision commit hashes
218
+ cached_files .append (f"{ repo_folder_name } /refs/" )
219
+
220
+ cached_files .append ("version.txt" )
221
+
222
+ return cached_files
223
+
224
+
225
+ def update_config_and_gather_files (
226
+ config : TrussConfig , truss_dir : Path , build_dir : Path
227
+ ):
228
+ if config .build .model_server != ModelServer .TrussServer :
229
+ model_key = update_model_key (config )
230
+ update_model_name (config , model_key )
231
+
232
+ return get_files_to_cache (config , truss_dir , build_dir )
195
233
196
234
197
235
def create_tgi_build_dir (config : TrussConfig , build_dir : Path , truss_dir : Path ):
@@ -200,19 +238,24 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
200
238
if not build_dir .exists ():
201
239
build_dir .mkdir (parents = True )
202
240
203
- model_files = update_config_and_gather_files (config , truss_dir , build_dir , "TGI" )
241
+ model_files , cached_file_paths = update_config_and_gather_files (
242
+ config , truss_dir , build_dir
243
+ )
204
244
205
245
hf_access_token = config .secrets .get (HF_ACCESS_TOKEN_SECRET_NAME )
206
246
dockerfile_template = read_template_from_fs (
207
247
TEMPLATES_DIR , "tgi/tgi.Dockerfile.jinja"
208
248
)
209
249
210
250
data_dir = build_dir / "data"
251
+ credentials_file = data_dir / "service_account.json"
211
252
dockerfile_content = dockerfile_template .render (
212
253
hf_access_token = hf_access_token ,
213
254
models = model_files ,
214
255
hf_cache = config .hf_cache ,
215
- data_dir_exists = Path (data_dir ).exists (),
256
+ data_dir_exists = data_dir .exists (),
257
+ credentials_exists = credentials_file .exists (),
258
+ cached_files = cached_file_paths ,
216
259
)
217
260
dockerfile_filepath = build_dir / "Dockerfile"
218
261
dockerfile_filepath .write_text (dockerfile_content )
@@ -247,7 +290,9 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path)
247
290
build_config : Build = config .build
248
291
server_endpoint = server_endpoint_config [build_config .arguments .pop ("endpoint" )]
249
292
250
- model_files = update_config_and_gather_files (config , truss_dir , build_dir , "vLLM" )
293
+ model_files , cached_file_paths = update_config_and_gather_files (
294
+ config , truss_dir , build_dir
295
+ )
251
296
252
297
hf_access_token = config .secrets .get (HF_ACCESS_TOKEN_SECRET_NAME )
253
298
dockerfile_template = read_template_from_fs (
@@ -256,12 +301,15 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path)
256
301
nginx_template = read_template_from_fs (TEMPLATES_DIR , "vllm/proxy.conf.jinja" )
257
302
258
303
data_dir = build_dir / "data"
304
+ credentials_file = data_dir / "service_account.json"
259
305
dockerfile_content = dockerfile_template .render (
260
306
hf_access_token = hf_access_token ,
261
307
models = model_files ,
262
308
should_install_server_requirements = True ,
263
309
hf_cache = config .hf_cache ,
264
310
data_dir_exists = data_dir .exists (),
311
+ credentials_exists = credentials_file .exists (),
312
+ cached_files = cached_file_paths ,
265
313
)
266
314
dockerfile_filepath = build_dir / "Dockerfile"
267
315
dockerfile_filepath .write_text (dockerfile_content )
@@ -336,8 +384,8 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
336
384
download_external_data (self ._spec .external_data , data_dir )
337
385
338
386
# Download from HuggingFace
339
- model_files = update_config_and_gather_files (
340
- config , truss_dir , build_dir , server_name = "TrussServer"
387
+ model_files , cached_files = update_config_and_gather_files (
388
+ config , truss_dir , build_dir
341
389
)
342
390
343
391
# Copy inference server code
@@ -391,7 +439,11 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
391
439
(build_dir / SYSTEM_PACKAGES_TXT_FILENAME ).write_text (spec .system_packages_txt )
392
440
393
441
self ._render_dockerfile (
394
- build_dir , should_install_server_requirements , model_files , use_hf_secret
442
+ build_dir ,
443
+ should_install_server_requirements ,
444
+ model_files ,
445
+ use_hf_secret ,
446
+ cached_files ,
395
447
)
396
448
397
449
def _render_dockerfile (
@@ -400,10 +452,12 @@ def _render_dockerfile(
400
452
should_install_server_requirements : bool ,
401
453
model_files : Dict [str , Any ],
402
454
use_hf_secret : bool ,
455
+ cached_files : List [str ],
403
456
):
404
457
config = self ._spec .config
405
458
data_dir = build_dir / config .data_dir
406
459
bundled_packages_dir = build_dir / config .bundled_packages_dir
460
+ credentials_file = data_dir / "service_account.json"
407
461
dockerfile_template = read_template_from_fs (
408
462
TEMPLATES_DIR , SERVER_DOCKERFILE_TEMPLATE_NAME
409
463
)
@@ -437,6 +491,9 @@ def _render_dockerfile(
437
491
truss_hash = directory_content_hash (self ._truss_dir ),
438
492
models = model_files ,
439
493
use_hf_secret = use_hf_secret ,
494
+ cached_files = cached_files ,
495
+ credentials_exists = credentials_file .exists (),
496
+ hf_cache = len (config .hf_cache .models ) > 0 ,
440
497
)
441
498
docker_file_path = build_dir / MODEL_DOCKERFILE_NAME
442
499
docker_file_path .write_text (dockerfile_contents )
0 commit comments