Simplify removeRecursive

This commit is contained in:
Suhas Daftuar 2025-02-05 09:11:40 -05:00
parent 01d8520038
commit a5a7905d83
2 changed files with 37 additions and 23 deletions

View File

@ -308,35 +308,41 @@ CTxMemPool::txiter CTxMemPool::CalculateDescendants(const CTxMemPoolEntry& entry
return mapTx.iterator_to(entry);
}
void CTxMemPool::removeRecursive(CTxMemPool::txiter to_remove, MemPoolRemovalReason reason)
{
AssertLockHeld(cs);
Assume(!m_have_changeset);
auto descendants = m_txgraph->GetDescendants(*to_remove, TxGraph::Level::MAIN);
for (auto tx: descendants) {
removeUnchecked(mapTx.iterator_to(static_cast<const CTxMemPoolEntry&>(*tx)), reason);
}
}
void CTxMemPool::removeRecursive(const CTransaction &origTx, MemPoolRemovalReason reason)
{
// Remove transaction from memory pool
AssertLockHeld(cs);
Assume(!m_have_changeset);
setEntries txToRemove;
txiter origit = mapTx.find(origTx.GetHash());
if (origit != mapTx.end()) {
txToRemove.insert(origit);
} else {
// When recursively removing but origTx isn't in the mempool
// be sure to remove any children that are in the pool. This can
// happen during chain re-orgs if origTx isn't re-accepted into
// the mempool for any reason.
for (unsigned int i = 0; i < origTx.vout.size(); i++) {
auto it = mapNextTx.find(COutPoint(origTx.GetHash(), i));
if (it == mapNextTx.end())
continue;
txiter nextit = it->second;
assert(nextit != mapTx.end());
txToRemove.insert(nextit);
}
txiter origit = mapTx.find(origTx.GetHash());
if (origit != mapTx.end()) {
removeRecursive(origit, reason);
} else {
// When recursively removing but origTx isn't in the mempool
// be sure to remove any descendants that are in the pool. This can
// happen during chain re-orgs if origTx isn't re-accepted into
// the mempool for any reason.
auto iter = mapNextTx.lower_bound(COutPoint(origTx.GetHash(), 0));
std::vector<const TxGraph::Ref*> to_remove;
while (iter != mapNextTx.end() && iter->first->hash == origTx.GetHash()) {
to_remove.emplace_back(&*(iter->second));
++iter;
}
setEntries setAllRemoves;
for (txiter it : txToRemove) {
CalculateDescendants(it, setAllRemoves);
auto all_removes = m_txgraph->GetDescendantsUnion(to_remove, TxGraph::Level::MAIN);
for (auto ref : all_removes) {
auto tx = mapTx.iterator_to(static_cast<const CTxMemPoolEntry&>(*ref));
removeUnchecked(tx, reason);
}
RemoveStaged(setAllRemoves, reason);
}
}
void CTxMemPool::removeForReorg(CChain& chain, std::function<bool(txiter)> check_final_and_mature)
@ -372,7 +378,7 @@ void CTxMemPool::removeConflicts(const CTransaction &tx)
if (Assume(txConflict.GetHash() != tx.GetHash()))
{
ClearPrioritisation(txConflict.GetHash());
removeRecursive(txConflict, MemPoolRemovalReason::CONFLICT);
removeRecursive(it->second, MemPoolRemovalReason::CONFLICT);
}
}
}

View File

@ -320,6 +320,11 @@ public:
*/
void check(const CCoinsViewCache& active_coins_tip, int64_t spendheight) const EXCLUSIVE_LOCKS_REQUIRED(::cs_main);
/**
* Remove a transaction from the mempool along with any descendants.
* If the transaction is not already in the mempool, find any descendants
* and remove them.
*/
void removeRecursive(const CTransaction& tx, MemPoolRemovalReason reason) EXCLUSIVE_LOCKS_REQUIRED(cs);
/** After reorg, filter the entries that would no longer be valid in the next block, and update
* the entries' cached LockPoints if needed. The mempool does not have any knowledge of
@ -581,6 +586,9 @@ private:
*/
void RemoveStaged(setEntries& stage, MemPoolRemovalReason reason) EXCLUSIVE_LOCKS_REQUIRED(cs);
/* Helper for the public removeRecursive() */
void removeRecursive(txiter to_remove, MemPoolRemovalReason reason) EXCLUSIVE_LOCKS_REQUIRED(cs);
/** Before calling removeUnchecked for a given transaction,
* UpdateForRemoveFromMempool must be called on the entire (dependent) set
* of transactions being removed at the same time. We use each