|
| 1 | +From d8a885d3d35c4512f665357ab3c25a54dc5731ca Mon Sep 17 00:00:00 2001 |
| 2 | +From: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> |
| 3 | +Date: Tue, 3 Dec 2024 10:51:04 -0800 |
| 4 | +Subject: [PATCH] feat: support non-cuda devices for text models |
| 5 | + |
| 6 | +This commit adds support of non-cuda pytorch backend devices |
| 7 | +to text models. Commit extends existing test to run for the |
| 8 | +externally specified device (cuda is a default). Commit verified on |
| 9 | +Llama3.2-3B-Instruct model for: |
| 10 | +* "cuda" device type on NVidia A10 GPU |
| 11 | +* "cpu" device type |
| 12 | +* "xpu" device type on Intel Data Center Max Series GPU (PVC) |
| 13 | + |
| 14 | +Co-authored-by: anordin95 <alexander.f.nordin@gmail.com> |
| 15 | +Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> |
| 16 | +--- |
| 17 | + models/llama3/reference_impl/generation.py | 43 +++++++++++++++++----- |
| 18 | + models/llama3/reference_impl/model.py | 4 +- |
| 19 | + models/llama3/tests/api/test_generation.py | 1 + |
| 20 | + 3 files changed, 37 insertions(+), 11 deletions(-) |
| 21 | + |
| 22 | +diff --git a/models/llama3/reference_impl/generation.py b/models/llama3/reference_impl/generation.py |
| 23 | +index a1b7b69..2882afb 100644 |
| 24 | +--- a/models/llama3/reference_impl/generation.py |
| 25 | ++++ b/models/llama3/reference_impl/generation.py |
| 26 | +@@ -74,6 +74,7 @@ class Llama: |
| 27 | + model_parallel_size: Optional[int] = None, |
| 28 | + tokenizer_path: Optional[str] = None, |
| 29 | + seed: int = 1, |
| 30 | ++ device: str = "cuda" |
| 31 | + ): |
| 32 | + """ |
| 33 | + Build a Llama instance by initializing and loading a model checkpoint. |
| 34 | +@@ -85,6 +86,7 @@ class Llama: |
| 35 | + max_batch_size (int): Maximum batch size for inference. |
| 36 | + model_parallel_size (Optional[int], optional): Number of model parallel processes. |
| 37 | + If not provided, it's determined from the environment. Defaults to None. |
| 38 | ++ device (str, optional): Device to use, e.g. cuda (default), xpu, cpu, etc. |
| 39 | + |
| 40 | + Returns: |
| 41 | + Llama: An instance of the Llama class with the loaded model and tokenizer. |
| 42 | +@@ -92,6 +94,7 @@ class Llama: |
| 43 | + Raises: |
| 44 | + AssertionError: If there are no checkpoint files in the specified directory, |
| 45 | + or if the model parallel size does not match the number of checkpoint files. |
| 46 | ++ RuntimeError: If PyTorch backend for the specified device is not available. |
| 47 | + |
| 48 | + |
| 49 | + Note: |
| 50 | +@@ -99,8 +102,16 @@ class Llama: |
| 51 | + and loads the pre-trained model and tokenizer. |
| 52 | + """ |
| 53 | + |
| 54 | ++ device = torch.device(device) |
| 55 | ++ if (device.type == "cuda" and not torch.cuda.is_available() or |
| 56 | ++ device.type == "xpu" and not torch.xpu.is_available()): |
| 57 | ++ raise RuntimeError(f"PyTorch backend for {device.type} device type is not available") |
| 58 | ++ |
| 59 | + if not torch.distributed.is_initialized(): |
| 60 | +- torch.distributed.init_process_group("nccl") |
| 61 | ++ if device.type == "cuda": |
| 62 | ++ torch.distributed.init_process_group("nccl") |
| 63 | ++ else: |
| 64 | ++ torch.distributed.init_process_group("gloo") |
| 65 | + |
| 66 | + if not model_parallel_is_initialized(): |
| 67 | + if model_parallel_size is None: |
| 68 | +@@ -108,7 +119,10 @@ class Llama: |
| 69 | + initialize_model_parallel(model_parallel_size) |
| 70 | + |
| 71 | + local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| 72 | +- torch.cuda.set_device(local_rank) |
| 73 | ++ if device.type == "cuda": |
| 74 | ++ torch.cuda.set_device(local_rank) |
| 75 | ++ elif device.type == "xpu": |
| 76 | ++ torch.xpu.set_device(local_rank) |
| 77 | + |
| 78 | + torch.manual_seed(seed) |
| 79 | + |
| 80 | +@@ -138,10 +152,20 @@ class Llama: |
| 81 | + tokenizer = Tokenizer.get_instance() |
| 82 | + |
| 83 | + assert model_args.vocab_size == tokenizer.n_words |
| 84 | +- if torch.cuda.is_bf16_supported(): |
| 85 | +- torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) |
| 86 | ++ torch.set_default_device(device) |
| 87 | ++ if device.type == "cuda": |
| 88 | ++ if torch.cuda.is_bf16_supported(): |
| 89 | ++ torch.set_default_dtype(torch.bfloat16) |
| 90 | ++ else: |
| 91 | ++ torch.set_default_dtype(torch.half) |
| 92 | ++ elif device.type == "xpu": |
| 93 | ++ if torch.xpu.is_bf16_supported(): |
| 94 | ++ torch.set_default_dtype(torch.bfloat16) |
| 95 | ++ else: |
| 96 | ++ torch.set_default_dtype(torch.half) |
| 97 | + else: |
| 98 | +- torch.set_default_tensor_type(torch.cuda.HalfTensor) |
| 99 | ++ torch.set_default_dtype(torch.half) |
| 100 | ++ |
| 101 | + if model_args.vision_chunk_size > 0: |
| 102 | + from .multimodal.model import CrossAttentionTransformer |
| 103 | + |
| 104 | +@@ -150,6 +174,7 @@ class Llama: |
| 105 | + else: |
| 106 | + model = Transformer(model_args) |
| 107 | + model.load_state_dict(checkpoint, strict=True) |
| 108 | ++ model.to(device) |
| 109 | + print(f"Loaded in {time.time() - start_time:.2f} seconds") |
| 110 | + |
| 111 | + return Llama(model, tokenizer, model_args) |
| 112 | +@@ -213,14 +238,14 @@ class Llama: |
| 113 | + ) |
| 114 | + |
| 115 | + pad_id = self.tokenizer.pad_id |
| 116 | +- tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") |
| 117 | ++ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long) |
| 118 | + for k, t in enumerate(prompt_tokens): |
| 119 | +- tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") |
| 120 | ++ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long) |
| 121 | + if logprobs: |
| 122 | + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) |
| 123 | + |
| 124 | + prev_pos = 0 |
| 125 | +- eos_reached = torch.tensor([False] * bsz, device="cuda") |
| 126 | ++ eos_reached = torch.tensor([False] * bsz) |
| 127 | + input_text_mask = tokens != pad_id |
| 128 | + |
| 129 | + if echo: |
| 130 | +@@ -237,7 +262,7 @@ class Llama: |
| 131 | + for cur_pos in range(min_prompt_len, total_len): |
| 132 | + if is_vision: |
| 133 | + position_ids = torch.arange( |
| 134 | +- prev_pos, cur_pos, dtype=torch.long, device="cuda" |
| 135 | ++ prev_pos, cur_pos, dtype=torch.long |
| 136 | + ) |
| 137 | + text_only_inference = model_input.vision is None |
| 138 | + logits = self.model.forward( |
| 139 | +diff --git a/models/llama3/reference_impl/model.py b/models/llama3/reference_impl/model.py |
| 140 | +index 099a1ed..0713544 100644 |
| 141 | +--- a/models/llama3/reference_impl/model.py |
| 142 | ++++ b/models/llama3/reference_impl/model.py |
| 143 | +@@ -158,7 +158,7 @@ class Attention(nn.Module): |
| 144 | + self.n_local_kv_heads, |
| 145 | + self.head_dim, |
| 146 | + ) |
| 147 | +- ).cuda() |
| 148 | ++ ) |
| 149 | + self.cache_v = torch.zeros( |
| 150 | + ( |
| 151 | + args.max_batch_size, |
| 152 | +@@ -166,7 +166,7 @@ class Attention(nn.Module): |
| 153 | + self.n_local_kv_heads, |
| 154 | + self.head_dim, |
| 155 | + ) |
| 156 | +- ).cuda() |
| 157 | ++ ) |
| 158 | + |
| 159 | + def forward( |
| 160 | + self, |
| 161 | +diff --git a/models/llama3/tests/api/test_generation.py b/models/llama3/tests/api/test_generation.py |
| 162 | +index a71738b..259e3ae 100644 |
| 163 | +--- a/models/llama3/tests/api/test_generation.py |
| 164 | ++++ b/models/llama3/tests/api/test_generation.py |
| 165 | +@@ -33,6 +33,7 @@ def build_generator(env_var: str): |
| 166 | + max_seq_len=128, |
| 167 | + max_batch_size=1, |
| 168 | + model_parallel_size=1, |
| 169 | ++ device=os.getenv("DEVICE", "cuda") |
| 170 | + ) |
| 171 | + |
| 172 | + |
| 173 | +-- |
| 174 | +2.34.1 |
| 175 | + |
0 commit comments