From ccf875d874220ec3f096438943065e762aaf149c Mon Sep 17 00:00:00 2001
From: Vasily Shamporov <vasily.shamporov@intel.com>
Date: Mon, 22 Apr 2024 12:32:15 +0200
Subject: [PATCH 1/3] Fix previous test

---
 .../llama_cpp_plugin/tests/functional/src/reset_state.cpp  | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp b/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
index 020867000..044980136 100644
--- a/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
+++ b/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
@@ -42,16 +42,17 @@ TEST_F(CompiledModelTest, ResetStateGPT2) {
     SetUp();
 
     ov::InferRequest lm_bad = model.create_infer_request();
-    std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);
+    std::vector<float> logits_lennon_bad = infer_and_get_last_logits(lm_bad, GPT2_LENNON_PROMPT_TOKEN_IDS, 0);
 
     // no reset_state on purpose
 
-    std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_reset,
+    std::vector<float> logits_sun_bad = infer_and_get_last_logits(lm_bad,
                                                                   GPT2_SUN_PROMPT_TOKEN_IDS,
                                                                   0);  // GPT2_LENNON_PROMPT_TOKEN_IDS.size());
 
-    std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_reset,
+    std::vector<int64_t> out_token_ids_bad = generate_n_tokens_with_positions(lm_bad,
                                                                               get_token_from_logits(logits_sun_reset),
                                                                               NUM_TOKENS_TO_GENERATE,
                                                                               GPT2_SUN_PROMPT_TOKEN_IDS.size());
+    ASSERT_NE(out_token_ids_bad, out_token_ids_ref);
 }

From ad9dee2657e578e015ad062e2f3d2e415a5dc1ac Mon Sep 17 00:00:00 2001
From: Vasily Shamporov <vasily.shamporov@intel.com>
Date: Mon, 22 Apr 2024 11:40:34 +0200
Subject: [PATCH 2/3] Add test

---
 .../tests/functional/src/reset_state.cpp      | 38 +++++++++++++++++++
 1 file changed, 38 insertions(+)

diff --git a/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp b/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
index 044980136..ce97fb278 100644
--- a/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
+++ b/modules/llama_cpp_plugin/tests/functional/src/reset_state.cpp
@@ -56,3 +56,41 @@ TEST_F(CompiledModelTest, ResetStateGPT2) {
                                                                               GPT2_SUN_PROMPT_TOKEN_IDS.size());
     ASSERT_NE(out_token_ids_bad, out_token_ids_ref);
 }
