diff --git a/wtxmgr/kahnsort.go b/wtxmgr/kahnsort.go index f6a85b6..1225e4a 100644 --- a/wtxmgr/kahnsort.go +++ b/wtxmgr/kahnsort.go @@ -4,29 +4,33 @@ package wtxmgr -import "github.com/btcsuite/btcd/chaincfg/chainhash" +import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) type graphNode struct { - value *TxRecord + value *wire.MsgTx outEdges []*chainhash.Hash inDegree int } type hashGraph map[chainhash.Hash]graphNode -func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph { +func makeGraph(set map[chainhash.Hash]*wire.MsgTx) hashGraph { graph := make(hashGraph) - for _, rec := range set { - // Add a node for every transaction record. The output edges - // and input degree are set by iterating over each record's - // inputs below. - if _, ok := graph[rec.Hash]; !ok { - graph[rec.Hash] = graphNode{value: rec} + for _, tx := range set { + // Add a node for every transaction. The output edges and input + // degree are set by iterating over each transaction's inputs + // below. + txHash := tx.TxHash() + if _, ok := graph[txHash]; !ok { + graph[txHash] = graphNode{value: tx} } inputLoop: - for _, input := range rec.MsgTx.TxIn { + for _, input := range tx.TxIn { // Transaction inputs that reference transactions not // included in the set do not create any (local) graph // edges. @@ -44,20 +48,20 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph { } // Mark a directed edge from the previous transaction - // hash to this transaction record and increase the - // input degree for this record's node. - inputRec := inputNode.value - if inputRec == nil { - inputRec = set[input.PreviousOutPoint.Hash] + // hash to this transaction and increase the input + // degree for this transaction's node. + inputTx := inputNode.value + if inputTx == nil { + inputTx = set[input.PreviousOutPoint.Hash] } graph[input.PreviousOutPoint.Hash] = graphNode{ - value: inputRec, - outEdges: append(inputNode.outEdges, &rec.Hash), + value: inputTx, + outEdges: append(inputNode.outEdges, &txHash), inDegree: inputNode.inDegree, } - node := graph[rec.Hash] - graph[rec.Hash] = graphNode{ - value: rec, + node := graph[txHash] + graph[txHash] = graphNode{ + value: tx, outEdges: node.outEdges, inDegree: node.inDegree + 1, } @@ -69,8 +73,8 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph { // graphRoots returns the roots of the graph. That is, it returns the node's // values for all nodes which contain an input degree of 0. -func graphRoots(graph hashGraph) []*TxRecord { - roots := make([]*TxRecord, 0, len(graph)) +func graphRoots(graph hashGraph) []*wire.MsgTx { + roots := make([]*wire.MsgTx, 0, len(graph)) for _, node := range graph { if node.inDegree == 0 { roots = append(roots, node.value) @@ -79,9 +83,9 @@ func graphRoots(graph hashGraph) []*TxRecord { return roots } -// dependencySort topologically sorts a set of transaction records by their -// dependency order. It is implemented using Kahn's algorithm. -func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord { +// dependencySort topologically sorts a set of transactions by their dependency +// order. It is implemented using Kahn's algorithm. +func dependencySort(txs map[chainhash.Hash]*wire.MsgTx) []*wire.MsgTx { graph := makeGraph(txs) s := graphRoots(graph) @@ -91,13 +95,13 @@ func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord { return s } - sorted := make([]*TxRecord, 0, len(txs)) + sorted := make([]*wire.MsgTx, 0, len(txs)) for len(s) != 0 { - rec := s[0] + tx := s[0] s = s[1:] - sorted = append(sorted, rec) + sorted = append(sorted, tx) - n := graph[rec.Hash] + n := graph[tx.TxHash()] for _, mHash := range n.outEdges { m := graph[*mHash] if m.inDegree != 0 { diff --git a/wtxmgr/unconfirmed.go b/wtxmgr/unconfirmed.go index 37bf7e1..3e76387 100644 --- a/wtxmgr/unconfirmed.go +++ b/wtxmgr/unconfirmed.go @@ -164,12 +164,12 @@ func (s *Store) UnminedTxs(ns walletdb.ReadBucket) ([]*wire.MsgTx, error) { return nil, err } - recs := dependencySort(recSet) - txs := make([]*wire.MsgTx, 0, len(recs)) - for _, rec := range recs { - txs = append(txs, &rec.MsgTx) + txSet := make(map[chainhash.Hash]*wire.MsgTx, len(recSet)) + for txHash, txRec := range recSet { + txSet[txHash] = &txRec.MsgTx } - return txs, nil + + return dependencySort(txSet), nil } func (s *Store) unminedTxRecords(ns walletdb.ReadBucket) (map[chainhash.Hash]*TxRecord, error) {