@@ -42,16 +42,55 @@ TEST_F(CompiledModelTest, ResetStateGPT2) {
42
42
SetUp ();
43
43
44
44
ov::InferRequest lm_bad = model.create_infer_request ();
45
- std::vector<float > logits_lennon_bad = infer_and_get_last_logits (lm , GPT2_LENNON_PROMPT_TOKEN_IDS, 0 );
45
+ std::vector<float > logits_lennon_bad = infer_and_get_last_logits (lm_bad , GPT2_LENNON_PROMPT_TOKEN_IDS, 0 );
46
46
47
47
// no reset_state on purpose
48
48
49
- std::vector<float > logits_sun_bad = infer_and_get_last_logits (lm_reset ,
49
+ std::vector<float > logits_sun_bad = infer_and_get_last_logits (lm_bad ,
50
50
GPT2_SUN_PROMPT_TOKEN_IDS,
51
51
0 ); // GPT2_LENNON_PROMPT_TOKEN_IDS.size());
52
52
53
- std::vector<int64_t > out_token_ids_bad = generate_n_tokens_with_positions (lm_reset ,
53
+ std::vector<int64_t > out_token_ids_bad = generate_n_tokens_with_positions (lm_bad ,
54
54
get_token_from_logits (logits_sun_reset),
55
55
NUM_TOKENS_TO_GENERATE,
56
56
GPT2_SUN_PROMPT_TOKEN_IDS.size ());
57
+ ASSERT_NE (out_token_ids_bad, out_token_ids_ref);
58
+ }
59
+
60
+ TEST_F (CompiledModelTest, StatesForDifferentInferRequestsAreIndependentGPT2) {
61
+ // Take two infer requests, process two different prompts with same position IDs, but for one of them, do
62
+ // .reset_state() in-between the inferences - check that the state is reset independently.
63
+
64
+ // the "new" sequence should have the same number of tokens as the previous one for this to work
65
+ std::vector<int64_t > MODIFIED_PROMPT_TOKEN_IDS = GPT2_LENNON_PROMPT_TOKEN_IDS;
66
+ MODIFIED_PROMPT_TOKEN_IDS.push_back (30 ); // extra newline
67
+ ASSERT_EQ (GPT2_SUN_PROMPT_TOKEN_IDS.size (), MODIFIED_PROMPT_TOKEN_IDS.size ());
68
+
69
+ ov::InferRequest first_infer_request = model.create_infer_request ();
70
+ std::vector<float > logits_first_ref = infer_and_get_last_logits (first_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0 );
71
+
72
+ ov::InferRequest another_infer_request = model.create_infer_request ();
73
+ std::vector<float > logits_another_ref =
74
+ infer_and_get_last_logits (another_infer_request, GPT2_SUN_PROMPT_TOKEN_IDS, 0 );
75
+
76
+ first_infer_request.reset_state ();
77
+
78
+ std::vector<float > logits_first_new_tokens_old_positions =
79
+ infer_and_get_last_logits (first_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0 );
80
+ std::vector<int64_t > out_tokens_first =
81
+ generate_n_tokens_with_positions (first_infer_request,
82
+ get_token_from_logits (logits_first_new_tokens_old_positions),
83
+ NUM_TOKENS_TO_GENERATE,
84
+ MODIFIED_PROMPT_TOKEN_IDS.size ());
85
+
86
+ // not resetting another_infer_request state on purpose
87
+ std::vector<float > logits_another_new_tokens_old_positions =
88
+ infer_and_get_last_logits (another_infer_request, MODIFIED_PROMPT_TOKEN_IDS, 0 );
89
+ std::vector<int64_t > out_tokens_another =
90
+ generate_n_tokens_with_positions (another_infer_request,
91
+ get_token_from_logits (logits_another_new_tokens_old_positions),
92
+ NUM_TOKENS_TO_GENERATE,
93
+ MODIFIED_PROMPT_TOKEN_IDS.size ());
94
+
95
+ EXPECT_NE (out_tokens_another, out_tokens_first);
57
96
}
0 commit comments