Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize first latency beam search for OVModelForCausalLM #695

Merged

Conversation

eaidova
Copy link
Collaborator

@eaidova eaidova commented Apr 30, 2024

What does this PR do?

this PR reduces first token latency for OVModelForCausalLM class if beam search decoding selected. Beam search represented during generation as batch of sequences (generation batch size = [num_input_promts * num_beams]). Generation API duplicates initial input sequence for promoting them for each beam before starting work, while on the first step all sequences are equal (in the same time, the first inference for models with cache is the most time-consuming part). The idea is postpone sequence duplication for beams after first iteration done (including duplication of past key values and logits in outputs)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Sorry, something went wrong.

@eaidova eaidova changed the title Ea/optimize first latency beam search for OVModelForCausalLM optimize first latency beam search for OVModelForCausalLM Apr 30, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -651,6 +764,954 @@ def _from_pretrained(

return causal_model

def _beam_search(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify using comments which part has been modified from the original code, this would help a lot for review and future maintenance

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the observed performance gain for the first token generation ? There is a lot of method overwritten in this PR which could lead to potential issues in terms of future transformers compatibility for generation, so would like to make sure the performance gain is significant before considering merging

@echarlaix echarlaix requested a review from IlyasMoutawwakil May 3, 2024 16:04
@IlyasMoutawwakil
Copy link
Member

IMO this will be very heavy to maintain with the constant changes in transformers lib, especially since the text generation api will be undergoing heavy refactorization soon.

Would it not make sense to instead of optimizing the generation strategy, rather optimize the first forward pass, with something along the lines of:

def generate():
    if beam_search: # or any generation strategy where this issue is observed
        self.first_beam_search_iteration = True
     else:
        self.first_beam_search_iteration = False
    
    return super().generate()

def forward():
    if self.first_beam_search_iteration :
        unique_inputs, inverse_order = torch.unique(inputs, dim=0, return_inverse=True) 
        # we can also use what we know about how the inputs are duplicated to deduplicate them
        unique_outputs = super().forward(unique_inputs)
        outputs = unique_inputs[inverse_order]
        self.first_beam_search_iteration = False
    else:
        outputs = super().forward(inputs)
    return outputs

I admit that this is more stateful and hacky than what's suggested in the PR, but it requires maintaining less code, until this duplication issue with beam search gets fixed in transformers.

@eaidova
Copy link
Collaborator Author

eaidova commented May 13, 2024

@IlyasMoutawwakil, thank you for your suggestion, that is from what I begin, but problem that we need to know how inputs was duplicated for nonstateful case to duplicate past key values and this required additional context for that (from generation config) that is not provided inside forward. Another problem is next_beam_idx that should be different before second inference (contains initial index duplication instead of arranged indices from cache reordering)

@eaidova eaidova force-pushed the ea/optimize_first_latency_beam_search branch from 86c2baf to d42574a Compare May 13, 2024 07:05
@eaidova eaidova force-pushed the ea/optimize_first_latency_beam_search branch from 60f55ca to d216e3a Compare May 13, 2024 12:28
@eaidova
Copy link
Collaborator Author

eaidova commented May 13, 2024

@IlyasMoutawwakil @echarlaix please take a look one more time, I significantly updated code for reducing overriding beam search methods

@eaidova eaidova requested a review from echarlaix May 13, 2024 12:30
@eaidova eaidova force-pushed the ea/optimize_first_latency_beam_search branch from d216e3a to 0dbb104 Compare May 13, 2024 13:16
@eaidova eaidova force-pushed the ea/optimize_first_latency_beam_search branch from cbd5274 to 59c8c40 Compare May 14, 2024 08:50
eaidova added 2 commits May 15, 2024 09:18
@eaidova eaidova force-pushed the ea/optimize_first_latency_beam_search branch from 40c26f7 to b1fc04b Compare May 15, 2024 05:29
@eaidova eaidova requested a review from echarlaix May 15, 2024 06:54
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great thanks for iterating on this @eaidova

@echarlaix echarlaix merged commit 2b902bb into huggingface:main May 15, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants