Merge pull request #1429 from cfromknecht/btcec-double-is-on-curve
btcec: optimize square root using fieldVal
This commit is contained in:
commit
b686b0a8eb
5 changed files with 348 additions and 31 deletions
|
@ -4,7 +4,10 @@
|
|||
|
||||
package btcec
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkAddJacobian benchmarks the secp256k1 curve addJacobian function with
|
||||
// Z values of 1 so that the associated optimizations are used.
|
||||
|
@ -121,3 +124,22 @@ func BenchmarkFieldNormalize(b *testing.B) {
|
|||
f.Normalize()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkParseCompressedPubKey benchmarks how long it takes to decompress and
|
||||
// validate a compressed public key from a byte array.
|
||||
func BenchmarkParseCompressedPubKey(b *testing.B) {
|
||||
rawPk, _ := hex.DecodeString("0234f9460f0e4f08393d192b3c5133a6ba099aa0ad9fd54ebccfacdfa239ff49c6")
|
||||
|
||||
var (
|
||||
pk *PublicKey
|
||||
err error
|
||||
)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pk, err = ParsePubKey(rawPk, S256())
|
||||
}
|
||||
_ = pk
|
||||
_ = err
|
||||
}
|
||||
|
|
|
@ -36,10 +36,17 @@ var (
|
|||
// interface from crypto/elliptic.
|
||||
type KoblitzCurve struct {
|
||||
*elliptic.CurveParams
|
||||
q *big.Int
|
||||
|
||||
// q is the value (P+1)/4 used to compute the square root of field
|
||||
// elements.
|
||||
q *big.Int
|
||||
|
||||
H int // cofactor of the curve.
|
||||
halfOrder *big.Int // half the order N
|
||||
|
||||
// fieldB is the constant B of the curve as a fieldVal.
|
||||
fieldB *fieldVal
|
||||
|
||||
// byteSize is simply the bit size / 8 and is provided for convenience
|
||||
// since it is calculated repeatedly.
|
||||
byteSize int
|
||||
|
@ -879,12 +886,22 @@ func (curve *KoblitzCurve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
|
|||
return curve.fieldJacobianToBigAffine(qx, qy, qz)
|
||||
}
|
||||
|
||||
// QPlus1Div4 returns the Q+1/4 constant for the curve for use in calculating
|
||||
// square roots via exponention.
|
||||
// QPlus1Div4 returns the (P+1)/4 constant for the curve for use in calculating
|
||||
// square roots via exponentiation.
|
||||
//
|
||||
// DEPRECATED: The actual value returned is (P+1)/4, where as the original
|
||||
// method name implies that this value is (((P+1)/4)+1)/4. This method is kept
|
||||
// to maintain backwards compatibility of the API. Use Q() instead.
|
||||
func (curve *KoblitzCurve) QPlus1Div4() *big.Int {
|
||||
return curve.q
|
||||
}
|
||||
|
||||
// Q returns the (P+1)/4 constant for the curve for use in calculating square
|
||||
// roots via exponentiation.
|
||||
func (curve *KoblitzCurve) Q() *big.Int {
|
||||
return curve.q
|
||||
}
|
||||
|
||||
var initonce sync.Once
|
||||
var secp256k1 KoblitzCurve
|
||||
|
||||
|
@ -917,6 +934,7 @@ func initS256() {
|
|||
big.NewInt(1)), big.NewInt(4))
|
||||
secp256k1.H = 1
|
||||
secp256k1.halfOrder = new(big.Int).Rsh(secp256k1.N, 1)
|
||||
secp256k1.fieldB = new(fieldVal).SetByteSlice(secp256k1.B.Bytes())
|
||||
|
||||
// Provided for convenience since this gets computed repeatedly.
|
||||
secp256k1.byteSize = secp256k1.BitSize / 8
|
||||
|
|
129
btcec/field.go
129
btcec/field.go
|
@ -102,6 +102,20 @@ const (
|
|||
fieldPrimeWordOne = 0x3ffffbf
|
||||
)
|
||||
|
||||
var (
|
||||
// fieldQBytes is the value Q = (P+1)/4 for the secp256k1 prime P. This
|
||||
// value is used to efficiently compute the square root of values in the
|
||||
// field via exponentiation. The value of Q in hex is:
|
||||
//
|
||||
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
|
||||
fieldQBytes = []byte{
|
||||
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xbf, 0xff, 0xff, 0x0c,
|
||||
}
|
||||
)
|
||||
|
||||
// fieldVal implements optimized fixed-precision arithmetic over the
|
||||
// secp256k1 finite field. This means all arithmetic is performed modulo
|
||||
// 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f. It
|
||||
|
@ -1221,3 +1235,118 @@ func (f *fieldVal) Inverse() *fieldVal {
|
|||
f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320)
|
||||
return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2)
|
||||
}
|
||||
|
||||
// SqrtVal computes the square root of x modulo the curve's prime, and stores
|
||||
// the result in f. The square root is computed via exponentiation of x by the
|
||||
// value Q = (P+1)/4 using the curve's precomputed big-endian representation of
|
||||
// the Q. This method uses a modified version of square-and-multiply
|
||||
// exponentiation over secp256k1 fieldVals to operate on bytes instead of bits,
|
||||
// which offers better performance over both big.Int exponentiation and bit-wise
|
||||
// square-and-multiply.
|
||||
//
|
||||
// NOTE: This method only works when P is intended to be the secp256k1 prime and
|
||||
// is not constant time. The returned value is of magnitude 1, but is
|
||||
// denormalized.
|
||||
func (f *fieldVal) SqrtVal(x *fieldVal) *fieldVal {
|
||||
// The following computation iteratively computes x^((P+1)/4) = x^Q
|
||||
// using the recursive, piece-wise definition:
|
||||
//
|
||||
// x^n = (x^2)^(n/2) mod P if n is even
|
||||
// x^n = x(x^2)^(n-1/2) mod P if n is odd
|
||||
//
|
||||
// Given n in its big-endian representation b_k, ..., b_0, x^n can be
|
||||
// computed by defining the sequence r_k+1, ..., r_0, where:
|
||||
//
|
||||
// r_k+1 = 1
|
||||
// r_i = (r_i+1)^2 * x^b_i for i = k, ..., 0
|
||||
//
|
||||
// The final value r_0 = x^n.
|
||||
//
|
||||
// See https://en.wikipedia.org/wiki/Exponentiation_by_squaring for more
|
||||
// details.
|
||||
//
|
||||
// This can be further optimized, by observing that the value of Q in
|
||||
// secp256k1 has the value:
|
||||
//
|
||||
// Q = 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c
|
||||
//
|
||||
// We can unroll the typical bit-wise interpretation of the
|
||||
// exponentiation algorithm above to instead operate on bytes.
|
||||
// This reduces the number of comparisons by an order of magnitude,
|
||||
// reducing the overhead of failed branch predictions and additional
|
||||
// comparisons in this method.
|
||||
//
|
||||
// Since there there are only 4 unique bytes of Q, this keeps the jump
|
||||
// table small without the need to handle all possible 8-bit values.
|
||||
// Further, we observe that 29 of the 32 bytes are 0xff; making the
|
||||
// first case handle 0xff therefore optimizes the hot path.
|
||||
f.SetInt(1)
|
||||
for _, b := range fieldQBytes {
|
||||
switch b {
|
||||
|
||||
// Most common case, where all 8 bits are set.
|
||||
case 0xff:
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
|
||||
// First byte of Q (0x3f), where all but the top two bits are
|
||||
// set. Note that this case only applies six operations, since
|
||||
// the highest bit of Q resides in bit six of the first byte. We
|
||||
// ignore the first two bits, since squaring for these bits will
|
||||
// result in an invalid result. We forgo squaring f before the
|
||||
// first multiply, since 1^2 = 1.
|
||||
case 0x3f:
|
||||
f.Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
|
||||
// Byte 28 of Q (0xbf), where only bit 7 is unset.
|
||||
case 0xbf:
|
||||
f.Square().Mul(x)
|
||||
f.Square()
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
|
||||
// Byte 31 of Q (0x0c), where only bits 3 and 4 are set.
|
||||
default:
|
||||
f.Square()
|
||||
f.Square()
|
||||
f.Square()
|
||||
f.Square()
|
||||
f.Square().Mul(x)
|
||||
f.Square().Mul(x)
|
||||
f.Square()
|
||||
f.Square()
|
||||
}
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
// Sqrt computes the square root of f modulo the curve's prime, and stores the
|
||||
// result in f. The square root is computed via exponentiation of x by the value
|
||||
// Q = (P+1)/4 using the curve's precomputed big-endian representation of the Q.
|
||||
// This method uses a modified version of square-and-multiply exponentiation
|
||||
// over secp256k1 fieldVals to operate on bytes instead of bits, which offers
|
||||
// better performance over both big.Int exponentiation and bit-wise
|
||||
// square-and-multiply.
|
||||
//
|
||||
// NOTE: This method only works when P is intended to be the secp256k1 prime and
|
||||
// is not constant time. The returned value is of magnitude 1, but is
|
||||
// denormalized.
|
||||
func (f *fieldVal) Sqrt() *fieldVal {
|
||||
return f.SqrtVal(f)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
package btcec
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
@ -820,3 +822,146 @@ func TestInverse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// randFieldVal returns a random, normalized element in the field.
|
||||
func randFieldVal(t *testing.T) fieldVal {
|
||||
var b [32]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
t.Fatalf("unable to create random element: %v", err)
|
||||
}
|
||||
|
||||
var x fieldVal
|
||||
return *x.SetBytes(&b).Normalize()
|
||||
}
|
||||
|
||||
type sqrtTest struct {
|
||||
name string
|
||||
in string
|
||||
expected string
|
||||
}
|
||||
|
||||
// TestSqrt asserts that a fieldVal properly computes the square root modulo the
|
||||
// sep256k1 prime.
|
||||
func TestSqrt(t *testing.T) {
|
||||
var tests []sqrtTest
|
||||
|
||||
// No valid root exists for the negative of a square.
|
||||
for i := uint(9); i > 0; i-- {
|
||||
var (
|
||||
x fieldVal
|
||||
s fieldVal // x^2 mod p
|
||||
n fieldVal // -x^2 mod p
|
||||
)
|
||||
|
||||
x.SetInt(i)
|
||||
s.SquareVal(&x).Normalize()
|
||||
n.NegateVal(&s, 1).Normalize()
|
||||
|
||||
tests = append(tests, sqrtTest{
|
||||
name: fmt.Sprintf("-%d", i),
|
||||
in: fmt.Sprintf("%x", *n.Bytes()),
|
||||
})
|
||||
}
|
||||
|
||||
// A root should exist for true squares.
|
||||
for i := uint(0); i < 10; i++ {
|
||||
var (
|
||||
x fieldVal
|
||||
s fieldVal // x^2 mod p
|
||||
)
|
||||
|
||||
x.SetInt(i)
|
||||
s.SquareVal(&x).Normalize()
|
||||
|
||||
tests = append(tests, sqrtTest{
|
||||
name: fmt.Sprintf("%d", i),
|
||||
in: fmt.Sprintf("%x", *s.Bytes()),
|
||||
expected: fmt.Sprintf("%x", *x.Bytes()),
|
||||
})
|
||||
}
|
||||
|
||||
// Compute a non-square element, by negating if it has a root.
|
||||
ns := randFieldVal(t)
|
||||
if new(fieldVal).SqrtVal(&ns).Square().Equals(&ns) {
|
||||
ns.Negate(1).Normalize()
|
||||
}
|
||||
|
||||
// For large random field values, test that:
|
||||
// 1) its square has a valid root.
|
||||
// 2) the negative of its square has no root.
|
||||
// 3) the product of its square with a non-square has no root.
|
||||
for i := 0; i < 10; i++ {
|
||||
var (
|
||||
x fieldVal
|
||||
s fieldVal // x^2 mod p
|
||||
n fieldVal // -x^2 mod p
|
||||
m fieldVal // ns*x^2 mod p
|
||||
)
|
||||
|
||||
x = randFieldVal(t)
|
||||
s.SquareVal(&x).Normalize()
|
||||
n.NegateVal(&s, 1).Normalize()
|
||||
m.Mul2(&s, &ns).Normalize()
|
||||
|
||||
// A root should exist for true squares.
|
||||
tests = append(tests, sqrtTest{
|
||||
name: fmt.Sprintf("%x", *s.Bytes()),
|
||||
in: fmt.Sprintf("%x", *s.Bytes()),
|
||||
expected: fmt.Sprintf("%x", *x.Bytes()),
|
||||
})
|
||||
|
||||
// No valid root exists for the negative of a square.
|
||||
tests = append(tests, sqrtTest{
|
||||
name: fmt.Sprintf("-%x", *s.Bytes()),
|
||||
in: fmt.Sprintf("%x", *n.Bytes()),
|
||||
})
|
||||
|
||||
// No root should be computed for product of a square and
|
||||
// non-square.
|
||||
tests = append(tests, sqrtTest{
|
||||
name: fmt.Sprintf("ns*%x", *s.Bytes()),
|
||||
in: fmt.Sprintf("%x", *m.Bytes()),
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
testSqrt(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testSqrt(t *testing.T, test sqrtTest) {
|
||||
var (
|
||||
f fieldVal
|
||||
root fieldVal
|
||||
rootNeg fieldVal
|
||||
)
|
||||
|
||||
f.SetHex(test.in).Normalize()
|
||||
|
||||
// Compute sqrt(f) and its negative.
|
||||
root.SqrtVal(&f).Normalize()
|
||||
rootNeg.NegateVal(&root, 1).Normalize()
|
||||
|
||||
switch {
|
||||
|
||||
// If we expect a square root, verify that either the computed square
|
||||
// root is +/- the expected value.
|
||||
case len(test.expected) > 0:
|
||||
var expected fieldVal
|
||||
expected.SetHex(test.expected).Normalize()
|
||||
if !root.Equals(&expected) && !rootNeg.Equals(&expected) {
|
||||
t.Fatalf("fieldVal.Sqrt incorrect root\n"+
|
||||
"got: %v\ngot_neg: %v\nwant: %v",
|
||||
root, rootNeg, expected)
|
||||
}
|
||||
|
||||
// Otherwise, we expect this input not to have a square root.
|
||||
default:
|
||||
if root.Square().Equals(&f) || rootNeg.Square().Equals(&f) {
|
||||
t.Fatalf("fieldVal.Sqrt root should not exist\n"+
|
||||
"got: %v\ngot_neg: %v", root, rootNeg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,41 +22,41 @@ func isOdd(a *big.Int) bool {
|
|||
return a.Bit(0) == 1
|
||||
}
|
||||
|
||||
// decompressPoint decompresses a point on the given curve given the X point and
|
||||
// decompressPoint decompresses a point on the secp256k1 curve given the X point and
|
||||
// the solution to use.
|
||||
func decompressPoint(curve *KoblitzCurve, x *big.Int, ybit bool) (*big.Int, error) {
|
||||
// TODO: This will probably only work for secp256k1 due to
|
||||
// optimizations.
|
||||
func decompressPoint(curve *KoblitzCurve, bigX *big.Int, ybit bool) (*big.Int, error) {
|
||||
var x fieldVal
|
||||
x.SetByteSlice(bigX.Bytes())
|
||||
|
||||
// Y = +-sqrt(x^3 + B)
|
||||
x3 := new(big.Int).Mul(x, x)
|
||||
x3.Mul(x3, x)
|
||||
x3.Add(x3, curve.Params().B)
|
||||
x3.Mod(x3, curve.Params().P)
|
||||
// Compute x^3 + B mod p.
|
||||
var x3 fieldVal
|
||||
x3.SquareVal(&x).Mul(&x)
|
||||
x3.Add(curve.fieldB).Normalize()
|
||||
|
||||
// Now calculate sqrt mod p of x^3 + B
|
||||
// This code used to do a full sqrt based on tonelli/shanks,
|
||||
// but this was replaced by the algorithms referenced in
|
||||
// https://bitcointalk.org/index.php?topic=162805.msg1712294#msg1712294
|
||||
y := new(big.Int).Exp(x3, curve.QPlus1Div4(), curve.Params().P)
|
||||
|
||||
if ybit != isOdd(y) {
|
||||
y.Sub(curve.Params().P, y)
|
||||
var y fieldVal
|
||||
y.SqrtVal(&x3)
|
||||
if ybit != y.IsOdd() {
|
||||
y.Negate(1)
|
||||
}
|
||||
y.Normalize()
|
||||
|
||||
// Check that y is a square root of x^3 + B.
|
||||
y2 := new(big.Int).Mul(y, y)
|
||||
y2.Mod(y2, curve.Params().P)
|
||||
if y2.Cmp(x3) != 0 {
|
||||
var y2 fieldVal
|
||||
y2.SquareVal(&y).Normalize()
|
||||
if !y2.Equals(&x3) {
|
||||
return nil, fmt.Errorf("invalid square root")
|
||||
}
|
||||
|
||||
// Verify that y-coord has expected parity.
|
||||
if ybit != isOdd(y) {
|
||||
if ybit != y.IsOdd() {
|
||||
return nil, fmt.Errorf("ybit doesn't match oddness")
|
||||
}
|
||||
|
||||
return y, nil
|
||||
return new(big.Int).SetBytes(y.Bytes()[:]), nil
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -102,6 +102,17 @@ func ParsePubKey(pubKeyStr []byte, curve *KoblitzCurve) (key *PublicKey, err err
|
|||
if format == pubkeyHybrid && ybit != isOdd(pubkey.Y) {
|
||||
return nil, fmt.Errorf("ybit doesn't match oddness")
|
||||
}
|
||||
|
||||
if pubkey.X.Cmp(pubkey.Curve.Params().P) >= 0 {
|
||||
return nil, fmt.Errorf("pubkey X parameter is >= to P")
|
||||
}
|
||||
if pubkey.Y.Cmp(pubkey.Curve.Params().P) >= 0 {
|
||||
return nil, fmt.Errorf("pubkey Y parameter is >= to P")
|
||||
}
|
||||
if !pubkey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
|
||||
return nil, fmt.Errorf("pubkey isn't on secp256k1 curve")
|
||||
}
|
||||
|
||||
case PubKeyBytesLenCompressed:
|
||||
// format is 0x2 | solution, <X coordinate>
|
||||
// solution determines which solution of the curve we use.
|
||||
|
@ -115,20 +126,12 @@ func ParsePubKey(pubKeyStr []byte, curve *KoblitzCurve) (key *PublicKey, err err
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default: // wrong!
|
||||
return nil, fmt.Errorf("invalid pub key length %d",
|
||||
len(pubKeyStr))
|
||||
}
|
||||
|
||||
if pubkey.X.Cmp(pubkey.Curve.Params().P) >= 0 {
|
||||
return nil, fmt.Errorf("pubkey X parameter is >= to P")
|
||||
}
|
||||
if pubkey.Y.Cmp(pubkey.Curve.Params().P) >= 0 {
|
||||
return nil, fmt.Errorf("pubkey Y parameter is >= to P")
|
||||
}
|
||||
if !pubkey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
|
||||
return nil, fmt.Errorf("pubkey isn't on secp256k1 curve")
|
||||
}
|
||||
return &pubkey, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue