diff --git a/btcec/btcec.go b/btcec/btcec.go index 5e7ce875..2cb1d929 100644 --- a/btcec/btcec.go +++ b/btcec/btcec.go @@ -36,10 +36,17 @@ var ( // interface from crypto/elliptic. type KoblitzCurve struct { *elliptic.CurveParams - q *big.Int + + // q is the value (P+1)/4 used to compute the square root of field + // elements. + q *big.Int + H int // cofactor of the curve. halfOrder *big.Int // half the order N + // fieldB is the constant B of the curve as a fieldVal. + fieldB *fieldVal + // byteSize is simply the bit size / 8 and is provided for convenience // since it is calculated repeatedly. byteSize int @@ -917,6 +924,7 @@ func initS256() { big.NewInt(1)), big.NewInt(4)) secp256k1.H = 1 secp256k1.halfOrder = new(big.Int).Rsh(secp256k1.N, 1) + secp256k1.fieldB = new(fieldVal).SetByteSlice(secp256k1.B.Bytes()) // Provided for convenience since this gets computed repeatedly. secp256k1.byteSize = secp256k1.BitSize / 8 diff --git a/btcec/field.go b/btcec/field.go index 0f2be74c..c2bb84b3 100644 --- a/btcec/field.go +++ b/btcec/field.go @@ -102,6 +102,20 @@ const ( fieldPrimeWordOne = 0x3ffffbf ) +var ( + // fieldQBytes is the value Q = (P+1)/4 for the secp256k1 prime P. This + // value is used to efficiently compute the square root of values in the + // field via exponentiation. The value of Q in hex is: + // + // Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c + fieldQBytes = []byte{ + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c, + } +) + // fieldVal implements optimized fixed-precision arithmetic over the // secp256k1 finite field. This means all arithmetic is performed modulo // 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. It @@ -1221,3 +1235,118 @@ func (f *fieldVal) Inverse() *fieldVal { f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320) return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2) } + +// SqrtVal computes the square root of x modulo the curve's prime, and stores +// the result in f. The square root is computed via exponentiation of x by the +// value Q = (P+1)/4 using the curve's precomputed big-endian representation of +// the Q. This method uses a modified version of square-and-multiply +// exponentiation over secp256k1 fieldVals to operate on bytes instead of bits, +// which offers better performance over both big.Int exponentiation and bit-wise +// square-and-multiply. +// +// NOTE: This method only works when P is intended to be the secp256k1 prime and +// is not constant time. The returned value is of magnitude 1, but is +// denormalized. +func (f *fieldVal) SqrtVal(x *fieldVal) *fieldVal { + // The following computation iteratively computes x^((P+1)/4) = x^Q + // using the recursive, piece-wise definition: + // + // x^n = (x^2)^(n/2) mod P if n is even + // x^n = x(x^2)^(n-1/2) mod P if n is odd + // + // Given n in its big-endian representation b_k, ..., b_0, x^n can be + // computed by defining the sequence r_k+1, ..., r_0, where: + // + // r_k+1 = 1 + // r_i = (r_i+1)^2 * x^b_i for i = k, ..., 0 + // + // The final value r_0 = x^n. + // + // See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more + // details. + // + // This can be further optimized, by observing that the value of Q in + // secp256k1 has the value: + // + // Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c + // + // We can unroll the typical bit-wise interpretation of the + // exponentiation algorithm above to instead operate on bytes. + // This reduces the number of comparisons by an order of magnitude, + // reducing the overhead of failed branch predictions and additional + // comparisons in this method. + // + // Since there there are only 4 unique bytes of Q, this keeps the jump + // table small without the need to handle all possible 8-bit values. + // Further, we observe that 29 of the 32 bytes are 0xff; making the + // first case handle 0xff therefore optimizes the hot path. + f.SetInt(1) + for _, b := range fieldQBytes { + switch b { + + // Most common case, where all 8 bits are set. + case 0xff: + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // First byte of Q (0x3f), where all but the top two bits are + // set. Note that this case only applies six operations, since + // the highest bit of Q resides in bit six of the first byte. We + // ignore the first two bits, since squaring for these bits will + // result in an invalid result. We forgo squaring f before the + // first multiply, since 1^2 = 1. + case 0x3f: + f.Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // Byte 28 of Q (0xbf), where only bit 7 is unset. + case 0xbf: + f.Square().Mul(x) + f.Square() + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + f.Square().Mul(x) + + // Byte 31 of Q (0x0c), where only bits 3 and 4 are set. + default: + f.Square() + f.Square() + f.Square() + f.Square() + f.Square().Mul(x) + f.Square().Mul(x) + f.Square() + f.Square() + } + } + + return f +} + +// Sqrt computes the square root of f modulo the curve's prime, and stores the +// result in f. The square root is computed via exponentiation of x by the value +// Q = (P+1)/4 using the curve's precomputed big-endian representation of the Q. +// This method uses a modified version of square-and-multiply exponentiation +// over secp256k1 fieldVals to operate on bytes instead of bits, which offers +// better performance over both big.Int exponentiation and bit-wise +// square-and-multiply. +// +// NOTE: This method only works when P is intended to be the secp256k1 prime and +// is not constant time. The returned value is of magnitude 1, but is +// denormalized. +func (f *fieldVal) Sqrt() *fieldVal { + return f.SqrtVal(f) +} diff --git a/btcec/field_test.go b/btcec/field_test.go index dcfb7049..27b9730f 100644 --- a/btcec/field_test.go +++ b/btcec/field_test.go @@ -6,6 +6,8 @@ package btcec import ( + "crypto/rand" + "fmt" "reflect" "testing" ) @@ -820,3 +822,146 @@ func TestInverse(t *testing.T) { } } } + +// randFieldVal returns a random, normalized element in the field. +func randFieldVal(t *testing.T) fieldVal { + var b [32]byte + if _, err := rand.Read(b[:]); err != nil { + t.Fatalf("unable to create random element: %v", err) + } + + var x fieldVal + return *x.SetBytes(&b).Normalize() +} + +type sqrtTest struct { + name string + in string + expected string +} + +// TestSqrt asserts that a fieldVal properly computes the square root modulo the +// sep256k1 prime. +func TestSqrt(t *testing.T) { + var tests []sqrtTest + + // No valid root exists for the negative of a square. + for i := uint(9); i > 0; i-- { + var ( + x fieldVal + s fieldVal // x^2 mod p + n fieldVal // -x^2 mod p + ) + + x.SetInt(i) + s.SquareVal(&x).Normalize() + n.NegateVal(&s, 1).Normalize() + + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("-%d", i), + in: fmt.Sprintf("%x", *n.Bytes()), + }) + } + + // A root should exist for true squares. + for i := uint(0); i < 10; i++ { + var ( + x fieldVal + s fieldVal // x^2 mod p + ) + + x.SetInt(i) + s.SquareVal(&x).Normalize() + + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("%d", i), + in: fmt.Sprintf("%x", *s.Bytes()), + expected: fmt.Sprintf("%x", *x.Bytes()), + }) + } + + // Compute a non-square element, by negating if it has a root. + ns := randFieldVal(t) + if new(fieldVal).SqrtVal(&ns).Square().Equals(&ns) { + ns.Negate(1).Normalize() + } + + // For large random field values, test that: + // 1) its square has a valid root. + // 2) the negative of its square has no root. + // 3) the product of its square with a non-square has no root. + for i := 0; i < 10; i++ { + var ( + x fieldVal + s fieldVal // x^2 mod p + n fieldVal // -x^2 mod p + m fieldVal // ns*x^2 mod p + ) + + x = randFieldVal(t) + s.SquareVal(&x).Normalize() + n.NegateVal(&s, 1).Normalize() + m.Mul2(&s, &ns).Normalize() + + // A root should exist for true squares. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("%x", *s.Bytes()), + in: fmt.Sprintf("%x", *s.Bytes()), + expected: fmt.Sprintf("%x", *x.Bytes()), + }) + + // No valid root exists for the negative of a square. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("-%x", *s.Bytes()), + in: fmt.Sprintf("%x", *n.Bytes()), + }) + + // No root should be computed for product of a square and + // non-square. + tests = append(tests, sqrtTest{ + name: fmt.Sprintf("ns*%x", *s.Bytes()), + in: fmt.Sprintf("%x", *m.Bytes()), + }) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testSqrt(t, test) + }) + } +} + +func testSqrt(t *testing.T, test sqrtTest) { + var ( + f fieldVal + root fieldVal + rootNeg fieldVal + ) + + f.SetHex(test.in).Normalize() + + // Compute sqrt(f) and its negative. + root.SqrtVal(&f).Normalize() + rootNeg.NegateVal(&root, 1).Normalize() + + switch { + + // If we expect a square root, verify that either the computed square + // root is +/- the expected value. + case len(test.expected) > 0: + var expected fieldVal + expected.SetHex(test.expected).Normalize() + if !root.Equals(&expected) && !rootNeg.Equals(&expected) { + t.Fatalf("fieldVal.Sqrt incorrect root\n"+ + "got: %v\ngot_neg: %v\nwant: %v", + root, rootNeg, expected) + } + + // Otherwise, we expect this input not to have a square root. + default: + if root.Square().Equals(&f) || rootNeg.Square().Equals(&f) { + t.Fatalf("fieldVal.Sqrt root should not exist\n"+ + "got: %v\ngot_neg: %v", root, rootNeg) + } + } +} diff --git a/btcec/pubkey.go b/btcec/pubkey.go index a6a492e7..c72f8705 100644 --- a/btcec/pubkey.go +++ b/btcec/pubkey.go @@ -22,41 +22,41 @@ func isOdd(a *big.Int) bool { return a.Bit(0) == 1 } -// decompressPoint decompresses a point on the given curve given the X point and +// decompressPoint decompresses a point on the secp256k1 curve given the X point and // the solution to use. -func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, error) { - // TODO: This will probably only work for secp256k1 due to - // optimizations. +func decompressPoint(curve *KoblitzCurve, bigX *big.Int, ybit bool) (*big.Int, error) { + var x fieldVal + x.SetByteSlice(bigX.Bytes()) - // Y = +-sqrt(x^3 + B) - x3 := new(big.Int).Mul(x, x) - x3.Mul(x3, x) - x3.Add(x3, curve.Params().B) - x3.Mod(x3, curve.Params().P) + // Compute x^3 + B mod p. + var x3 fieldVal + x3.SquareVal(&x).Mul(&x) + x3.Add(curve.fieldB).Normalize() // Now calculate sqrt mod p of x^3 + B // This code used to do a full sqrt based on tonelli/shanks, // but this was replaced by the algorithms referenced in // https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294 - y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P) - - if ybit != isOdd(y) { - y.Sub(curve.Params().P, y) + var y fieldVal + y.SqrtVal(&x3) + if ybit != y.IsOdd() { + y.Negate(1) } + y.Normalize() // Check that y is a square root of x^3 + B. - y2 := new(big.Int).Mul(y, y) - y2.Mod(y2, curve.Params().P) - if y2.Cmp(x3) != 0 { + var y2 fieldVal + y2.SquareVal(&y).Normalize() + if !y2.Equals(&x3) { return nil, fmt.Errorf("invalid square root") } // Verify that y-coord has expected parity. - if ybit != isOdd(y) { + if ybit != y.IsOdd() { return nil, fmt.Errorf("ybit doesn't match oddness") } - return y, nil + return new(big.Int).SetBytes(y.Bytes()[:]), nil } const (