From f96edd27b3a12919be88b12835a3c290f6671047 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 17 Jan 2025 17:54:49 -0500 Subject: [PATCH 1/9] add support for flash decoding on xpu Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/cache_utils.py | 66 +++++++++++++++++------- optimum/exporters/ipex/modeling_utils.py | 21 ++++---- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index f9df2cf69a..d3555e73d0 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -5,6 +5,8 @@ from intel_extension_for_pytorch.llm.modules import PagedAttention from transformers import Cache, PretrainedConfig +from optimum.intel.utils.import_utils import is_ipex_version + class IPEXPagedCache(Cache): """ @@ -43,6 +45,10 @@ def __init__( ) -> None: super().__init__() self.batch_size = batch_size + self.device = device + self.flash_decoding = ( + is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99") + ) # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) default_block_size = 16 if device.type == "cpu" else 64 @@ -69,14 +75,43 @@ def __init__( key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": - key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) - value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + if self.flash_decoding: + key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + else: + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) for i in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) + def reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + ): + if self.device.type == "xpu" and self.flash_decoding: + PagedAttention.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slots, + ) + else: + PagedAttention.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + ) + def update_for_prefill( self, key_states: torch.Tensor, @@ -94,7 +129,7 @@ def update_for_prefill( block_table = self.free_blocks.nonzero().view(-1)[0:nb] self.block_tables[i][0:nb] = block_table self.free_blocks[block_table] = 0 - slots_range = torch.arange(input_lens[i], device=key_states.device) + slots_range = torch.arange(input_lens[i], device=self.device) block_indices = slots_range // self.block_size slot_offsets = slots_range % self.block_size all_block_indices.append(self.block_tables[i][block_indices]) @@ -104,12 +139,8 @@ def update_for_prefill( all_slot_offsets = torch.cat(all_slot_offsets) self.slots = all_block_indices * self.block_size + all_slot_offsets # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -127,7 +158,7 @@ def update_for_decode( if layer_idx == 0: start_block_idx = self._seen_tokens // self.block_size slot_offset_in_block = (self._seen_tokens) % self.block_size - self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) + self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32) for i in range(batch_size): if slot_offset_in_block[i] == 0: # need a new block: @@ -138,12 +169,8 @@ def update_for_decode( self.free_blocks[self.block_tables[i][b_idx]] = 0 self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -193,16 +220,15 @@ def get_max_length(self) -> Optional[int]: def reset(self): """Resets the cache values while preserving the objects""" - self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) + self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.device) self.block_tables.fill_(-1) - self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device) + self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device) self.max_seq_len = 0 def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - device = self.block_tables.device origin_table = self.block_tables.clone() - updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device)) mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) num_blocks = mask.cumsum(-1)[:, -1] updated_table = torch.zeros_like(beam_idx) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693df..cc57de14e7 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -628,10 +628,10 @@ def postprocess_attention_output(self, attn_output): return attn_output # Maybe removed after torch 2.6 released - def has_flash_attn(self, query): - if query.device.type == "cpu": + def has_flash_attn(self): + if self.module_device.type == "cpu": return is_torch_version(">", "2.4.99") - elif query.device.type == "xpu": + elif self.module_device.type == "xpu": return is_torch_version(">", "2.5.99") def attention_interface( @@ -652,20 +652,23 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() - query_max_len = input_lens.max() if past_len == 0 else 1 + query_len_tensor = ( + seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int() + ) + max_input_lens = input_lens.max().item() + query_max_len = max_input_lens if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + key_cache, + value_cache, query_len_tensor, seq_len_tensor, query_max_len, - input_lens.max(), + max_input_lens, 1.0 / math.sqrt(self.head_dim), True, past_key_value.block_tables, From 850195ec7c44a7c1cafceaa008922c84579c13ad Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Tue, 21 Jan 2025 05:34:13 -0500 Subject: [PATCH 2/9] change block_size to 16 for xpu, as `single_query_cached_kv_attention` API does not support 64 Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 7d9f012e7c..83cffa1a4c 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -52,7 +52,7 @@ def __init__( # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device) - default_block_size = 16 if device.type == "cpu" else 64 + default_block_size = 16 self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( From 8dacb0ab82fea3222823042f63885692449efa54 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 23 Jan 2025 10:21:39 -0500 Subject: [PATCH 3/9] optimize the performance Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 61 +++++++++++++++++++----- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index cc57de14e7..e2ca92c926 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -179,8 +179,8 @@ def _llama_model_forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + device = input_ids.device if input_ids is not None else inputs_embeds.device if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) @@ -200,6 +200,9 @@ def _llama_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -235,6 +238,9 @@ def _llama_model_forward( use_cache=use_cache, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = layer_outputs[0] @@ -303,10 +309,11 @@ def _falcon_model_forward( past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 batch_size, seq_length, _ = inputs_embeds.shape + device = input_ids.device if input_ids is not None else inputs_embeds.device if cache_position is None: cache_position = torch.arange( - past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + past_key_values_length, past_key_values_length + seq_length, device=device ) if position_ids is None: @@ -323,6 +330,9 @@ def _falcon_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -365,6 +375,9 @@ def _falcon_model_forward( cache_position=cache_position, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = outputs[0] @@ -459,6 +472,9 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -494,6 +510,9 @@ def _gpt2_model_forward( use_cache=use_cache, output_attentions=output_attentions, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, ) hidden_states = outputs[0] @@ -635,7 +654,19 @@ def has_flash_attn(self): return is_torch_version(">", "2.5.99") def attention_interface( - self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + self, + query, + key_cache, + value_cache, + key, + value, + past_key_value, + attention_mask, + input_lens, + past_len, + seq_len_tensor, + query_len_tensor, + max_input_lens, ): if past_key_value is None: n_rep = query.shape[1] // key.shape[1] @@ -654,18 +685,13 @@ def attention_interface( self.use_sdpa = True elif self.has_flash_attn(): attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = ( - seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int() - ) - max_input_lens = input_lens.max().item() query_max_len = max_input_lens if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache, value_cache, - query_len_tensor, + seq_len_tensor if past_len == 0 else query_len_tensor, seq_len_tensor, query_max_len, max_input_lens, @@ -677,7 +703,6 @@ def attention_interface( elif past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, @@ -726,6 +751,9 @@ def forward( if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) input_lens = kwargs.pop("input_lens", None) + seq_len_tensor = kwargs.pop("seq_len_tensor", None) + query_len_tensor = kwargs.pop("query_len_tensor", None) + max_input_lens = kwargs.pop("max_input_lens", 0) past_len = 0 if past_key_value is not None: past_len = past_key_value.get_seq_length() @@ -737,7 +765,18 @@ def forward( key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) attn_output = self.attention_interface( - query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + query, + key_cache, + value_cache, + key, + value, + past_key_value, + attention_mask, + input_lens, + past_len, + seq_len_tensor, + query_len_tensor, + max_input_lens, ) attn_output = self.postprocess_attention_output(attn_output) From 23a3b54cf7296bfff2e527e49ae40c8983678e3c Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 6 Feb 2025 18:16:52 -0500 Subject: [PATCH 4/9] fix format issue Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e2ca92c926..1eae6014ad 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -312,9 +312,7 @@ def _falcon_model_forward( device = input_ids.device if input_ids is not None else inputs_embeds.device if cache_position is None: - cache_position = torch.arange( - past_key_values_length, past_key_values_length + seq_length, device=device - ) + cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) if position_ids is None: position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) From 3616dd24d39bbe2293c25864b6cf0394f8eb910d Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Mon, 17 Feb 2025 15:37:09 -0500 Subject: [PATCH 5/9] refine code Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/cache_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 83cffa1a4c..7b8ab1cc7f 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -46,7 +46,7 @@ def __init__( super().__init__() self.max_batch_size = max_batch_size self.device = device - self.flash_decoding = ( + self._supports_flash_decoding = ( is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99") ) # Used in `generate` to keep tally of how many tokens the cache has seen @@ -76,7 +76,7 @@ def __init__( key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": - if self.flash_decoding: + if self._supports_flash_decoding: key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) else: @@ -96,7 +96,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - if self.device.type == "xpu" and self.flash_decoding: + # TODO: unify API definition between CPU and XPU in IPEX version > 2.6 + if self.device.type == "xpu" and self._supports_flash_decoding: PagedAttention.reshape_and_cache_flash( key, value, From 081ff45961f2a12b4c7f95b9e0b355f3caab0752 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 17 Jan 2025 17:54:49 -0500 Subject: [PATCH 6/9] add support for flash decoding on xpu Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 37f9635aa6..333a0fd7f7 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -810,7 +810,7 @@ def attention_interface( query.contiguous() if query.device.type == "xpu" else query, key_cache, value_cache, - seq_len_tensor if past_len == 0 else query_len_tensor, + query_len_tensor, seq_len_tensor, query_max_len, max_input_lens, From d5100b4244ab85b6851ac36ed69ed954e9d2c036 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 19 Feb 2025 14:13:40 +0000 Subject: [PATCH 7/9] fix conflict CI Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 333a0fd7f7..ff348201ca 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -653,6 +653,7 @@ def _qwen2_model_forward( inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = inputs_embeds.shape[:2] + device = input_ids.device if input_ids is not None else inputs_embeds.device past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: @@ -677,6 +678,9 @@ def _qwen2_model_forward( position_embeddings = self.rotary_emb(hidden_states, position_ids) input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = torch.arange(seq_len_tensor.shape[0], device=device).int() + max_input_lens = input_lens.max().item() if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask @@ -712,6 +716,9 @@ def _qwen2_model_forward( cache_position=cache_position, position_embeddings=position_embeddings, input_lens=input_lens, + max_input_lens=max_input_lens, + seq_len_tensor=seq_len_tensor, + query_len_tensor=query_len_tensor, **kwargs, ) From 775a13e003fb34794f409edef5a98d244c369365 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 20 Feb 2025 05:43:16 -0500 Subject: [PATCH 8/9] fix bug Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/modeling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ff348201ca..3ff807c17e 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -811,6 +811,7 @@ def attention_interface( self.use_sdpa = True elif self.has_flash_attn(): attn_output = torch.empty_like(query) + query_len_tensor = seq_len_tensor if past_len == 0 else query_len_tensor query_max_len = max_input_lens if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, From fe4d106b402855b67b2fddba1f5f97ddff17e6cf Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 20 Feb 2025 09:08:13 -0500 Subject: [PATCH 9/9] fix CI bug Signed-off-by: Liu, Kaixuan --- tests/ipex/test_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index eda4f09614..dc15d161c8 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -291,7 +291,7 @@ def test_forward(self, model_arch): dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) self.assertIsInstance(ipex_model.config, PretrainedConfig) - input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long) + input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long).to(DEVICE) outputs = ipex_model(input_ids) self.assertIsInstance(outputs.logits, torch.Tensor)