From c7d523f83ccb19ea4766e6523685105bd6c76e70 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 14 May 2019 22:46:59 -0700 Subject: [PATCH] btcec/pubkey: optimize decompressPoint using fieldVals This commit optimizes the decompressPoint subroutine, used in extracting compressed pubkeys and performing pubkey recovery. We do so by replacing the use of big.Int.Exp with with square-and-multiply exponentiation of btcec's more optimized fieldVals, reducing the overall latency and memory requirements of decompressPoint. Instead of operating on bits of Q = (P+1)/4, the exponentiation applies the square-and-multiply operations on full bytes of Q. Compared to the original speedup. Compared the bit-wise version, the improvement is roughly 10%. A new pair fieldVal methods called Sqrt and SqrtVal are added, which applies the square-and-multiply exponentiation using precomputed byte-slice of the value Q. Comparison against big.Int sqrt and SAM sqrt over bytes of Q: benchmark old ns/op new ns/op delta BenchmarkParseCompressedPubKey-8 35545 23119 -34.96% benchmark old allocs new allocs delta BenchmarkParseCompressedPubKey-8 35 6 -82.86% benchmark old bytes new bytes delta BenchmarkParseCompressedPubKey-8 2777 256 -90.78% --- btcec/btcec.go | 10 ++- btcec/field.go | 129 +++++++++++++++++++++++++++++++++++++++ btcec/field_test.go | 145 ++++++++++++++++++++++++++++++++++++++++++++ btcec/pubkey.go | 36 +++++------ 4 files changed, 301 insertions(+), 19 deletions(-) diff --git a/btcec/btcec.go b/btcec/btcec.go index 5e7ce875..2cb1d929 100644 --- a/btcec/btcec.go +++ b/btcec/btcec.go @@ -36,10 +36,17 @@ var ( // interface from crypto/elliptic. type KoblitzCurve struct { *elliptic.CurveParams - q *big.Int + + // q is the value (P+1)/4 used to compute the square root of field + // elements. + q *big.Int + H int // cofactor of the curve. halfOrder *big.Int // half the order N + // fieldB is the constant B of the curve as a fieldVal. + fieldB *fieldVal + // byteSize is simply the bit size / 8 and is provided for convenience // since it is calculated repeatedly. byteSize int @@ -917,6 +924,7 @@ func initS256() { big.NewInt(1)), big.NewInt(4)) secp256k1.H = 1 secp256k1.halfOrder = new(big.Int).Rsh(secp256k1.N, 1) + secp256k1.fieldB = new(fieldVal).SetByteSlice(secp256k1.B.Bytes()) // Provided for convenience since this gets computed repeatedly. secp256k1.byteSize = secp256k1.BitSize / 8 diff --git a/btcec/field.go b/btcec/field.go index 0f2be74c..c2bb84b3 100644 --- a/btcec/field.go +++ b/btcec/field.go @@ -102,6 +102,20 @@ const ( fieldPrimeWordOne = 0x3ffffbf ) +var ( + // fieldQBytes is the value Q = (P+1)/4 for the secp256k1 prime P. This + // value is used to efficiently compute the square root of values in the + // field via exponentiation. The value of Q in hex is: + // + // Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c + fieldQBytes = []byte{ + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c, + } +) + // fieldVal implements optimized fixed-precision arithmetic over the // secp256k1 finite field. This means all arithmetic is performed modulo // 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. It @@ -1221,3 +1235,118 @@ func (f *fieldVal) Inverse() *fieldVal { f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320) return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2) } + +// SqrtVal computes the square root of x modulo the curve's prime, and stores +// the result in f. The square root is computed via exponentiation of x by the +// value Q = (P+1)/4 using the curve's precomputed big-endian representation of +// the Q. This method uses a modified version of square-and-multiply +// exponentiation over secp256k1 fieldVals to operate on bytes instead of bits, +// which offers better performance over both big.Int exponentiation and bit-wise +// square-and-multiply. +// +// NOTE: This method only works when P is intended to be the secp256k1 prime and +// is not constant time. The returned value is of magnitude 1, but is +// denormalized. +func (f *fieldVal) SqrtVal(x *fieldVal) *fieldVal { + // The following computation iteratively computes x^((P+1)/4) = x^Q + // using the recursive, piece-wise definition: + // + // x^n = (x^2)^(n/2) mod P if n is even + // x^n = x(x^2)^(n-1/2) mod P if n is odd + // + // Given n in its big-endian representation b_k, ..., b_0, x^n can be + // computed by defining the sequence r_k+1, ..., r_0, where: + // + // r_k+1 = 1 + // r_i = (r_i+1)^2 * x^b_i for i = k, ..., 0 + // + // The final value r_0 = x^n. + // + // See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more + // details. + // + // This can be further optimized, by observing that the value of Q in + // secp256k1 has the value: + // + // Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c + // + // We can unroll the typical bit-wise interpretation of the + // exponentiation algorithm above to instead operate on bytes. + // This reduces the number of comparisons by an order of magnitude, + // reducing the overhead of failed branch predictions and additional + // comparisons in this method. + // + // Since there there are only 4 unique bytes of Q, this keeps the jump + // table small without the need to handle all possible 8-bit values. + // Further, we observe that 29 of the 32 bytes are 0xff; making the + // first case handle 0xff therefore optimizes the hot path. + f.SetInt(1) + for _, b := range fieldQBytes { + switch b { + + // Most common case, where all 8 bits are set. + case 0xff: + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // First byte of Q (0x3f), where all but the top two bits are + // set. Note that this case only applies six operations, since + // the highest bit of Q resides in bit six of the first byte. We + // ignore the first two bits, since squaring for these bits will + // result in an invalid result. We forgo squaring f before the + // first multiply, since 1^2 = 1. + case 0x3f: + f.Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // Byte 28 of Q (0xbf), where only bit 7 is unset. + case 0xbf: + f.Square().Mul(x) + f.Square() + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // Byte 31 of Q (0x0c), where only bits 3 and 4 are set. + default: + f.Square() + f.Square() + f.Square() + f.Square() + f.Square().Mul(x) + f.Square().Mul(x) + f.Square() + f.Square() + } + } + + return f +} + +// Sqrt computes the square root of f modulo the curve's prime, and stores the +// result in f. The square root is computed via exponentiation of x by the value +// Q = (P+1)/4 using the curve's precomputed big-endian representation of the Q. +// This method uses a modified version of square-and-multiply exponentiation +// over secp256k1 fieldVals to operate on bytes instead of bits, which offers +// better performance over both big.Int exponentiation and bit-wise +// square-and-multiply. +// +// NOTE: This method only works when P is intended to be the secp256k1 prime and +// is not constant time. The returned value is of magnitude 1, but is +// denormalized. +func (f *fieldVal) Sqrt() *fieldVal { + return f.SqrtVal(f) +} diff --git a/btcec/field_test.go b/btcec/field_test.go index dcfb7049..27b9730f 100644 --- a/btcec/field_test.go +++ b/btcec/field_test.go @@ -6,6 +6,8 @@ package btcec import ( + "crypto/rand" + "fmt" "reflect" "testing" ) @@ -820,3 +822,146 @@ func TestInverse(t *testing.T) { } } } + +// randFieldVal returns a random, normalized element in the field. +func randFieldVal(t *testing.T) fieldVal { + var b [32]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatalf("unable to create random element: %v", err) + } + + var x fieldVal + return *x.SetBytes(&b).Normalize() +} + +type sqrtTest struct { + name string + in string + expected string +} + +// TestSqrt asserts that a fieldVal properly computes the square root modulo the +// sep256k1 prime. +func TestSqrt(t *testing.T) { + var tests []sqrtTest + + // No valid root exists for the negative of a square. + for i := uint(9); i > 0; i-- { + var ( + x fieldVal + s fieldVal // x^2 mod p + n fieldVal // -x^2 mod p + ) + + x.SetInt(i) + s.SquareVal(&x).Normalize() + n.NegateVal(&s, 1).Normalize() + + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("-%d", i), + in: fmt.Sprintf("%x", *n.Bytes()), + }) + } + + // A root should exist for true squares. + for i := uint(0); i < 10; i++ { + var ( + x fieldVal + s fieldVal // x^2 mod p + ) + + x.SetInt(i) + s.SquareVal(&x).Normalize() + + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("%d", i), + in: fmt.Sprintf("%x", *s.Bytes()), + expected: fmt.Sprintf("%x", *x.Bytes()), + }) + } + + // Compute a non-square element, by negating if it has a root. + ns := randFieldVal(t) + if new(fieldVal).SqrtVal(&ns).Square().Equals(&ns) { + ns.Negate(1).Normalize() + } + + // For large random field values, test that: + // 1) its square has a valid root. + // 2) the negative of its square has no root. + // 3) the product of its square with a non-square has no root. + for i := 0; i < 10; i++ { + var ( + x fieldVal + s fieldVal // x^2 mod p + n fieldVal // -x^2 mod p + m fieldVal // ns*x^2 mod p + ) + + x = randFieldVal(t) + s.SquareVal(&x).Normalize() + n.NegateVal(&s, 1).Normalize() + m.Mul2(&s, &ns).Normalize() + + // A root should exist for true squares. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("%x", *s.Bytes()), + in: fmt.Sprintf("%x", *s.Bytes()), + expected: fmt.Sprintf("%x", *x.Bytes()), + }) + + // No valid root exists for the negative of a square. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("-%x", *s.Bytes()), + in: fmt.Sprintf("%x", *n.Bytes()), + }) + + // No root should be computed for product of a square and + // non-square. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("ns*%x", *s.Bytes()), + in: fmt.Sprintf("%x", *m.Bytes()), + }) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testSqrt(t, test) + }) + } +} + +func testSqrt(t *testing.T, test sqrtTest) { + var ( + f fieldVal + root fieldVal + rootNeg fieldVal + ) + + f.SetHex(test.in).Normalize() + + // Compute sqrt(f) and its negative. + root.SqrtVal(&f).Normalize() + rootNeg.NegateVal(&root, 1).Normalize() + + switch { + + // If we expect a square root, verify that either the computed square + // root is +/- the expected value. + case len(test.expected) > 0: + var expected fieldVal + expected.SetHex(test.expected).Normalize() + if !root.Equals(&expected) && !rootNeg.Equals(&expected) { + t.Fatalf("fieldVal.Sqrt incorrect root\n"+ + "got: %v\ngot_neg: %v\nwant: %v", + root, rootNeg, expected) + } + + // Otherwise, we expect this input not to have a square root. + default: + if root.Square().Equals(&f) || rootNeg.Square().Equals(&f) { + t.Fatalf("fieldVal.Sqrt root should not exist\n"+ + "got: %v\ngot_neg: %v", root, rootNeg) + } + } +} diff --git a/btcec/pubkey.go b/btcec/pubkey.go index a6a492e7..c72f8705 100644 --- a/btcec/pubkey.go +++ b/btcec/pubkey.go @@ -22,41 +22,41 @@ func isOdd(a *big.Int) bool { return a.Bit(0) == 1 } -// decompressPoint decompresses a point on the given curve given the X point and +// decompressPoint decompresses a point on the secp256k1 curve given the X point and // the solution to use. -func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, error) { - // TODO: This will probably only work for secp256k1 due to - // optimizations. +func decompressPoint(curve *KoblitzCurve, bigX *big.Int, ybit bool) (*big.Int, error) { + var x fieldVal + x.SetByteSlice(bigX.Bytes()) - // Y = +-sqrt(x^3 + B) - x3 := new(big.Int).Mul(x, x) - x3.Mul(x3, x) - x3.Add(x3, curve.Params().B) - x3.Mod(x3, curve.Params().P) + // Compute x^3 + B mod p. + var x3 fieldVal + x3.SquareVal(&x).Mul(&x) + x3.Add(curve.fieldB).Normalize() // Now calculate sqrt mod p of x^3 + B // This code used to do a full sqrt based on tonelli/shanks, // but this was replaced by the algorithms referenced in // https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294 - y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P) - - if ybit != isOdd(y) { - y.Sub(curve.Params().P, y) + var y fieldVal + y.SqrtVal(&x3) + if ybit != y.IsOdd() { + y.Negate(1) } + y.Normalize() // Check that y is a square root of x^3 + B. - y2 := new(big.Int).Mul(y, y) - y2.Mod(y2, curve.Params().P) - if y2.Cmp(x3) != 0 { + var y2 fieldVal + y2.SquareVal(&y).Normalize() + if !y2.Equals(&x3) { return nil, fmt.Errorf("invalid square root") } // Verify that y-coord has expected parity. - if ybit != isOdd(y) { + if ybit != y.IsOdd() { return nil, fmt.Errorf("ybit doesn't match oddness") } - return y, nil + return new(big.Int).SetBytes(y.Bytes()[:]), nil } const (