btcec/pubkey: optimize decompressPoint using fieldVals

This commit optimizes the decompressPoint subroutine, used in extracting
compressed pubkeys and performing pubkey recovery. We do so by replacing
the use of big.Int.Exp with with square-and-multiply exponentiation of
btcec's more optimized fieldVals, reducing the overall latency and
memory requirements of decompressPoint.

Instead of operating on bits of Q = (P+1)/4, the exponentiation applies
the square-and-multiply operations on full bytes of Q.  Compared to the
original speedup. Compared the bit-wise version, the improvement is
roughly 10%.

A new pair fieldVal methods called Sqrt and SqrtVal are added, which
applies the square-and-multiply exponentiation using precomputed
byte-slice of the value Q.

Comparison against big.Int sqrt and SAM sqrt over bytes of Q:

benchmark                            old ns/op     new ns/op     delta
BenchmarkParseCompressedPubKey-8     35545         23119         -34.96%

benchmark                            old allocs     new allocs     delta
BenchmarkParseCompressedPubKey-8     35             6            -82.86%

benchmark                            old bytes     new bytes     delta
BenchmarkParseCompressedPubKey-8     2777          256           -90.78%
This commit is contained in:
Conner Fromknecht 2019-05-14 22:46:59 -07:00
parent 39500ed5ed
commit c7d523f83c
No known key found for this signature in database
GPG key ID: E7D737B67FA592C7
4 changed files with 301 additions and 19 deletions

View file

@ -36,10 +36,17 @@ var (
// interface from crypto/elliptic.
type KoblitzCurve struct {
*elliptic.CurveParams
q *big.Int
// q is the value (P+1)/4 used to compute the square root of field
// elements.
q *big.Int
H int // cofactor of the curve.
halfOrder *big.Int // half the order N
// fieldB is the constant B of the curve as a fieldVal.
fieldB *fieldVal
// byteSize is simply the bit size / 8 and is provided for convenience
// since it is calculated repeatedly.
byteSize int
@ -917,6 +924,7 @@ func initS256() {
big.NewInt(1)), big.NewInt(4))
secp256k1.H = 1
secp256k1.halfOrder = new(big.Int).Rsh(secp256k1.N, 1)
secp256k1.fieldB = new(fieldVal).SetByteSlice(secp256k1.B.Bytes())
// Provided for convenience since this gets computed repeatedly.
secp256k1.byteSize = secp256k1.BitSize / 8

View file

@ -102,6 +102,20 @@ const (
fieldPrimeWordOne = 0x3ffffbf
)
var (
// fieldQBytes is the value Q = (P+1)/4 for the secp256k1 prime P. This
// value is used to efficiently compute the square root of values in the
// field via exponentiation. The value of Q in hex is:
//
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
fieldQBytes = []byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c,
}
)
// fieldVal implements optimized fixed-precision arithmetic over the
// secp256k1 finite field. This means all arithmetic is performed modulo
// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. It
@ -1221,3 +1235,118 @@ func (f *fieldVal) Inverse() *fieldVal {
f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320)
return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2)
}
// SqrtVal computes the square root of x modulo the curve's prime, and stores
// the result in f. The square root is computed via exponentiation of x by the
// value Q = (P+1)/4 using the curve's precomputed big-endian representation of
// the Q. This method uses a modified version of square-and-multiply
// exponentiation over secp256k1 fieldVals to operate on bytes instead of bits,
// which offers better performance over both big.Int exponentiation and bit-wise
// square-and-multiply.
//
// NOTE: This method only works when P is intended to be the secp256k1 prime and
// is not constant time. The returned value is of magnitude 1, but is
// denormalized.
func (f *fieldVal) SqrtVal(x *fieldVal) *fieldVal {
// The following computation iteratively computes x^((P+1)/4) = x^Q
// using the recursive, piece-wise definition:
//
// x^n = (x^2)^(n/2) mod P if n is even
// x^n = x(x^2)^(n-1/2) mod P if n is odd
//
// Given n in its big-endian representation b_k, ..., b_0, x^n can be
// computed by defining the sequence r_k+1, ..., r_0, where:
//
// r_k+1 = 1
// r_i = (r_i+1)^2 * x^b_i for i = k, ..., 0
//
// The final value r_0 = x^n.
//
// See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more
// details.
//
// This can be further optimized, by observing that the value of Q in
// secp256k1 has the value:
//
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
//
// We can unroll the typical bit-wise interpretation of the
// exponentiation algorithm above to instead operate on bytes.
// This reduces the number of comparisons by an order of magnitude,
// reducing the overhead of failed branch predictions and additional
// comparisons in this method.
//
// Since there there are only 4 unique bytes of Q, this keeps the jump
// table small without the need to handle all possible 8-bit values.
// Further, we observe that 29 of the 32 bytes are 0xff; making the
// first case handle 0xff therefore optimizes the hot path.
f.SetInt(1)
for _, b := range fieldQBytes {
switch b {
// Most common case, where all 8 bits are set.
case 0xff:
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// First byte of Q (0x3f), where all but the top two bits are
// set. Note that this case only applies six operations, since
// the highest bit of Q resides in bit six of the first byte. We
// ignore the first two bits, since squaring for these bits will
// result in an invalid result. We forgo squaring f before the
// first multiply, since 1^2 = 1.
case 0x3f:
f.Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// Byte 28 of Q (0xbf), where only bit 7 is unset.
case 0xbf:
f.Square().Mul(x)
f.Square()
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// Byte 31 of Q (0x0c), where only bits 3 and 4 are set.
default:
f.Square()
f.Square()
f.Square()
f.Square()
f.Square().Mul(x)
f.Square().Mul(x)
f.Square()
f.Square()
}
}
return f
}
// Sqrt computes the square root of f modulo the curve's prime, and stores the
// result in f. The square root is computed via exponentiation of x by the value
// Q = (P+1)/4 using the curve's precomputed big-endian representation of the Q.
// This method uses a modified version of square-and-multiply exponentiation
// over secp256k1 fieldVals to operate on bytes instead of bits, which offers
// better performance over both big.Int exponentiation and bit-wise
// square-and-multiply.
//
// NOTE: This method only works when P is intended to be the secp256k1 prime and
// is not constant time. The returned value is of magnitude 1, but is
// denormalized.
func (f *fieldVal) Sqrt() *fieldVal {
return f.SqrtVal(f)
}

