-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize first latency beam search for OVModelForCausalLM #695
optimize first latency beam search for OVModelForCausalLM #695
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -651,6 +764,954 @@ def _from_pretrained( | |||
|
|||
return causal_model | |||
|
|||
def _beam_search( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you specify using comments which part has been modified from the original code, this would help a lot for review and future maintenance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the observed performance gain for the first token generation ? There is a lot of method overwritten in this PR which could lead to potential issues in terms of future transformers compatibility for generation, so would like to make sure the performance gain is significant before considering merging
IMO this will be very heavy to maintain with the constant changes in transformers lib, especially since the text generation api will be undergoing heavy refactorization soon. Would it not make sense to instead of optimizing the generation strategy, rather optimize the first forward pass, with something along the lines of: def generate():
if beam_search: # or any generation strategy where this issue is observed
self.first_beam_search_iteration = True
else:
self.first_beam_search_iteration = False
return super().generate()
def forward():
if self.first_beam_search_iteration :
unique_inputs, inverse_order = torch.unique(inputs, dim=0, return_inverse=True)
# we can also use what we know about how the inputs are duplicated to deduplicate them
unique_outputs = super().forward(unique_inputs)
outputs = unique_inputs[inverse_order]
self.first_beam_search_iteration = False
else:
outputs = super().forward(inputs)
return outputs I admit that this is more stateful and hacky than what's suggested in the PR, but it requires maintaining less code, until this duplication issue with beam search gets fixed in transformers. |
@IlyasMoutawwakil, thank you for your suggestion, that is from what I begin, but problem that we need to know how inputs was duplicated for nonstateful case to duplicate past key values and this required additional context for that (from generation config) that is not provided inside forward. Another problem is next_beam_idx that should be different before second inference (contains initial index duplication instead of arranged indices from cache reordering) |
86c2baf
to
d42574a
Compare
60f55ca
to
d216e3a
Compare
@IlyasMoutawwakil @echarlaix please take a look one more time, I significantly updated code for reducing overriding beam search methods |
d216e3a
to
0dbb104
Compare
cbd5274
to
59c8c40
Compare
40c26f7
to
b1fc04b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great thanks for iterating on this @eaidova
What does this PR do?
this PR reduces first token latency for OVModelForCausalLM class if beam search decoding selected. Beam search represented during generation as batch of sequences (generation batch size = [num_input_promts * num_beams]). Generation API duplicates initial input sequence for promoting them for each beam before starting work, while on the first step all sequences are equal (in the same time, the first inference for models with cache is the most time-consuming part). The idea is postpone sequence duplication for beams after first iteration done (including duplication of past key values and logits in outputs)
Before submitting