+
+TEST_F(CompiledModelTest, StatesForDifferentInferRequestsAreIndependentGPT2) {
+    // Take two infer requests, process two different prompts with same position IDs, but for one of them, do
+    // .reset_state() in-between the inferences - check that the state is reset independently.
+
+    // the "new" sequence should have the same number of tokens as the previous one for this to work
+    std::vector<int64_t> MODIFIED_PROMPT_TOKEN_IDS = GPT2_LENNON_PROMPT_TOKEN_IDS;
+    MODIFIED_PROMPT_TOKEN_IDS.push_back(30);  // extra newline
+    ASSERT_EQ(GPT2_SUN_PROMPT_TOKEN_IDS.size(), MODIFIED_PROMPT_TOKEN_IDS.size());
+
+    ov::InferRequest first_infer_request = model.create_infer_request();
+    std::vector<float> logits_first_ref = infer_and_get_last_logits(first_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);
+
+    ov::InferRequest another_infer_request = model.create_infer_request();
+    std::vector<float> logits_another_ref =
+        infer_and_get_last_logits(another_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0);
+
+    first_infer_request.reset_state();
+
+    std::vector<float> logits_first_new_tokens_old_positions =
+        infer_and_get_last_logits(first_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
+    std::vector<int64_t> out_tokens_first =
+        generate_n_tokens_with_positions(first_infer_request,
+                                         get_token_from_logits(logits_first_new_tokens_old_positions),
+                                         NUM_TOKENS_TO_GENERATE,
+                                         MODIFIED_PROMPT_TOKEN_IDS.size());
+
+    // not resetting another_infer_request state on purpose
+    std::vector<float> logits_another_new_tokens_old_positions =
+        infer_and_get_last_logits(another_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0);
+    std::vector<int64_t> out_tokens_another =
+        generate_n_tokens_with_positions(another_infer_request,
+                                         get_token_from_logits(logits_another_new_tokens_old_positions),
+                                         NUM_TOKENS_TO_GENERATE,
+                                         MODIFIED_PROMPT_TOKEN_IDS.size());
+
+    EXPECT_NE(out_tokens_another, out_tokens_first);
+}

From 7a7af2220b3fec0930f386f63cfc18737e0cd13b Mon Sep 17 00:00:00 2001
From: Vasily Shamporov <vasily.shamporov@intel.com>
Date: Mon, 22 Apr 2024 12:26:37 +0200
Subject: [PATCH 3/3] Implement separate states

---
 .../include/compiled_model.hpp                |  1 +
 .../include/infer_request.hpp                 |  5 +++--
 modules/llama_cpp_plugin/include/state.hpp    |  9 ++++----
 .../llama_cpp_plugin/src/compiled_model.cpp   | 13 ++++--------
 .../llama_cpp_plugin/src/infer_request.cpp    | 21 ++++++++++++++-----
 5 files changed, 29 insertions(+), 20 deletions(-)

diff --git a/modules/llama_cpp_plugin/include/compiled_model.hpp b/modules/llama_cpp_plugin/include/compiled_model.hpp
index c59c2288e..90babd314 100644
--- a/modules/llama_cpp_plugin/include/compiled_model.hpp
+++ b/modules/llama_cpp_plugin/include/compiled_model.hpp
@@ -61,6 +61,7 @@ class LlamaCppModel : public ICompiledModel {
 private:
     gguf_context* m_gguf_ctx = nullptr;
     std::string m_gguf_fname;
+    size_t m_num_threads;
 
     llama_model* m_llama_model_ptr = nullptr;
     llama_context* m_llama_ctx = nullptr;
diff --git a/modules/llama_cpp_plugin/include/infer_request.hpp b/modules/llama_cpp_plugin/include/infer_request.hpp
index 8f298ab57..ef8daaaa5 100644
--- a/modules/llama_cpp_plugin/include/infer_request.hpp
+++ b/modules/llama_cpp_plugin/include/infer_request.hpp
@@ -12,8 +12,8 @@ namespace llama_cpp_plugin {
 
 class LlamaCppSyncInferRequest : public ISyncInferRequest {
 public:
-    explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model);
-    virtual ~LlamaCppSyncInferRequest(){};
+    explicit LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model, size_t num_threads);
+    virtual ~LlamaCppSyncInferRequest() override;
 
     virtual void set_tensors_impl(const ov::Output<const ov::Node> port,
                                   const std::vector<ov::SoPtr<ov::ITensor>>& tensors) override;
@@ -24,6 +24,7 @@ class LlamaCppSyncInferRequest : public ISyncInferRequest {
 
 private:
     std::shared_ptr<const LlamaCppModel> m_compiled_model_ptr;
+    llama_context* m_llama_ctx;
 };
 
 }  // namespace llama_cpp_plugin
diff --git a/modules/llama_cpp_plugin/include/state.hpp b/modules/llama_cpp_plugin/include/state.hpp
index 229970894..032f29dc6 100644
--- a/modules/llama_cpp_plugin/include/state.hpp
+++ b/modules/llama_cpp_plugin/include/state.hpp
@@ -12,15 +12,16 @@ namespace llama_cpp_plugin {
 class LlamaCppState : public IVariableState {
 public:
     LlamaCppState() = delete;
-    LlamaCppState(const std::shared_ptr<const LlamaCppModel>& model_ptr)
-        : m_model_ptr(model_ptr),
+    LlamaCppState(llama_context* llama_context_ptr)
+        : m_llama_ctx_ptr(llama_context_ptr),
           IVariableState("llama_cpp_state") {}
     void reset() override {
-        llama_kv_cache_clear(m_model_ptr->m_llama_ctx);
+        OPENVINO_ASSERT(m_llama_ctx_ptr != nullptr);
+        llama_kv_cache_clear(m_llama_ctx_ptr);
     }
 
 private:
-    const std::shared_ptr<const LlamaCppModel>& m_model_ptr;
+    llama_context* m_llama_ctx_ptr;
 };
 }  // namespace llama_cpp_plugin
 }  // namespace ov
diff --git a/modules/llama_cpp_plugin/src/compiled_model.cpp b/modules/llama_cpp_plugin/src/compiled_model.cpp
index 9e5fdec1d..b53b11363 100644
--- a/modules/llama_cpp_plugin/src/compiled_model.cpp
+++ b/modules/llama_cpp_plugin/src/compiled_model.cpp
@@ -9,7 +9,6 @@
 #include <openvino/opsets/opset13.hpp>
 #include <openvino/runtime/properties.hpp>
 #include <openvino/util/log.hpp>
-#include <thread>
 
 #include "infer_request.hpp"
 #include "plugin.hpp"
