2
2
from typing import Any , Dict , Optional
3
3
4
4
import yaml
5
+ from google .cloud import storage
5
6
from huggingface_hub import list_repo_files
6
7
from huggingface_hub .utils import filter_repo_objects
7
8
from truss .constants import (
27
28
)
28
29
from truss .contexts .truss_context import TrussContext
29
30
from truss .patch .hash import directory_content_hash
30
- from truss .truss_config import Build , ModelServer , TrussConfig
31
+ from truss .truss_config import (
32
+ Build ,
33
+ HuggingFaceCache ,
34
+ HuggingFaceModel ,
35
+ ModelServer ,
36
+ TrussConfig ,
37
+ )
31
38
from truss .truss_spec import TrussSpec
32
39
from truss .util .download import download_external_data
33
40
from truss .util .jinja import read_template_from_fs
@@ -75,7 +82,12 @@ def create_tgi_build_dir(config: TrussConfig, build_dir: Path):
75
82
supervisord_filepath .write_text (supervisord_contents )
76
83
77
84
78
- def create_vllm_build_dir (config : TrussConfig , build_dir : Path ):
85
+ def create_vllm_build_dir (config : TrussConfig , build_dir : Path , truss_dir : Path ):
86
+ def copy_into_build_dir (from_path : Path , path_in_build_dir : str ):
87
+ copy_tree_or_file (from_path , build_dir / path_in_build_dir ) # type: ignore[operator]
88
+
89
+ copy_tree_path (truss_dir , build_dir )
90
+
79
91
server_endpoint_config = {
80
92
"Completions" : "/v1/completions" ,
81
93
"ChatCompletions" : "/v1/chat/completions" ,
@@ -85,13 +97,58 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
85
97
86
98
build_config : Build = config .build
87
99
server_endpoint = server_endpoint_config [build_config .arguments .pop ("endpoint" )]
100
+
101
+ model_name = build_config .arguments .pop ("model" )
102
+ if "gs://" in model_name :
103
+ # if we are pulling from a gs bucket, we want to alias it as a part of the cache
104
+ model_to_cache = {"repo_id" : model_name }
105
+ if config .hf_cache :
106
+ config .hf_cache .models .append (HuggingFaceModel .from_dict (model_to_cache ))
107
+ else :
108
+ config .hf_cache = HuggingFaceCache .from_list ([model_to_cache ])
109
+ build_config .arguments [
110
+ "model"
111
+ ] = f"/app/hf_cache/{ model_name .replace ('gs://' , '' )} "
112
+
88
113
hf_access_token = config .secrets .get (HF_ACCESS_TOKEN_SECRET_NAME )
89
114
dockerfile_template = read_template_from_fs (
90
115
TEMPLATES_DIR , "vllm/vllm.Dockerfile.jinja"
91
116
)
92
117
nginx_template = read_template_from_fs (TEMPLATES_DIR , "vllm/proxy.conf.jinja" )
118
+ copy_into_build_dir (
119
+ TEMPLATES_DIR / "cache_requirements.txt" , "cache_requirements.txt"
120
+ )
93
121
94
- dockerfile_content = dockerfile_template .render (hf_access_token = hf_access_token )
122
+ model_files = {}
123
+ if config .hf_cache :
124
+ curr_dir = Path (__file__ ).parent .resolve ()
125
+ copy_into_build_dir (curr_dir / "cache_warmer.py" , "cache_warmer.py" )
126
+ for model in config .hf_cache .models :
127
+ repo_id = model .repo_id
128
+ revision = model .revision
129
+
130
+ allow_patterns = model .allow_patterns
131
+ ignore_patterns = model .ignore_patterns
132
+
133
+ filtered_repo_files = list (
134
+ filter_repo_objects (
135
+ items = list_files (
136
+ repo_id , truss_dir / config .data_dir , revision = revision
137
+ ),
138
+ allow_patterns = allow_patterns ,
139
+ ignore_patterns = ignore_patterns ,
140
+ )
141
+ )
142
+ model_files [repo_id ] = {
143
+ "files" : filtered_repo_files ,
144
+ "revision" : revision ,
145
+ }
146
+
147
+ dockerfile_content = dockerfile_template .render (
148
+ hf_access_token = hf_access_token ,
149
+ models = model_files ,
150
+ should_install_server_requirements = True ,
151
+ )
95
152
dockerfile_filepath = build_dir / "Dockerfile"
96
153
dockerfile_filepath .write_text (dockerfile_content )
97
154
@@ -110,6 +167,47 @@ def create_vllm_build_dir(config: TrussConfig, build_dir: Path):
110
167
supervisord_filepath .write_text (supervisord_contents )
111
168
112
169
170
+ def split_gs_path (gs_path ):
171
+ # Remove the 'gs://' prefix
172
+ path = gs_path .replace ("gs://" , "" )
173
+
174
+ # Split on the first slash
175
+ parts = path .split ("/" , 1 )
176
+
177
+ bucket_name = parts [0 ]
178
+ prefix = parts [1 ] if len (parts ) > 1 else ""
179
+
180
+ return bucket_name , prefix
181
+
182
+
183
+ def list_bucket_files (bucket_name , data_dir , is_trusted = False ):
184
+ # TODO(varun): provide support for aws s3
185
+
186
+ if is_trusted :
187
+ storage_client = storage .Client .from_service_account_json (
188
+ data_dir / "service_account.json"
189
+ )
190
+ else :
191
+ storage_client = storage .Client ()
192
+ print (bucket_name .replace ("gs://" , "" ))
193
+ bucket_name , prefix = split_gs_path (bucket_name )
194
+ blobs = storage_client .list_blobs (bucket_name , prefix = prefix )
195
+
196
+ all_objects = []
197
+ for blob in blobs :
198
+ all_objects .append (Path (blob .name ).name )
199
+ print (Path (blob .name ).name )
200
+ return all_objects
201
+
202
+
203
+ def list_files (repo_id , data_dir , revision = None ):
204
+ if repo_id .startswith (("s3://" , "gs://" )):
205
+ return list_bucket_files (repo_id , data_dir , is_trusted = True )
206
+ else :
207
+ # we assume it's a HF bucket
208
+ list_repo_files (repo_id , revision = revision )
209
+
210
+
113
211
class ServingImageBuilderContext (TrussContext ):
114
212
@staticmethod
115
213
def run (truss_dir : Path ):
@@ -143,7 +241,7 @@ def prepare_image_build_dir(
143
241
create_tgi_build_dir (config , build_dir )
144
242
return
145
243
elif config .build .model_server is ModelServer .VLLM :
146
- create_vllm_build_dir (config , build_dir )
244
+ create_vllm_build_dir (config , build_dir , truss_dir )
147
245
return
148
246
149
247
data_dir = build_dir / config .data_dir # type: ignore[operator]
@@ -175,7 +273,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
175
273
176
274
filtered_repo_files = list (
177
275
filter_repo_objects (
178
- items = list_repo_files (repo_id , revision = revision ),
276
+ items = list_files (repo_id , data_dir , revision = revision ),
179
277
allow_patterns = allow_patterns ,
180
278
ignore_patterns = ignore_patterns ,
181
279
)
0 commit comments