diff --git a/crates/starknet_mempool/src/mempool.rs b/crates/starknet_mempool/src/mempool.rs index 313ccfd9021..0d03eb2d321 100644 --- a/crates/starknet_mempool/src/mempool.rs +++ b/crates/starknet_mempool/src/mempool.rs @@ -410,7 +410,8 @@ impl Mempool { } fn remove_expired_txs(&mut self) { - let removed_txs = self.tx_pool.remove_txs_older_than(self.config.transaction_ttl); + let removed_txs = + self.tx_pool.remove_txs_older_than(self.config.transaction_ttl, &self.state.staged); self.tx_queue.remove_txs(&removed_txs); metric_count_expired_txs(removed_txs.len()); } diff --git a/crates/starknet_mempool/src/mempool_test.rs b/crates/starknet_mempool/src/mempool_test.rs index da169de23ec..16213dccba2 100644 --- a/crates/starknet_mempool/src/mempool_test.rs +++ b/crates/starknet_mempool/src/mempool_test.rs @@ -1118,3 +1118,34 @@ fn test_register_metrics() { assert_eq!(get_metric_counter_txs_dropped(&metrics, drop_reason), 0); } } + +#[rstest] +fn expired_staged_txs_are_not_deleted() { + // Create a mempool with a fake clock. + let fake_clock = Arc::new(FakeClock::default()); + let mut mempool = Mempool::new( + MempoolConfig { transaction_ttl: Duration::from_secs(60), ..Default::default() }, + fake_clock.clone(), + ); + + // Add 2 transactions to the mempool, and stage one. + let staged_tx = + add_tx_input!(tx_hash: 1, address: "0x0", tx_nonce: 0, account_nonce: 0, tip: 100); + let nonstaged_tx = + add_tx_input!(tx_hash: 2, address: "0x0", tx_nonce: 1, account_nonce: 0, tip: 100); + add_tx(&mut mempool, &staged_tx); + add_tx(&mut mempool, &nonstaged_tx); + assert_eq!(mempool.get_txs(1).unwrap(), vec![staged_tx.tx.clone()]); + + // Advance the clock beyond the TTL. + fake_clock.advance(mempool.config.transaction_ttl + Duration::from_secs(5)); + + // Add another transaction to trigger the cleanup, and verify the staged tx is still in the + // mempool. The non-staged tx should be removed. + let another_tx = + add_tx_input!(tx_hash: 3, address: "0x1", tx_nonce: 0, account_nonce: 0, tip: 100); + add_tx(&mut mempool, &another_tx); + let expected_mempool_content = + MempoolTestContentBuilder::new().with_pool([staged_tx.tx, another_tx.tx]).build(); + expected_mempool_content.assert_eq(&mempool.content()); +} diff --git a/crates/starknet_mempool/src/transaction_pool.rs b/crates/starknet_mempool/src/transaction_pool.rs index b8aba5d0736..5ed75127b90 100644 --- a/crates/starknet_mempool/src/transaction_pool.rs +++ b/crates/starknet_mempool/src/transaction_pool.rs @@ -100,8 +100,12 @@ impl TransactionPool { removed_txs.len() } - pub fn remove_txs_older_than(&mut self, duration: Duration) -> Vec { - let removed_txs = self.txs_by_submission_time.remove_txs_older_than(duration); + pub fn remove_txs_older_than( + &mut self, + duration: Duration, + exclude_txs: &HashMap, + ) -> Vec { + let removed_txs = self.txs_by_submission_time.remove_txs_older_than(duration, exclude_txs); self.remove_from_main_mapping(&removed_txs); self.remove_from_account_mapping(&removed_txs); @@ -331,19 +335,31 @@ impl TimedTransactionMap { } /// Removes all transactions that were submitted to the pool before the given duration. - #[allow(dead_code)] - pub fn remove_txs_older_than(&mut self, duration: Duration) -> Vec { + /// Transactions for accounts listed in exclude_txs with nonces lower than the specified nonce + /// are preserved. + pub fn remove_txs_older_than( + &mut self, + duration: Duration, + exclude_txs: &HashMap, + ) -> Vec { let split_off_value = SubmissionID { submission_time: self.clock.now() - duration, tx_hash: Default::default(), }; - let removed_txs: Vec<_> = - self.txs_by_submission_time.split_off(&split_off_value).into_values().collect(); - - for tx in removed_txs.iter() { - self.hash_to_submission_id.remove(&tx.tx_hash).expect( - "Transaction should have a submission ID if it is in the timed transaction map.", - ); + let old_txs = self.txs_by_submission_time.split_off(&split_off_value); + + let mut removed_txs = Vec::new(); + for (submission_id, tx) in old_txs.into_iter() { + if exclude_txs.get(&tx.address).is_some_and(|nonce| tx.nonce < *nonce) { + // The transaction should be preserved. Add it back. + self.txs_by_submission_time.insert(submission_id, tx); + } else { + self.hash_to_submission_id.remove(&tx.tx_hash).expect( + "Transaction should have a submission ID if it is in the timed transaction \ + map.", + ); + removed_txs.push(tx); + } } removed_txs