Skip to content

Commit 6a4ba7f

Browse files
authored
Separate reset for KV state and LoRA state in LLMPipeline (openvinotoolkit#1058)
Fixing a bug when LoRA state is experienced reset each time when generate is invoked that brought unnecessary overhead in each generate call even if LoRA tensors/alphas are not changed.
1 parent 275729c commit 6a4ba7f

File tree

3 files changed

+46
-40
lines changed

3 files changed

+46
-40
lines changed

src/cpp/include/openvino/genai/lora_adapter.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ class OPENVINO_GENAI_EXPORTS AdapterController {
190190
// Apply adapters configured in the current config set last time, or set and use new config given as optional `config` argument
191191
void apply(ov::InferRequest& request, const std::optional<AdapterConfig>& config = std::nullopt);
192192

193-
// the next call of apply will set all adapter tensors regardless of config change, use this method if full state.reset is called for the controlled model
194-
void force_full_apply(bool full_apply = true);
193+
// Returns true if a given name is one of the state names created by this adapter controller for dynamic LoRA
194+
// Helps to distinguish LoRA states from other states (e.g. KV cache state) in the model for a partial state reset.
195+
bool has_state_name(const std::string& name);
195196

196197
operator bool() const {
197198
return bool(m_pimpl);

src/cpp/src/llm_pipeline.cpp

+15-8
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
8686
m_adapter_controller = AdapterController(model, m_generation_config.adapters, device); // TODO: Make the prefix name configurable
8787
utils::slice_matmul_statefull_model(model);
8888
m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request();
89-
m_adapter_controller->apply(m_model_runner, m_generation_config.adapters);
9089
} else {
9190
auto [core_plugin_config, compile_plugin_config] = ov::genai::utils::split_core_complile_config(plugin_config);
9291
core.set_property(core_plugin_config);
@@ -179,6 +178,18 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
179178
return decoded_results;
180179
}
181180

181+
void reset_kv_state() {
182+
if(m_adapter_controller) {
183+
for(auto& state: m_model_runner.query_state()) {
184+
if(!m_adapter_controller->has_state_name(state.get_name())) {
185+
state.reset();
186+
}
187+
}
188+
} else {
189+
m_model_runner.reset_state();
190+
}
191+
}
192+
182193
EncodedResults generate(
183194
const EncodedInputs& inputs,
184195
OptionalGenerationConfig generation_config,
@@ -273,11 +284,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
273284
}
274285

275286
if (!is_chat_conversation) {
276-
// FIXME: Reset only KV cache part of state, there is also can be LoRA applied in the states and full reset will need to reapply LoRA even if the LoRA config is not changed
277-
m_model_runner.reset_state();
278-
if(m_adapter_controller) {
279-
m_adapter_controller->force_full_apply(); // FIXME: Reset only KV cache part to avoid this call
280-
}
287+
reset_kv_state();
281288
m_selected_beam = std::nullopt;
282289
} else {
283290
m_is_cache_empty = false;
@@ -297,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
297304
is_chat_conversation = true;
298305
m_selected_beam = std::nullopt;
299306
if (!m_is_cache_empty) {
300-
m_model_runner.reset_state();
307+
reset_kv_state();
301308
m_is_cache_empty = true;
302309
m_history = {};
303310
m_templated_chat_history = "";
@@ -315,7 +322,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
315322
is_chat_conversation = false;
316323
m_selected_beam = std::nullopt;
317324
if (!m_is_cache_empty) {
318-
m_model_runner.reset_state();
325+
reset_kv_state();
319326
m_is_cache_empty = true;
320327
m_history.clear();
321328
m_templated_chat_history.clear();

src/cpp/src/lora_adapter.cpp

+28-30
Original file line numberDiff line numberDiff line change
@@ -367,17 +367,8 @@ struct LoRAParametersByWeightGetter {
367367
};
368368

369369

370-
// TODO: There is possible simplification if a new feature is implemented in OpenVINO:
371-
// move name from LoRAVarIDs to to LoRAIndices when the order of tensors in the model state in OV infer request will
372-
// be the same as the order of variables, remove LoRAVarsIDs in this case.
373-
374-
struct LoRAIndices : public LoRAParts<size_t> {
375-
std::string name;
376-
};
377-
378-
struct LoRAVarIDs : public LoRAParts<ov::op::util::VariableInfo> {
379-
std::string name; // layer name where LoRA with given variables is attached
380-
};
370+
using LoRAIndices = LoRAParts<size_t>;
371+
using LoRAVarIDs = LoRAParts<ov::op::util::VariableInfo>;
381372

382373

383374
// Deduce expected LoRA input and output static dimensions based on a given node where LoRA is applied
@@ -398,15 +389,18 @@ void deduce_input_output_dims(NodePtr node, ov::Dimension& input_dim, ov::Dimens
398389
}
399390

400391

392+
using LoRAVarMap = std::map<std::string, LoRAVarIDs>;
393+
394+
401395
// Creates ReadValue and Assign nodes to inject LoRA tensors as variables for a given node but
402396
// doesn't connect them to the model returning as LoRANode instance.
403397
struct LoRAWeightStateGetter {
404398
LoRAParametersGetter params_getter;
405399
std::shared_ptr<ov::Model> model;
406-
std::vector<LoRAVarIDs>& variable_ids;
400+
LoRAVarMap& variable_ids;
407401
// TODO: Use variable indices instead of variable_id for faster search for a state tensor
408402

409-
LoRAWeightStateGetter (const LoRAParametersGetter& params_getter, std::shared_ptr<ov::Model> model, std::vector<LoRAVarIDs>& variable_ids) :
403+
LoRAWeightStateGetter (const LoRAParametersGetter& params_getter, std::shared_ptr<ov::Model> model, LoRAVarMap& variable_ids) :
410404
params_getter(params_getter), model(model), variable_ids(variable_ids) {}
411405

412406
std::optional<LoRANode> operator() (NodePtr node) const {
@@ -420,7 +414,6 @@ struct LoRAWeightStateGetter {
420414
std::string variable_id_prefix = "lora_state_" + std::to_string(model->get_sinks().size()) + name;
421415
LoRANode result;
422416
LoRAVarIDs var_ids;
423-
var_ids.name = name;
424417

425418
// FIXME: No guarantees on ordering of state in InferRequest makes impossible using indices of variables later, forced to use variable_id instead
426419
//indices.A = model->get_variables().size();
@@ -446,7 +439,7 @@ struct LoRAWeightStateGetter {
446439
variable_id_prefix + ".B"
447440
};
448441
result.B = add_variable(var_ids.B);
449-
variable_ids.emplace_back(var_ids);
442+
variable_ids.emplace(name, var_ids);
450443
return result;
451444
} else {
452445
return std::nullopt;
@@ -815,7 +808,8 @@ bool operator< (const Adapter& a, const Adapter& b) {
815808

816809

817810
struct AdapterControllerImpl {
818-
std::vector<LoRAVarIDs> variable_ids;
811+
LoRAVarMap variable_ids;
812+
std::unordered_set<std::string> variable_names;
819813
AdapterConfig current_config;
820814
bool need_full_apply = true;
821815
InferRequestSignatureCache lora_state_evaluators;
@@ -890,6 +884,13 @@ struct AdapterControllerImpl {
890884

891885
pm.run_passes(model);
892886
model->validate_nodes_and_infer_types(); // FIXME: For debugging purposes only
887+
888+
// Collect all variable names to quickly detect which state tensor belongs to this adapter controller later
889+
for(const auto& var: variable_ids) {
890+
variable_names.insert(var.second.A.variable_id);
891+
variable_names.insert(var.second.B.variable_id);
892+
variable_names.insert(var.second.alpha.variable_id);
893+
}
893894
}
894895

895896
static std::shared_ptr<Adapter::Impl> get_adapter_impl(const Adapter& adapter) {
@@ -945,15 +946,14 @@ struct AdapterControllerImpl {
945946
} else if(diff) {
946947
if(diff.adapter) {
947948
set_new_adapter_tensors(infer_request);
948-
} else {
949-
OPENVINO_ASSERT(diff.alpha);
949+
} else if(diff.alpha) {
950950
set_new_adapter_alphas(infer_request);
951951
}
952952
}
953953
}
954954

955-
void force_full_apply(bool full_apply) {
956-
need_full_apply = full_apply;
955+
bool has_state_name(const std::string& name) {
956+
return variable_names.count(name);
957957
}
958958

959959
void set_new_adapter_alphas (ov::InferRequest& infer_request) {
@@ -988,12 +988,10 @@ struct AdapterControllerImpl {
988988
for(const auto& lora_var_ids : variable_ids) {
989989
// FIXME: Remove this mapping when the order of state will be the same as the order of variables
990990
LoRAIndices lora_indices;
991-
lora_indices.alpha = state_name_to_index.at(lora_var_ids.alpha.variable_id);
992-
lora_indices.A = state_name_to_index.at(lora_var_ids.A.variable_id);
993-
lora_indices.B = state_name_to_index.at(lora_var_ids.B.variable_id);
994-
lora_indices.name = lora_var_ids.name; // TODO: Redundant?
995-
996-
set_lora_tensors(state, lora_var_ids, lora_indices, weight_getters);
991+
lora_indices.alpha = state_name_to_index.at(lora_var_ids.second.alpha.variable_id);
992+
lora_indices.A = state_name_to_index.at(lora_var_ids.second.A.variable_id);
993+
lora_indices.B = state_name_to_index.at(lora_var_ids.second.B.variable_id);
994+
set_lora_tensors(state, lora_var_ids.first, lora_var_ids.second, lora_indices, weight_getters);
997995
}
998996
}
999997

@@ -1191,13 +1189,13 @@ struct AdapterControllerImpl {
11911189
return shape;
11921190
}
11931191

1194-
void set_lora_tensors(std::vector<VariableState>& state, const LoRAVarIDs& lora_var_ids, const LoRAIndices& lora_indices, const std::vector<LoRAWeightGetter>& weight_getters) {
1192+
void set_lora_tensors(std::vector<VariableState>& state, const std::string& name, const LoRAVarIDs& lora_var_ids, const LoRAIndices& lora_indices, const std::vector<LoRAWeightGetter>& weight_getters) {
11951193
LoRAParts<ov::Tensor> lora_state_tensors{
11961194
ov::Tensor(lora_var_ids.alpha.data_type, dynamic_to_static(lora_var_ids.alpha.data_shape)),
11971195
ov::Tensor(lora_var_ids.A.data_type, dynamic_to_static(lora_var_ids.A.data_shape)),
11981196
ov::Tensor(lora_var_ids.B.data_type, dynamic_to_static(lora_var_ids.B.data_shape))
11991197
};
1200-
auto new_tensors = prepare_lora_tensors(lora_indices.name, weight_getters, lora_state_tensors);
1198+
auto new_tensors = prepare_lora_tensors(name, weight_getters, lora_state_tensors);
12011199

12021200
state[lora_indices.alpha].set_state(new_tensors.alpha);
12031201
state[lora_indices.A].set_state(new_tensors.A);
@@ -1269,8 +1267,8 @@ void AdapterController::apply(ov::InferRequest& request, const std::optional<Ada
12691267
}
12701268

12711269

1272-
void AdapterController::force_full_apply(bool full_apply) {
1273-
return m_pimpl->force_full_apply(full_apply);
1270+
bool AdapterController::has_state_name(const std::string& name) {
1271+
return m_pimpl->has_state_name(name);
12741272
}
12751273

12761274

0 commit comments

Comments
 (0)