Merge pull request #1429 from cfromknecht/btcec-double-is-on-curve

btcec: optimize square root using fieldVal
This commit is contained in:
Olaoluwa Osuntokun 2019-10-09 17:54:42 -07:00 committed by GitHub
commit b686b0a8eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 348 additions and 31 deletions

View file

@ -4,7 +4,10 @@
package btcec
import "testing"
import (
"encoding/hex"
"testing"
)
// BenchmarkAddJacobian benchmarks the secp256k1 curve addJacobian function with
// Z values of 1 so that the associated optimizations are used.
@ -121,3 +124,22 @@ func BenchmarkFieldNormalize(b *testing.B) {
f.Normalize()
}
}
// BenchmarkParseCompressedPubKey benchmarks how long it takes to decompress and
// validate a compressed public key from a byte array.
func BenchmarkParseCompressedPubKey(b *testing.B) {
rawPk, _ := hex.DecodeString("0234f9460f0e4f08393d192b3c5133a6ba099aa0ad9fd54ebccfacdfa239ff49c6")
var (
pk *PublicKey
err error
)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
pk, err = ParsePubKey(rawPk, S256())
}
_ = pk
_ = err
}

View file

@ -36,10 +36,17 @@ var (
// interface from crypto/elliptic.
type KoblitzCurve struct {
*elliptic.CurveParams
q *big.Int
// q is the value (P+1)/4 used to compute the square root of field
// elements.
q *big.Int
H int // cofactor of the curve.
halfOrder *big.Int // half the order N
// fieldB is the constant B of the curve as a fieldVal.
fieldB *fieldVal
// byteSize is simply the bit size / 8 and is provided for convenience
// since it is calculated repeatedly.
byteSize int
@ -879,12 +886,22 @@ func (curve *KoblitzCurve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
return curve.fieldJacobianToBigAffine(qx, qy, qz)
}
// QPlus1Div4 returns the Q+1/4 constant for the curve for use in calculating
// square roots via exponention.
// QPlus1Div4 returns the (P+1)/4 constant for the curve for use in calculating
// square roots via exponentiation.
//
// DEPRECATED: The actual value returned is (P+1)/4, where as the original
// method name implies that this value is (((P+1)/4)+1)/4. This method is kept
// to maintain backwards compatibility of the API. Use Q() instead.
func (curve *KoblitzCurve) QPlus1Div4() *big.Int {
return curve.q
}
// Q returns the (P+1)/4 constant for the curve for use in calculating square
// roots via exponentiation.
func (curve *KoblitzCurve) Q() *big.Int {
return curve.q
}
var initonce sync.Once
var secp256k1 KoblitzCurve
@ -917,6 +934,7 @@ func initS256() {
big.NewInt(1)), big.NewInt(4))
secp256k1.H = 1
secp256k1.halfOrder = new(big.Int).Rsh(secp256k1.N, 1)
secp256k1.fieldB = new(fieldVal).SetByteSlice(secp256k1.B.Bytes())
// Provided for convenience since this gets computed repeatedly.
secp256k1.byteSize = secp256k1.BitSize / 8

View file

@ -102,6 +102,20 @@ const (
fieldPrimeWordOne = 0x3ffffbf
)
var (
// fieldQBytes is the value Q = (P+1)/4 for the secp256k1 prime P. This
// value is used to efficiently compute the square root of values in the
// field via exponentiation. The value of Q in hex is:
//
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
fieldQBytes = []byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c,
}
)
// fieldVal implements optimized fixed-precision arithmetic over the
// secp256k1 finite field. This means all arithmetic is performed modulo
// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. It
@ -1221,3 +1235,118 @@ func (f *fieldVal) Inverse() *fieldVal {
f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320)
return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2)
}
// SqrtVal computes the square root of x modulo the curve's prime, and stores
// the result in f. The square root is computed via exponentiation of x by the
// value Q = (P+1)/4 using the curve's precomputed big-endian representation of
// the Q. This method uses a modified version of square-and-multiply
// exponentiation over secp256k1 fieldVals to operate on bytes instead of bits,
// which offers better performance over both big.Int exponentiation and bit-wise
// square-and-multiply.
//
// NOTE: This method only works when P is intended to be the secp256k1 prime and
// is not constant time. The returned value is of magnitude 1, but is
// denormalized.
func (f *fieldVal) SqrtVal(x *fieldVal) *fieldVal {
// The following computation iteratively computes x^((P+1)/4) = x^Q
// using the recursive, piece-wise definition:
//
// x^n = (x^2)^(n/2) mod P if n is even
// x^n = x(x^2)^(n-1/2) mod P if n is odd
//
// Given n in its big-endian representation b_k, ..., b_0, x^n can be
// computed by defining the sequence r_k+1, ..., r_0, where:
//
// r_k+1 = 1
// r_i = (r_i+1)^2 * x^b_i for i = k, ..., 0
//
// The final value r_0 = x^n.
//
// See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more
// details.
//
// This can be further optimized, by observing that the value of Q in
// secp256k1 has the value:
//
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
//
// We can unroll the typical bit-wise interpretation of the
// exponentiation algorithm above to instead operate on bytes.
// This reduces the number of comparisons by an order of magnitude,
// reducing the overhead of failed branch predictions and additional
// comparisons in this method.
//
// Since there there are only 4 unique bytes of Q, this keeps the jump
// table small without the need to handle all possible 8-bit values.
// Further, we observe that 29 of the 32 bytes are 0xff; making the
// first case handle 0xff therefore optimizes the hot path.
f.SetInt(1)
for _, b := range fieldQBytes {
switch b {
// Most common case, where all 8 bits are set.
case 0xff:
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// First byte of Q (0x3f), where all but the top two bits are
// set. Note that this case only applies six operations, since
// the highest bit of Q resides in bit six of the first byte. We
// ignore the first two bits, since squaring for these bits will
// result in an invalid result. We forgo squaring f before the
// first multiply, since 1^2 = 1.
case 0x3f:
f.Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// Byte 28 of Q (0xbf), where only bit 7 is unset.
case 0xbf:
f.Square().Mul(x)
f.Square()
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
f.Square().Mul(x)
// Byte 31 of Q (0x0c), where only bits 3 and 4 are set.
default:
f.Square()
f.Square()
f.Square()
f.Square()
f.Square().Mul(x)
f.Square().Mul(x)
f.Square()
f.Square()
}
}
return f
}
// Sqrt computes the square root of f modulo the curve's prime, and stores the
// result in f. The square root is computed via exponentiation of x by the value
// Q = (P+1)/4 using the curve's precomputed big-endian representation of the Q.
// This method uses a modified version of square-and-multiply exponentiation
// over secp256k1 fieldVals to operate on bytes instead of bits, which offers
// better performance over both big.Int exponentiation and bit-wise
// square-and-multiply.
//
// NOTE: This method only works when P is intended to be the secp256k1 prime and
// is not constant time. The returned value is of magnitude 1, but is
// denormalized.
func (f *fieldVal) Sqrt() *fieldVal {
return f.SqrtVal(f)
}