@@ -18,7 +17,6 @@ namespace ov {
 namespace llama_cpp_plugin {
 
 LlamaCppModel::~LlamaCppModel() {
-    llama_free(m_llama_ctx);
     llama_free_model(m_llama_model_ptr);
     llama_backend_free();
 }
@@ -27,15 +25,12 @@ LlamaCppModel::LlamaCppModel(const std::string& gguf_fname,
                              const std::shared_ptr<const IPlugin>& plugin,
                              size_t num_threads)
     : ICompiledModel(nullptr, plugin),
-      m_gguf_fname(gguf_fname) {
+      m_gguf_fname(gguf_fname),
+      m_num_threads(num_threads) {
     OPENVINO_DEBUG << "llama_cpp_plugin: loading llama model directly from GGUF... " << std::endl;
     llama_model_params mparams = llama_model_default_params();
     mparams.n_gpu_layers = 99;
     m_llama_model_ptr = llama_load_model_from_file(gguf_fname.c_str(), mparams);
-    llama_context_params cparams = llama_context_default_params();
-    cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
-    cparams.n_ctx = 0;  // this means that the actual n_ctx will be taken equal to the model's train-time value
-    m_llama_ctx = llama_new_context_with_model(m_llama_model_ptr, cparams);
     OPENVINO_DEBUG << "llama_cpp_plugin: llama model loaded successfully from GGUF..." << std::endl;
 
     auto input_ids = std::make_shared<ov::opset13::Parameter>(ov::element::Type_t::i64, ov::PartialShape({-1, -1}));
@@ -87,8 +82,8 @@ ov::Any LlamaCppModel::get_property(const std::string& name) const {
 }
 
 std::shared_ptr<ov::ISyncInferRequest> LlamaCppModel::create_sync_infer_request() const {
-    return std::make_shared<LlamaCppSyncInferRequest>(
-        std::static_pointer_cast<const LlamaCppModel>(shared_from_this()));
+    return std::make_shared<LlamaCppSyncInferRequest>(std::static_pointer_cast<const LlamaCppModel>(shared_from_this()),
+                                                      m_num_threads);
 }
 
 const std::vector<ov::Output<const ov::Node>>& LlamaCppModel::inputs() const {
diff --git a/modules/llama_cpp_plugin/src/infer_request.cpp b/modules/llama_cpp_plugin/src/infer_request.cpp
index 76fba58cd..3eefd56d9 100644
--- a/modules/llama_cpp_plugin/src/infer_request.cpp
+++ b/modules/llama_cpp_plugin/src/infer_request.cpp
@@ -5,6 +5,7 @@
 
 #include <memory>
 #include <openvino/runtime/ivariable_state.hpp>
+#include <thread>
 
 #include "llama.h"
 #include "openvino/runtime/make_tensor.hpp"
@@ -24,9 +25,14 @@ void allocate_tensor_impl(ov::SoPtr<ov::ITensor>& tensor,
     }
 }
 
-LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model)
+LlamaCppSyncInferRequest::LlamaCppSyncInferRequest(const std::shared_ptr<const LlamaCppModel>& compiled_model,
+                                                   size_t num_threads)
     : ov::ISyncInferRequest(compiled_model) {
     OPENVINO_DEBUG << "llama_cpp_plugin: infer request ctor called\n";
+    llama_context_params cparams = llama_context_default_params();
+    cparams.n_threads = num_threads ? num_threads : std::thread::hardware_concurrency();
+    cparams.n_ctx = 0;  // this means that the actual n_ctx will be taken equal to the model's train-time value
+    m_llama_ctx = llama_new_context_with_model(compiled_model->m_llama_model_ptr, cparams);
     m_compiled_model_ptr = compiled_model;
     for (const auto& input : get_inputs()) {
         allocate_tensor(input, [input](ov::SoPtr<ov::ITensor>& tensor) {
@@ -97,8 +103,7 @@ void LlamaCppSyncInferRequest::infer() {
         }
     }
 
-    llama_context* ctx = m_compiled_model_ptr->m_llama_ctx;
-    int32_t sts = llama_decode(ctx, batch);
+    int32_t sts = llama_decode(m_llama_ctx, batch);
 
     if (sts != 0) {
         OPENVINO_THROW("llama_decode failed with code ", sts);
@@ -112,7 +117,7 @@ void LlamaCppSyncInferRequest::infer() {
     for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
         for (size_t seq_idx = 0; seq_idx < sequence_length; seq_idx++) {
             size_t pos = batch_idx * sequence_length + seq_idx;
-            float* logits_from_llama = llama_get_logits_ith(ctx, pos);
+            float* logits_from_llama = llama_get_logits_ith(m_llama_ctx, pos);
             std::copy(logits_from_llama, logits_from_llama + n_vocab, output_tensor_data_ptr + pos * n_vocab);
         }
     }
@@ -132,7 +137,13 @@ std::vector<ov::ProfilingInfo> LlamaCppSyncInferRequest::get_profiling_info() co
 
 std::vector<ov::SoPtr<ov::IVariableState>> LlamaCppSyncInferRequest::query_state() const {
     OPENVINO_DEBUG << "llama_cpp_plugin: query_state() called\n";
-    return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_compiled_model_ptr))};
+    return {std::static_pointer_cast<ov::IVariableState>(std::make_shared<LlamaCppState>(m_llama_ctx))};
+}
+
+LlamaCppSyncInferRequest::~LlamaCppSyncInferRequest() {
+    if (m_llama_ctx != nullptr) {
+        llama_free(m_llama_ctx);
+    }
 }
 }  // namespace llama_cpp_plugin
 }  // namespace ov