Merge pull request #590 from wpaulino/dependency-sort
wtxmgr: export dependencySort and add unit tests
This commit is contained in:
commit
b8074786d7
3 changed files with 186 additions and 34 deletions
|
@ -4,29 +4,33 @@
|
|||
|
||||
package wtxmgr
|
||||
|
||||
import "github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
import (
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
type graphNode struct {
|
||||
value *TxRecord
|
||||
value *wire.MsgTx
|
||||
outEdges []*chainhash.Hash
|
||||
inDegree int
|
||||
}
|
||||
|
||||
type hashGraph map[chainhash.Hash]graphNode
|
||||
|
||||
func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph {
|
||||
func makeGraph(set map[chainhash.Hash]*wire.MsgTx) hashGraph {
|
||||
graph := make(hashGraph)
|
||||
|
||||
for _, rec := range set {
|
||||
// Add a node for every transaction record. The output edges
|
||||
// and input degree are set by iterating over each record's
|
||||
// inputs below.
|
||||
if _, ok := graph[rec.Hash]; !ok {
|
||||
graph[rec.Hash] = graphNode{value: rec}
|
||||
for _, tx := range set {
|
||||
// Add a node for every transaction. The output edges and input
|
||||
// degree are set by iterating over each transaction's inputs
|
||||
// below.
|
||||
txHash := tx.TxHash()
|
||||
if _, ok := graph[txHash]; !ok {
|
||||
graph[txHash] = graphNode{value: tx}
|
||||
}
|
||||
|
||||
inputLoop:
|
||||
for _, input := range rec.MsgTx.TxIn {
|
||||
for _, input := range tx.TxIn {
|
||||
// Transaction inputs that reference transactions not
|
||||
// included in the set do not create any (local) graph
|
||||
// edges.
|
||||
|
@ -44,20 +48,20 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph {
|
|||
}
|
||||
|
||||
// Mark a directed edge from the previous transaction
|
||||
// hash to this transaction record and increase the
|
||||
// input degree for this record's node.
|
||||
inputRec := inputNode.value
|
||||
if inputRec == nil {
|
||||
inputRec = set[input.PreviousOutPoint.Hash]
|
||||
// hash to this transaction and increase the input
|
||||
// degree for this transaction's node.
|
||||
inputTx := inputNode.value
|
||||
if inputTx == nil {
|
||||
inputTx = set[input.PreviousOutPoint.Hash]
|
||||
}
|
||||
graph[input.PreviousOutPoint.Hash] = graphNode{
|
||||
value: inputRec,
|
||||
outEdges: append(inputNode.outEdges, &rec.Hash),
|
||||
value: inputTx,
|
||||
outEdges: append(inputNode.outEdges, &txHash),
|
||||
inDegree: inputNode.inDegree,
|
||||
}
|
||||
node := graph[rec.Hash]
|
||||
graph[rec.Hash] = graphNode{
|
||||
value: rec,
|
||||
node := graph[txHash]
|
||||
graph[txHash] = graphNode{
|
||||
value: tx,
|
||||
outEdges: node.outEdges,
|
||||
inDegree: node.inDegree + 1,
|
||||
}
|
||||
|
@ -69,8 +73,8 @@ func makeGraph(set map[chainhash.Hash]*TxRecord) hashGraph {
|
|||
|
||||
// graphRoots returns the roots of the graph. That is, it returns the node's
|
||||
// values for all nodes which contain an input degree of 0.
|
||||
func graphRoots(graph hashGraph) []*TxRecord {
|
||||
roots := make([]*TxRecord, 0, len(graph))
|
||||
func graphRoots(graph hashGraph) []*wire.MsgTx {
|
||||
roots := make([]*wire.MsgTx, 0, len(graph))
|
||||
for _, node := range graph {
|
||||
if node.inDegree == 0 {
|
||||
roots = append(roots, node.value)
|
||||
|
@ -79,9 +83,9 @@ func graphRoots(graph hashGraph) []*TxRecord {
|
|||
return roots
|
||||
}
|
||||
|
||||
// dependencySort topologically sorts a set of transaction records by their
|
||||
// dependency order. It is implemented using Kahn's algorithm.
|
||||
func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord {
|
||||
// DependencySort topologically sorts a set of transactions by their dependency
|
||||
// order. It is implemented using Kahn's algorithm.
|
||||
func DependencySort(txs map[chainhash.Hash]*wire.MsgTx) []*wire.MsgTx {
|
||||
graph := makeGraph(txs)
|
||||
s := graphRoots(graph)
|
||||
|
||||
|
@ -91,13 +95,13 @@ func dependencySort(txs map[chainhash.Hash]*TxRecord) []*TxRecord {
|
|||
return s
|
||||
}
|
||||
|
||||
sorted := make([]*TxRecord, 0, len(txs))
|
||||
sorted := make([]*wire.MsgTx, 0, len(txs))
|
||||
for len(s) != 0 {
|
||||
rec := s[0]
|
||||
tx := s[0]
|
||||
s = s[1:]
|
||||
sorted = append(sorted, rec)
|
||||
sorted = append(sorted, tx)
|
||||
|
||||
n := graph[rec.Hash]
|
||||
n := graph[tx.TxHash()]
|
||||
for _, mHash := range n.outEdges {
|
||||
m := graph[*mHash]
|
||||
if m.inDegree != 0 {
|
||||
|
|
148
wtxmgr/kahnsort_test.go
Normal file
148
wtxmgr/kahnsort_test.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
package wtxmgr_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/btcsuite/btcwallet/wtxmgr"
|
||||
)
|
||||
|
||||
// createTx is a helper method to create random transactions that spend
|
||||
// particular inputs.
|
||||
func createTx(t *testing.T, numOutputs int, inputs ...wire.OutPoint) *wire.MsgTx {
|
||||
t.Helper()
|
||||
|
||||
tx := wire.NewMsgTx(1)
|
||||
if len(inputs) == 0 {
|
||||
tx.AddTxIn(&wire.TxIn{})
|
||||
} else {
|
||||
for _, input := range inputs {
|
||||
tx.AddTxIn(&wire.TxIn{PreviousOutPoint: input})
|
||||
}
|
||||
}
|
||||
for i := 0; i < numOutputs; i++ {
|
||||
var pkScript [32]byte
|
||||
if _, err := rand.Read(pkScript[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tx.AddTxOut(&wire.TxOut{
|
||||
Value: rand.Int63(),
|
||||
PkScript: pkScript[:],
|
||||
})
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// getOutPoint returns the outpoint for the output with the given index in the
|
||||
// transaction.
|
||||
func getOutPoint(tx *wire.MsgTx, index uint32) wire.OutPoint {
|
||||
return wire.OutPoint{Hash: tx.TxHash(), Index: index}
|
||||
}
|
||||
|
||||
// TestDependencySort ensures that transactions are topologically sorted by
|
||||
// their dependency order under multiple scenarios. A transaction (a) can depend
|
||||
// on another (b) as long as (a) spends an output created in (b).
|
||||
func TestDependencySort(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
// setup is in charge of setting the dependency graph and
|
||||
// returning the transactions in their expected sorted order.
|
||||
setup func(t *testing.T) []*wire.MsgTx
|
||||
}{
|
||||
{
|
||||
name: "single dependency chain",
|
||||
setup: func(t *testing.T) []*wire.MsgTx {
|
||||
// a -> b -> c
|
||||
a := createTx(t, 1)
|
||||
b := createTx(t, 1, getOutPoint(a, 0))
|
||||
c := createTx(t, 1, getOutPoint(b, 0))
|
||||
return []*wire.MsgTx{a, b, c}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "double dependency chain",
|
||||
setup: func(t *testing.T) []*wire.MsgTx {
|
||||
// a -> b
|
||||
// a -> c
|
||||
// c -> d
|
||||
// d -> b
|
||||
a := createTx(t, 2)
|
||||
c := createTx(t, 1, getOutPoint(a, 1))
|
||||
d := createTx(t, 1, getOutPoint(c, 0))
|
||||
b := createTx(t, 1, getOutPoint(a, 0), getOutPoint(d, 0))
|
||||
return []*wire.MsgTx{a, c, d, b}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multi dependency chain",
|
||||
setup: func(t *testing.T) []*wire.MsgTx {
|
||||
// a -> e
|
||||
// a -> c
|
||||
// e -> c
|
||||
// c -> g
|
||||
// a -> b
|
||||
// g -> b
|
||||
// e -> f
|
||||
// c -> f
|
||||
// g -> f
|
||||
// b -> f
|
||||
// b -> d
|
||||
// f -> d
|
||||
a := createTx(t, 3)
|
||||
|
||||
a0 := getOutPoint(a, 0)
|
||||
e := createTx(t, 2, a0)
|
||||
|
||||
a1 := getOutPoint(a, 1)
|
||||
e0 := getOutPoint(e, 0)
|
||||
c := createTx(t, 2, a1, e0)
|
||||
|
||||
c0 := getOutPoint(c, 0)
|
||||
g := createTx(t, 2, c0)
|
||||
|
||||
a2 := getOutPoint(a, 2)
|
||||
g0 := getOutPoint(g, 0)
|
||||
b := createTx(t, 1, a2, g0)
|
||||
|
||||
e1 := getOutPoint(e, 1)
|
||||
c1 := getOutPoint(c, 1)
|
||||
g1 := getOutPoint(g, 1)
|
||||
b0 := getOutPoint(b, 0)
|
||||
f := createTx(t, 1, e1, c1, g1, b0)
|
||||
|
||||
f0 := getOutPoint(f, 0)
|
||||
d := createTx(t, 1, b0, f0)
|
||||
|
||||
return []*wire.MsgTx{a, e, c, g, b, f, d}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
exp := test.setup(t)
|
||||
|
||||
txSet := make(map[chainhash.Hash]*wire.MsgTx, len(exp))
|
||||
for _, tx := range exp {
|
||||
txSet[tx.TxHash()] = tx
|
||||
}
|
||||
|
||||
sortedTxs := wtxmgr.DependencySort(txSet)
|
||||
|
||||
if !reflect.DeepEqual(sortedTxs, exp) {
|
||||
t.Fatalf("expected %v, got %v", exp, sortedTxs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -164,12 +164,12 @@ func (s *Store) UnminedTxs(ns walletdb.ReadBucket) ([]*wire.MsgTx, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
recs := dependencySort(recSet)
|
||||
txs := make([]*wire.MsgTx, 0, len(recs))
|
||||
for _, rec := range recs {
|
||||
txs = append(txs, &rec.MsgTx)
|
||||
txSet := make(map[chainhash.Hash]*wire.MsgTx, len(recSet))
|
||||
for txHash, txRec := range recSet {
|
||||
txSet[txHash] = &txRec.MsgTx
|
||||
}
|
||||
return txs, nil
|
||||
|
||||
return DependencySort(txSet), nil
|
||||
}
|
||||
|
||||
func (s *Store) unminedTxRecords(ns walletdb.ReadBucket) (map[chainhash.Hash]*TxRecord, error) {
|
||||
|
|
Loading…
Add table
Reference in a new issue