View file

@ -6,6 +6,8 @@
package btcec
import (
"crypto/rand"
"fmt"
"reflect"
"testing"
)
@ -820,3 +822,146 @@ func TestInverse(t *testing.T) {
}
}
}
// randFieldVal returns a random, normalized element in the field.
func randFieldVal(t *testing.T) fieldVal {
var b [32]byte
if _, err := rand.Read(b[:]); err != nil {
t.Fatalf("unable to create random element: %v", err)
}
var x fieldVal
return *x.SetBytes(&b).Normalize()
}
type sqrtTest struct {
name string
in string
expected string
}
// TestSqrt asserts that a fieldVal properly computes the square root modulo the
// sep256k1 prime.
func TestSqrt(t *testing.T) {
var tests []sqrtTest
// No valid root exists for the negative of a square.
for i := uint(9); i > 0; i-- {
var (
x fieldVal
s fieldVal // x^2 mod p
n fieldVal // -x^2 mod p
)
x.SetInt(i)
s.SquareVal(&x).Normalize()
n.NegateVal(&s, 1).Normalize()
tests = append(tests, sqrtTest{
name: fmt.Sprintf("-%d", i),
in: fmt.Sprintf("%x", *n.Bytes()),
})
}
// A root should exist for true squares.
for i := uint(0); i < 10; i++ {
var (
x fieldVal
s fieldVal // x^2 mod p
)
x.SetInt(i)
s.SquareVal(&x).Normalize()
tests = append(tests, sqrtTest{
name: fmt.Sprintf("%d", i),
in: fmt.Sprintf("%x", *s.Bytes()),
expected: fmt.Sprintf("%x", *x.Bytes()),
})
}
// Compute a non-square element, by negating if it has a root.
ns := randFieldVal(t)
if new(fieldVal).SqrtVal(&ns).Square().Equals(&ns) {
ns.Negate(1).Normalize()
}
// For large random field values, test that:
// 1) its square has a valid root.
// 2) the negative of its square has no root.
// 3) the product of its square with a non-square has no root.
for i := 0; i < 10; i++ {
var (
x fieldVal
s fieldVal // x^2 mod p
n fieldVal // -x^2 mod p
m fieldVal // ns*x^2 mod p
)
x = randFieldVal(t)
s.SquareVal(&x).Normalize()
n.NegateVal(&s, 1).Normalize()
m.Mul2(&s, &ns).Normalize()
// A root should exist for true squares.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("%x", *s.Bytes()),
in: fmt.Sprintf("%x", *s.Bytes()),
expected: fmt.Sprintf("%x", *x.Bytes()),
})
// No valid root exists for the negative of a square.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("-%x", *s.Bytes()),
in: fmt.Sprintf("%x", *n.Bytes()),
})
// No root should be computed for product of a square and
// non-square.
tests = append(tests, sqrtTest{
name: fmt.Sprintf("ns*%x", *s.Bytes()),
in: fmt.Sprintf("%x", *m.Bytes()),
})
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testSqrt(t, test)
})
}
}
func testSqrt(t *testing.T, test sqrtTest) {
var (
f fieldVal
root fieldVal
rootNeg fieldVal
)
f.SetHex(test.in).Normalize()
// Compute sqrt(f) and its negative.
root.SqrtVal(&f).Normalize()
rootNeg.NegateVal(&root, 1).Normalize()
switch {
// If we expect a square root, verify that either the computed square
// root is +/- the expected value.
case len(test.expected) > 0:
var expected fieldVal
expected.SetHex(test.expected).Normalize()
if !root.Equals(&expected) && !rootNeg.Equals(&expected) {
t.Fatalf("fieldVal.Sqrt incorrect root\n"+
"got: %v\ngot_neg: %v\nwant: %v",
root, rootNeg, expected)
}
// Otherwise, we expect this input not to have a square root.
default:
if root.Square().Equals(&f) || rootNeg.Square().Equals(&f) {
t.Fatalf("fieldVal.Sqrt root should not exist\n"+
"got: %v\ngot_neg: %v", root, rootNeg)
}
}
}

