Skip to content

Commit bb43dc1

Browse files
authored
[LLAMA_CPP] Use separate states for infer requests (#908)
* Fix previous test * Add test * Implement separate states
1 parent a6b9f14 commit bb43dc1

File tree

6 files changed

+71
-23
lines changed

6 files changed

+71
-23
lines changed

modules/llama_cpp_plugin/include/compiled_model.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class LlamaCppModel : public ICompiledModel {
6161
private:
6262
gguf_context* m_gguf_ctx = nullptr;
6363
std::string m_gguf_fname;
64+
size_t m_num_threads;
6465

6566
llama_model* m_llama_model_ptr = nullptr;
6667
llama_context* m_llama_ctx = nullptr;

modules/llama_cpp_plugin/include/infer_request.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ namespace llama_cpp_plugin {
1212

1313
class LlamaCppSyncInferRequest : public ISyncInferRequest {
1414
public:
15-
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model);
16-
virtual ~LlamaCppSyncInferRequest(){};
15+
explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model, size_t num_threads);
16+
virtual ~LlamaCppSyncInferRequest() override;
1717

1818
virtual void set_tensors_impl(const ov::Output<const ov::Node> port,
1919
const std::vector<ov::SoPtr<ov::ITensor>>& tensors) override;
@@ -24,6 +24,7 @@ class LlamaCppSyncInferRequest : public ISyncInferRequest {
2424

2525
private:
2626
std::shared_ptr<const LlamaCppModel> m_compiled_model_ptr;
27+
llama_context* m_llama_ctx;
2728
};
2829

2930
} // namespace llama_cpp_plugin

modules/llama_cpp_plugin/include/state.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ namespace llama_cpp_plugin {
1212
class LlamaCppState : public IVariableState {
1313
public:
1414
LlamaCppState() = delete;
15-
LlamaCppState(const std::shared_ptr<const LlamaCppModel>& model_ptr)
16-
: m_model_ptr(model_ptr),
15+
LlamaCppState(llama_context* llama_context_ptr)
16+
: m_llama_ctx_ptr(llama_context_ptr),
1717
IVariableState("llama_cpp_state") {}
1818
void reset() override {
19-
llama_kv_cache_clear(m_model_ptr->m_llama_ctx);
19+
OPENVINO_ASSERT(m_llama_ctx_ptr != nullptr);
20+
llama_kv_cache_clear(m_llama_ctx_ptr);
2021
}
2122

2223
private:
23-
const std::shared_ptr<const LlamaCppModel>& m_model_ptr;
24+
llama_context* m_llama_ctx_ptr;
2425
};
2526
} // namespace llama_cpp_plugin
2627
} // namespace ov

modules/llama_cpp_plugin/src/compiled_model.cpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <openvino/opsets/opset13.hpp>
1010
#include <openvino/runtime/properties.hpp>
1111
#include <openvino/util/log.hpp>
12-
#include <thread>
1312

1413
#include "infer_request.hpp"
1514
#include "plugin.hpp"
@@ -18,7 +17,6 @@ namespace ov {
1817
namespace llama_cpp_plugin {
1918

2019
LlamaCppModel::~LlamaCppModel() {
21-
llama_free(m_llama_ctx);
2220
llama_free_model(m_llama_model_ptr);
2321
llama_backend_free();
2422
}
@@ -27,15 +25,12 @@ LlamaCppModel::LlamaCppModel(const std::string& gguf_fname,
2725
const std::shared_ptr<const IPlugin>& plugin,
2826
size_t num_threads)
2927
: ICompiledModel(nullptr, plugin),
30-
m_gguf_fname(gguf_fname) {
28+
m_gguf_fname(gguf_fname),
29+
m_num_threads(num_threads) {
3130
OPENVINO_DEBUG << "llama_cpp_plugin: loading llama model directly from GGUF... " << std::endl;
3231
llama_model_params mparams = llama_model_default_params();
3332
mparams.n_gpu_layers = 99;
3433
m_llama_model_ptr = llama_load_model_from_file(gguf_fname.c_str(), mparams);
35-
llama_context_params cparams = llama_context_default_params();
36-
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
37-
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
38-
m_llama_ctx = llama_new_context_with_model(m_llama_model_ptr, cparams);
3934
OPENVINO_DEBUG << "llama_cpp_plugin: llama model loaded successfully from GGUF..." << std::endl;
4035

4136
auto input_ids = std::make_shared<ov::opset13::Parameter>(ov::element::Type_t::i64, ov::PartialShape({-1, -1}));
@@ -87,8 +82,8 @@ ov::Any LlamaCppModel::get_property(const std::string& name) const {
8782
}
8883

8984
std::shared_ptr<ov::ISyncInferRequest> LlamaCppModel::create_sync_infer_request() const {
90-
return std::make_shared<LlamaCppSyncInferRequest>(
91-
std::static_pointer_cast<const LlamaCppModel>(shared_from_this()));
85+
return std::make_shared<LlamaCppSyncInferRequest>(std::static_pointer_cast<const LlamaCppModel>(shared_from_this()),
86+
m_num_threads);
9287
}
9388

9489
const std::vector<ov::Output<const ov::Node>>& LlamaCppModel::inputs() const {

modules/llama_cpp_plugin/src/infer_request.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include <memory>
77
#include <openvino/runtime/ivariable_state.hpp>
8+
#include <thread>
89

910
#include "llama.h"
1011
#include "openvino/runtime/make_tensor.hpp"
@@ -24,9 +25,14 @@ void allocate_tensor_impl(ov::SoPtr<ov::ITensor>& tensor,
2425
}
2526
}
2627

27-
LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model)
28+
LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model,
29+
size_t num_threads)
2830
: ov::ISyncInferRequest(compiled_model) {
2931
OPENVINO_DEBUG << "llama_cpp_plugin: infer request ctor called\n";
32+
llama_context_params cparams = llama_context_default_params();
33+
cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
34+
cparams.n_ctx = 0; // this means that the actual n_ctx will be taken equal to the model's train-time value
35+
m_llama_ctx = llama_new_context_with_model(compiled_model->m_llama_model_ptr, cparams);
3036
m_compiled_model_ptr = compiled_model;
3137
for (const auto& input : get_inputs()) {
3238
allocate_tensor(input, [input](ov::SoPtr<ov::ITensor>& tensor) {
@@ -97,8 +103,7 @@ void LlamaCppSyncInferRequest::infer() {
97103
}
98104
}
99105

100-
llama_context* ctx = m_compiled_model_ptr->m_llama_ctx;
101-
int32_t sts = llama_decode(ctx, batch);
106+
int32_t sts = llama_decode(m_llama_ctx, batch);
102107

103108
if (sts != 0) {
104109
OPENVINO_THROW("llama_decode failed with code ", sts);
@@ -112,7 +117,7 @@ void LlamaCppSyncInferRequest::infer() {
112117
for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
113118
for (size_t seq_idx = 0; seq_idx < sequence_length; seq_idx++) {
114119
size_t pos = batch_idx * sequence_length + seq_idx;
115-
float* logits_from_llama = llama_get_logits_ith(ctx, pos);
120+
float* logits_from_llama = llama_get_logits_ith(m_llama_ctx, pos);
116121
std::copy(logits_from_llama, logits_from_llama + n_vocab, output_tensor_data_ptr + pos * n_vocab);
117122
}
118123
}
@@ -132,7 +137,13 @@ std::vector<ov::ProfilingInfo> LlamaCppSyncInferRequest::get_profiling_info() co
132137

133138
std::vector<ov::SoPtr<ov::IVariableState>> LlamaCppSyncInferRequest::query_state() const {
134139
OPENVINO_DEBUG << "llama_cpp_plugin: query_state() called\n";
135-
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_compiled_model_ptr))};
140+
return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_llama_ctx))};
141+
}
142+
143+
LlamaCppSyncInferRequest::~LlamaCppSyncInferRequest() {
144+
if (m_llama_ctx != nullptr) {
145+
llama_free(m_llama_ctx);
146+
}
136147
}
137148
} // namespace llama_cpp_plugin
138149
} // namespace ov

modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp

+42-3
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,55 @@ TEST_F(CompiledModelTest, ResetStateGPT2) {
4242
SetUp();
4343

4444
ov::InferRequest lm_bad = model.create_infer_request();
45-
std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);
45+
std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm_bad, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);
4646

4747
// no reset_state on purpose
4848

49-
std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_reset,
49+
std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_bad,
5050
GPT2_SUN_PROMPT_TOKEN_IDS,
5151
0); // GPT2_LENNON_PROMPT_TOKEN_IDS.size());
5252

53-
std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_reset,
53+
std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_bad,
5454
get_token_from_logits(logits_sun_reset),
5555
NUM_TOKENS_TO_GENERATE,
5656
GPT2_SUN_PROMPT_TOKEN_IDS.size());
57+
ASSERT_NE(out_token_ids_bad, out_token_ids_ref);
58+
}
59+
60+
TEST_F(CompiledModelTest, StatesForDifferentInferRequestsAreIndependentGPT2) {
61+
// Take two infer requests, process two different prompts with same position IDs, but for one of them, do
62+
// .reset_state() in-between the inferences - check that the state is reset independently.
63+
64+
// the "new" sequence should have the same number of tokens as the previous one for this to work
65+
std::vector<int64_t> MODIFIED_PROMPT_TOKEN_IDS = GPT2_LENNON_PROMPT_TOKEN_IDS;
66+
MODIFIED_PROMPT_TOKEN_IDS.push_back(30); // extra newline
67+
ASSERT_EQ(GPT2_SUN_PROMPT_TOKEN_IDS.size(), MODIFIED_PROMPT_TOKEN_IDS.size());
68+
69+
ov::InferRequest first_infer_request = model.create_infer_request();
70+
std::vector<float> logits_first_ref = infer_and_get_last_logits(first_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);
71+
72+
ov::InferRequest another_infer_request = model.create_infer_request();
73+
std::vector<float> logits_another_ref =
74+
infer_and_get_last_logits(another_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);
75+
76+
first_infer_request.reset_state();
77+
78+
std::vector<float> logits_first_new_tokens_old_positions =
79+
infer_and_get_last_logits(first_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
80+
std::vector<int64_t> out_tokens_first =
81+
generate_n_tokens_with_positions(first_infer_request,
82+
get_token_from_logits(logits_first_new_tokens_old_positions),
83+
NUM_TOKENS_TO_GENERATE,
84+
MODIFIED_PROMPT_TOKEN_IDS.size());
85+
86+
// not resetting another_infer_request state on purpose
87+
std::vector<float> logits_another_new_tokens_old_positions =
88+
infer_and_get_last_logits(another_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
89+
std::vector<int64_t> out_tokens_another =
90+
generate_n_tokens_with_positions(another_infer_request,
91+
get_token_from_logits(logits_another_new_tokens_old_positions),
92+
NUM_TOKENS_TO_GENERATE,
93+
MODIFIED_PROMPT_TOKEN_IDS.size());
94+
95+
EXPECT_NE(out_tokens_another, out_tokens_first);
5796
}

0 commit comments

Comments
 (0)