|
| 1 | +--- |
| 2 | +title: Pre/post-process methods |
| 3 | +description: "Deploy a model that makes use of pre-process" |
| 4 | +--- |
| 5 | + |
| 6 | +Out of the box, Truss limits the amount of concurrent predicts that happen on |
| 7 | +single container. This ensures that the CPU, and for many models the GPU, do not get |
| 8 | +overloaded, and that the model can continue respond to requests in periods of high load |
| 9 | + |
| 10 | +However, many models, in addition to having compute components, also have |
| 11 | +IO requirements. For example, a model that classifies images may need to download |
| 12 | +the image from a URL before it can classify it. |
| 13 | + |
| 14 | +Truss provides a way to separate the IO component from the compute component, to |
| 15 | +ensure that any IO does not prevent utilization of the compute on your pod. |
| 16 | + |
| 17 | +To do this, you can use the pre/post process methods on a Truss. These methods |
| 18 | +can be defined like this: |
| 19 | + |
| 20 | + |
| 21 | +```python |
| 22 | +class Model: |
| 23 | + def __init__: ... |
| 24 | + def load(self, **kwargs) -> None: ... |
| 25 | + def preprocess(self, request): |
| 26 | + # Include any IO logic that happens _before_ predict here |
| 27 | + ... |
| 28 | + |
| 29 | + def predict(self, request): |
| 30 | + # Include the actual predict here |
| 31 | + ... |
| 32 | + |
| 33 | + def postprocess(self, response): |
| 34 | + # Include any IO logic that happens _after_ predict here |
| 35 | + ... |
| 36 | +``` |
| 37 | + |
| 38 | +What happens when the model is invoked is that any logic defined in the pre or post-process |
| 39 | +methods happen on a separate thread, and are not subject to the same concurrency limits as |
| 40 | +predict. So -- let's say you have a model that can handle 5 concurrent requests: |
| 41 | + |
| 42 | +```config.yaml |
| 43 | +... |
| 44 | +runtime: |
| 45 | + predict_concurrency: 10 |
| 46 | +... |
| 47 | +``` |
| 48 | + |
| 49 | +If you hit it with 10 requests, they will _all_ begin pre-processing, but then when the |
| 50 | +the 6th request is ready to begin the predict method, it will have to wait for one of the |
| 51 | +first 5 requests to finish. This ensures that the GPU is not overloaded, while also ensuring |
| 52 | +that the compute logic does not get blocked by IO, thereby ensuring that you can achieve |
| 53 | +maximum throughput. |
| 54 | + |
| 55 | +<RequestExample> |
| 56 | + |
| 57 | +```python model/model.py |
| 58 | +import requests |
| 59 | +from typing import Dict |
| 60 | +from PIL import Image |
| 61 | +from transformers import CLIPProcessor, CLIPModel |
| 62 | + |
| 63 | +CHECKPOINT = "openai/clip-vit-base-patch32" |
| 64 | + |
| 65 | + |
| 66 | +class Model: |
| 67 | + """ |
| 68 | + This is simple example of using CLIP to classify images. |
| 69 | + It outputs the probability of the image being a cat or a dog. |
| 70 | + """ |
| 71 | + def __init__(self, **kwargs) -> None: |
| 72 | + self._processor = None |
| 73 | + self._model = None |
| 74 | + |
| 75 | + def load(self): |
| 76 | + """ |
| 77 | + Loads the CLIP model and processor checkpoints. |
| 78 | + """ |
| 79 | + self._model = CLIPModel.from_pretrained(CHECKPOINT) |
| 80 | + self._processor = CLIPProcessor.from_pretrained(CHECKPOINT) |
| 81 | + |
| 82 | + def preprocess(self, request: Dict) -> Dict: |
| 83 | + """" |
| 84 | + This method downloads the image from the url and preprocesses it. |
| 85 | + The preprocess method is used for any logic that involves IO, in this |
| 86 | + case downloading the image. It is called before the predict method |
| 87 | + in a separate thread and is not subject to the same concurrency |
| 88 | + limits as the predict method, so can be called many times in parallel. |
| 89 | + """ |
| 90 | + image = Image.open(requests.get(request.pop("url"), stream=True).raw) |
| 91 | + request["inputs"] = self._processor( |
| 92 | + text=["a photo of a cat", "a photo of a dog"], |
| 93 | + images=image, |
| 94 | + return_tensors="pt", |
| 95 | + padding=True |
| 96 | + ) |
| 97 | + return request |
| 98 | + |
| 99 | + def predict(self, request: Dict) -> Dict: |
| 100 | + """ |
| 101 | + This performs the actual classification. The predict method is subject to |
| 102 | + the predict concurrency constraints. |
| 103 | + """ |
| 104 | + outputs = self._model(**request["inputs"]) |
| 105 | + logits_per_image = outputs.logits_per_image |
| 106 | + return logits_per_image.softmax(dim=1).tolist() |
| 107 | + |
| 108 | +``` |
| 109 | + |
| 110 | +```yaml config.yaml |
| 111 | +model_name: clip-example |
| 112 | +requirements: |
| 113 | +- transformers==4.32.0 |
| 114 | +- pillow==10.0.0 |
| 115 | +- torch==2.0.1 |
| 116 | +resources: |
| 117 | + cpu: "3" |
| 118 | + memory: 14Gi |
| 119 | + use_gpu: true |
| 120 | + accelerator: A10G |
| 121 | +``` |
| 122 | +
|
| 123 | +</RequestExample> |
0 commit comments