View file

@ -22,41 +22,41 @@ func isOdd(a *big.Int) bool {
return a.Bit(0) == 1
}
// decompressPoint decompresses a point on the given curve given the X point and
// decompressPoint decompresses a point on the secp256k1 curve given the X point and
// the solution to use.
func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, error) {
// TODO: This will probably only work for secp256k1 due to
// optimizations.
func decompressPoint(curve *KoblitzCurve, bigX *big.Int, ybit bool) (*big.Int, error) {
var x fieldVal
x.SetByteSlice(bigX.Bytes())
// Y = +-sqrt(x^3 + B)
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
x3.Add(x3, curve.Params().B)
x3.Mod(x3, curve.Params().P)
// Compute x^3 + B mod p.
var x3 fieldVal
x3.SquareVal(&x).Mul(&x)
x3.Add(curve.fieldB).Normalize()
// Now calculate sqrt mod p of x^3 + B
// This code used to do a full sqrt based on tonelli/shanks,
// but this was replaced by the algorithms referenced in
// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294
y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P)
if ybit != isOdd(y) {
y.Sub(curve.Params().P, y)
var y fieldVal
y.SqrtVal(&x3)
if ybit != y.IsOdd() {
y.Negate(1)
}
y.Normalize()
// Check that y is a square root of x^3 + B.
y2 := new(big.Int).Mul(y, y)
y2.Mod(y2, curve.Params().P)
if y2.Cmp(x3) != 0 {
var y2 fieldVal
y2.SquareVal(&y).Normalize()
if !y2.Equals(&x3) {
return nil, fmt.Errorf("invalid square root")
}
// Verify that y-coord has expected parity.
if ybit != isOdd(y) {
if ybit != y.IsOdd() {
return nil, fmt.Errorf("ybit doesn't match oddness")
}
return y, nil
return new(big.Int).SetBytes(y.Bytes()[:]), nil
}
const (
@ -102,6 +102,17 @@ func ParsePubKey(pubKeyStr []byte, curve *KoblitzCurve) (key *PublicKey, err err
if format == pubkeyHybrid && ybit != isOdd(pubkey.Y) {
return nil, fmt.Errorf("ybit doesn't match oddness")
}
if pubkey.X.Cmp(pubkey.Curve.Params().P) >= 0 {
return nil, fmt.Errorf("pubkey X parameter is >= to P")
}
if pubkey.Y.Cmp(pubkey.Curve.Params().P) >= 0 {
return nil, fmt.Errorf("pubkey Y parameter is >= to P")
}
if !pubkey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
return nil, fmt.Errorf("pubkey isn't on secp256k1 curve")
}
case PubKeyBytesLenCompressed:
// format is 0x2 | solution, <X coordinate>
// solution determines which solution of the curve we use.
@ -115,20 +126,12 @@ func ParsePubKey(pubKeyStr []byte, curve *KoblitzCurve) (key *PublicKey, err err
if err != nil {
return nil, err
}
default: // wrong!
return nil, fmt.Errorf("invalid pub key length %d",
len(pubKeyStr))
}
if pubkey.X.Cmp(pubkey.Curve.Params().P) >= 0 {
return nil, fmt.Errorf("pubkey X parameter is >= to P")
}
if pubkey.Y.Cmp(pubkey.Curve.Params().P) >= 0 {
return nil, fmt.Errorf("pubkey Y parameter is >= to P")
}
if !pubkey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
return nil, fmt.Errorf("pubkey isn't on secp256k1 curve")
}
return &pubkey, nil
}