From 1203c6a237cb87b78ec495772cecb178200499ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Tue, 20 Mar 2018 01:13:54 +0900 Subject: crypto/bn256: full switchover to cloudflare's code (#16301) * crypto/bn256: full switchover to cloudflare's code * crypto/bn256: only use cloudflare for optimized architectures * crypto/bn256: upstream fallback for non-optimized code * .travis, build: drop support for Go 1.8 (need type aliases) * crypto/bn256/cloudflare: enable curve mul lattice optimization --- crypto/bn256/cloudflare/lattice.go | 115 +++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 crypto/bn256/cloudflare/lattice.go (limited to 'crypto/bn256/cloudflare/lattice.go') diff --git a/crypto/bn256/cloudflare/lattice.go b/crypto/bn256/cloudflare/lattice.go new file mode 100644 index 000000000..f9ace4d9f --- /dev/null +++ b/crypto/bn256/cloudflare/lattice.go @@ -0,0 +1,115 @@ +package bn256 + +import ( + "math/big" +) + +var half = new(big.Int).Rsh(Order, 1) + +var curveLattice = &lattice{ + vectors: [][]*big.Int{ + {bigFromBase10("147946756881789319000765030803803410728"), bigFromBase10("147946756881789319010696353538189108491")}, + {bigFromBase10("147946756881789319020627676272574806254"), bigFromBase10("-147946756881789318990833708069417712965")}, + }, + inverse: []*big.Int{ + bigFromBase10("147946756881789318990833708069417712965"), + bigFromBase10("147946756881789319010696353538189108491"), + }, + det: bigFromBase10("43776485743678550444492811490514550177096728800832068687396408373151616991234"), +} + +var targetLattice = &lattice{ + vectors: [][]*big.Int{ + {bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697763"), bigFromBase10("9931322734385697764")}, + {bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848882"), bigFromBase10("-9931322734385697762")}, + {bigFromBase10("-9931322734385697762"), bigFromBase10("-4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("-4965661367192848882")}, + {bigFromBase10("9931322734385697763"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881")}, + }, + inverse: []*big.Int{ + bigFromBase10("734653495049373973658254490726798021314063399421879442165"), + bigFromBase10("147946756881789319000765030803803410728"), + bigFromBase10("-147946756881789319005730692170996259609"), + bigFromBase10("1469306990098747947464455738335385361643788813749140841702"), + }, + det: new(big.Int).Set(Order), +} + +type lattice struct { + vectors [][]*big.Int + inverse []*big.Int + det *big.Int +} + +// decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis. +func (l *lattice) decompose(k *big.Int) []*big.Int { + n := len(l.inverse) + + // Calculate closest vector in lattice to with Babai's rounding. + c := make([]*big.Int, n) + for i := 0; i < n; i++ { + c[i] = new(big.Int).Mul(k, l.inverse[i]) + round(c[i], l.det) + } + + // Transform vectors according to c and subtract . + out := make([]*big.Int, n) + temp := new(big.Int) + + for i := 0; i < n; i++ { + out[i] = new(big.Int) + + for j := 0; j < n; j++ { + temp.Mul(c[j], l.vectors[j][i]) + out[i].Add(out[i], temp) + } + + out[i].Neg(out[i]) + out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i]) + } + out[0].Add(out[0], k) + + return out +} + +func (l *lattice) Precompute(add func(i, j uint)) { + n := uint(len(l.vectors)) + total := uint(1) << n + + for i := uint(0); i < n; i++ { + for j := uint(0); j < total; j++ { + if (j>>i)&1 == 1 { + add(i, j) + } + } + } +} + +func (l *lattice) Multi(scalar *big.Int) []uint8 { + decomp := l.decompose(scalar) + + maxLen := 0 + for _, x := range decomp { + if x.BitLen() > maxLen { + maxLen = x.BitLen() + } + } + + out := make([]uint8, maxLen) + for j, x := range decomp { + for i := 0; i < maxLen; i++ { + out[i] += uint8(x.Bit(i)) << uint(j) + } + } + + return out +} + +// round sets num to num/denom rounded to the nearest integer. +func round(num, denom *big.Int) { + r := new(big.Int) + num.DivMod(num, denom, r) + + if r.Cmp(half) == 1 { + num.Add(num, big.NewInt(1)) + } +} -- cgit v1.2.3