View file

@ -6,6 +6,8 @@
package btcec
import (
"crypto/rand"
"fmt"
"reflect"
"testing"
)
@ -820,3 +822,146 @@ func TestInverse(t *testing.T) {
}
}
}
// randFieldVal returns a random, normalized element in the field.
func randFieldVal(t *testing.T) fieldVal {
var b [32]byte
if _, err := rand.Read(b[:]); err != nil {
t.Fatalf("unable to create random element: %v", err)
}
var x fieldVal
return *x.SetBytes(&b).Normalize()
}
type sqrtTest struct {
name string
in string
expected string
}
// TestSqrt asserts that a fieldVal properly computes the square root modulo the
// sep256k1 prime.
func TestSqrt(t *testing.T) {
var tests []sqrtTest
// No valid root exists for the negative of a square.
for i := uint(9); i > 0; i-- {
var (
x fieldVal
s fieldVal // x^2 mod p
n fieldVal // -x^2 mod p
)
x.SetInt(i)
s.SquareVal(&x).Normalize()
n.NegateVal(&s, 1).Normalize()
tests = append(tests, sqrtTest{
name: fmt.Sprintf("-%d", i),
in: fmt.Sprintf("%x", *n.Bytes()),
})
}
// A root should exist for true squares.
for i := uint(0); i < 10; i++ {
var (
x fieldVal
s fieldVal // x^2 mod p
)
x.SetInt(i)
s.SquareVal(&x).Normalize()
tests = append(tests, sqrtTest{
name: fmt.Sprintf("%d", i),
in: fmt.Sprintf("%x", *s.Bytes()),
expected: fmt.Sprintf("%x", *x.Bytes()),
})
}
// Compute a non-square element, by negating if it has a root.
ns := randFieldVal(t)
if new(fieldVal).SqrtVal(&ns).Square().Equals(&ns) {
ns.Negate(1).Normalize()
}
// For large random field values, test that:
// 1) its square has a valid root.
// 2) the negative of its square has no root.
// 3) the product of its square with a non-square has no root.
for i := 0; i < 10; i++ {
var (
x fieldVal
s fieldVal // x^2 mod p
n fieldVal // -x^2 mod p
m fieldVal // ns*x^2 mod p
)
x = randFieldVal(t)
s.SquareVal(&x).Normalize()
n.NegateVal(&s, 1).Normalize()
m.Mul2(&s, &ns).Normalize()
// A root should exist for true squares.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("%x", *s.Bytes()),
in: fmt.Sprintf("%x", *s.Bytes()),
expected: fmt.Sprintf("%x", *x.Bytes()),
})
// No valid root exists for the negative of a square.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("-%x", *s.Bytes()),
in: fmt.Sprintf("%x", *n.Bytes()),
})
// No root should be computed for product of a square and
// non-square.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("ns*%x", *s.Bytes()),
in: fmt.Sprintf("%x", *m.Bytes()),
})
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testSqrt(t, test)
})
}
}
func testSqrt(t *testing.T, test sqrtTest) {
var (
f fieldVal
root fieldVal
rootNeg fieldVal
)
f.SetHex(test.in).Normalize()
// Compute sqrt(f) and its negative.
root.SqrtVal(&f).Normalize()
rootNeg.NegateVal(&root, 1).Normalize()
switch {
// If we expect a square root, verify that either the computed square
// root is +/- the expected value.
case len(test.expected) > 0:
var expected fieldVal
expected.SetHex(test.expected).Normalize()
if !root.Equals(&expected) && !rootNeg.Equals(&expected) {
t.Fatalf("fieldVal.Sqrt incorrect root\n"+
"got: %v\ngot_neg: %v\nwant: %v",
root, rootNeg, expected)
}
// Otherwise, we expect this input not to have a square root.
default:
if root.Square().Equals(&f) || rootNeg.Square().Equals(&f) {
t.Fatalf("fieldVal.Sqrt root should not exist\n"+
"got: %v\ngot_neg: %v", root, rootNeg)
}
}
}

View file

@ -22,41 +22,41 @@ func isOdd(a *big.Int) bool {
return a.Bit(0) == 1
}
// decompressPoint decompresses a point on the given curve given the X point and
// decompressPoint decompresses a point on the secp256k1 curve given the X point and
// the solution to use.
func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, error) {
// TODO: This will probably only work for secp256k1 due to
// optimizations.
func decompressPoint(curve *KoblitzCurve, bigX *big.Int, ybit bool) (*big.Int, error) {
var x fieldVal
x.SetByteSlice(bigX.Bytes())
// Y = +-sqrt(x^3 + B)
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
x3.Add(x3, curve.Params().B)
x3.Mod(x3, curve.Params().P)
// Compute x^3 + B mod p.
var x3 fieldVal
x3.SquareVal(&x).Mul(&x)
x3.Add(curve.fieldB).Normalize()
// Now calculate sqrt mod p of x^3 + B
// This code used to do a full sqrt based on tonelli/shanks,
// but this was replaced by the algorithms referenced in
// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294
y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P)
if ybit != isOdd(y) {
y.Sub(curve.Params().P, y)
var y fieldVal
y.SqrtVal(&x3)
if ybit != y.IsOdd() {
y.Negate(1)
}
y.Normalize()
// Check that y is a square root of x^3 + B.
y2 := new(big.Int).Mul(y, y)
y2.Mod(y2, curve.Params().P)
if y2.Cmp(x3) != 0 {
var y2 fieldVal
y2.SquareVal(&y).Normalize()
if !y2.Equals(&x3) {
return nil, fmt.Errorf("invalid square root")
}
// Verify that y-coord has expected parity.
if ybit != isOdd(y) {
if ybit != y.IsOdd() {
return nil, fmt.Errorf("ybit doesn't match oddness")
}
return y, nil
return new(big.Int).SetBytes(y.Bytes()[:]), nil
}
const (