aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/dexon-foundation/mcl/src/gen.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/dexon-foundation/mcl/src/gen.cpp')
-rw-r--r--vendor/github.com/dexon-foundation/mcl/src/gen.cpp999
1 files changed, 999 insertions, 0 deletions
diff --git a/vendor/github.com/dexon-foundation/mcl/src/gen.cpp b/vendor/github.com/dexon-foundation/mcl/src/gen.cpp
new file mode 100644
index 000000000..763f64b98
--- /dev/null
+++ b/vendor/github.com/dexon-foundation/mcl/src/gen.cpp
@@ -0,0 +1,999 @@
+#include "llvm_gen.hpp"
+#include <cybozu/option.hpp>
+#include <mcl/op.hpp>
+#include <map>
+#include <set>
+#include <fstream>
+
+typedef std::set<std::string> StrSet;
+
+struct Code : public mcl::Generator {
+ typedef std::map<int, Function> FunctionMap;
+ typedef std::vector<Operand> OperandVec;
+ Operand Void;
+ uint32_t unit;
+ uint32_t unit2;
+ uint32_t bit;
+ uint32_t N;
+ const StrSet *privateFuncList;
+ bool wasm;
+ std::string suf;
+ std::string unitStr;
+ Function mulUU;
+ Function mul32x32; // for WASM
+ Function extractHigh;
+ Function mulPos;
+ Function makeNIST_P192;
+ Function mcl_fpDbl_mod_NIST_P192;
+ Function mcl_fp_sqr_NIST_P192;
+ FunctionMap mcl_fp_shr1_M;
+ FunctionMap mcl_fp_addPreM;
+ FunctionMap mcl_fp_subPreM;
+ FunctionMap mcl_fp_addM;
+ FunctionMap mcl_fp_subM;
+ FunctionMap mulPvM;
+ FunctionMap mcl_fp_mulUnitPreM;
+ FunctionMap mcl_fpDbl_mulPreM;
+ FunctionMap mcl_fpDbl_sqrPreM;
+ FunctionMap mcl_fp_montM;
+ FunctionMap mcl_fp_montRedM;
+ Code() : unit(0), unit2(0), bit(0), N(0), privateFuncList(0), wasm(false) { }
+ void verifyAndSetPrivate(Function& f)
+ {
+ if (privateFuncList && privateFuncList->find(f.name) != privateFuncList->end()) {
+ f.setPrivate();
+ }
+ }
+ void storeN(Operand r, Operand p, int offset = 0)
+ {
+ if (p.bit != unit) {
+ throw cybozu::Exception("bad IntPtr size") << p.bit;
+ }
+ if (offset > 0) {
+ p = getelementptr(p, offset);
+ }
+ if (r.bit == unit) {
+ store(r, p);
+ return;
+ }
+ const size_t n = r.bit / unit;
+ for (size_t i = 0; i < n; i++) {
+ store(trunc(r, unit), getelementptr(p, i));
+ if (i < n - 1) {
+ r = lshr(r, unit);
+ }
+ }
+ }
+ Operand loadN(Operand p, size_t n, int offset = 0)
+ {
+ if (p.bit != unit) {
+ throw cybozu::Exception("bad IntPtr size") << p.bit;
+ }
+ if (offset > 0) {
+ p = getelementptr(p, offset);
+ }
+ Operand v = load(p);
+ for (size_t i = 1; i < n; i++) {
+ v = zext(v, v.bit + unit);
+ Operand t = load(getelementptr(p, i));
+ t = zext(t, v.bit);
+ t = shl(t, unit * i);
+ v = _or(v, t);
+ }
+ return v;
+ }
+ void gen_mul32x32()
+ {
+ const int u = 32;
+ resetGlobalIdx();
+ Operand z(Int, u * 2);
+ Operand x(Int, u);
+ Operand y(Int, u);
+ mul32x32 = Function("mul32x32L", z, x, y);
+ mul32x32.setPrivate();
+ verifyAndSetPrivate(mul32x32);
+ beginFunc(mul32x32);
+
+ x = zext(x, u * 2);
+ y = zext(y, u * 2);
+ z = mul(x, y);
+ ret(z);
+ endFunc();
+ }
+ void gen_mul64x64(Operand& z, Operand& x, Operand& y)
+ {
+ Operand a = trunc(lshr(x, 32), 32);
+ Operand b = trunc(x, 32);
+ Operand c = trunc(lshr(y, 32), 32);
+ Operand d = trunc(y, 32);
+ Operand ad = call(mul32x32, a, d);
+ Operand bd = call(mul32x32, b, d);
+ bd = zext(bd, 96);
+ ad = shl(zext(ad, 96), 32);
+ ad = add(ad, bd);
+ Operand ac = call(mul32x32, a, c);
+ Operand bc = call(mul32x32, b, c);
+ bc = zext(bc, 96);
+ ac = shl(zext(ac, 96), 32);
+ ac = add(ac, bc);
+ ad = zext(ad, 128);
+ ac = shl(zext(ac, 128), 32);
+ z = add(ac, ad);
+ }
+ void gen_mulUU()
+ {
+ if (wasm) {
+ gen_mul32x32();
+ }
+ resetGlobalIdx();
+ Operand z(Int, unit2);
+ Operand x(Int, unit);
+ Operand y(Int, unit);
+ std::string name = "mul";
+ name += unitStr + "x" + unitStr + "L";
+ mulUU = Function(name, z, x, y);
+ mulUU.setPrivate();
+ verifyAndSetPrivate(mulUU);
+ beginFunc(mulUU);
+
+ if (wasm) {
+ gen_mul64x64(z, x, y);
+ } else {
+ x = zext(x, unit2);
+ y = zext(y, unit2);
+ z = mul(x, y);
+ }
+ ret(z);
+ endFunc();
+ }
+ void gen_extractHigh()
+ {
+ resetGlobalIdx();
+ Operand z(Int, unit);
+ Operand x(Int, unit2);
+ std::string name = "extractHigh";
+ name += unitStr;
+ extractHigh = Function(name, z, x);
+ extractHigh.setPrivate();
+ beginFunc(extractHigh);
+
+ x = lshr(x, unit);
+ z = trunc(x, unit);
+ ret(z);
+ endFunc();
+ }
+ void gen_mulPos()
+ {
+ resetGlobalIdx();
+ Operand xy(Int, unit2);
+ Operand px(IntPtr, unit);
+ Operand y(Int, unit);
+ Operand i(Int, unit);
+ std::string name = "mulPos";
+ name += unitStr + "x" + unitStr;
+ mulPos = Function(name, xy, px, y, i);
+ mulPos.setPrivate();
+ beginFunc(mulPos);
+
+ Operand x = load(getelementptr(px, i));
+ xy = call(mulUU, x, y);
+ ret(xy);
+ endFunc();
+ }
+ Operand extract192to64(const Operand& x, uint32_t shift)
+ {
+ Operand y = lshr(x, shift);
+ y = trunc(y, 64);
+ return y;
+ }
+ void gen_makeNIST_P192()
+ {
+ resetGlobalIdx();
+ Operand p(Int, 192);
+ Operand p0(Int, 64);
+ Operand p1(Int, 64);
+ Operand p2(Int, 64);
+ Operand _0 = makeImm(64, 0);
+ Operand _1 = makeImm(64, 1);
+ Operand _2 = makeImm(64, 2);
+ makeNIST_P192 = Function("makeNIST_P192L" + suf, p);
+ verifyAndSetPrivate(makeNIST_P192);
+ beginFunc(makeNIST_P192);
+ p0 = sub(_0, _1);
+ p1 = sub(_0, _2);
+ p2 = sub(_0, _1);
+ p0 = zext(p0, 192);
+ p1 = zext(p1, 192);
+ p2 = zext(p2, 192);
+ p1 = shl(p1, 64);
+ p2 = shl(p2, 128);
+ p = add(p0, p1);
+ p = add(p, p2);
+ ret(p);
+ endFunc();
+ }
+ /*
+ NIST_P192
+ p = 0xfffffffffffffffffffffffffffffffeffffffffffffffff
+ 0 1 2
+ ffffffffffffffff fffffffffffffffe ffffffffffffffff
+
+ p = (1 << 192) - (1 << 64) - 1
+ (1 << 192) % p = (1 << 64) + 1
+ L : 192bit
+ Hi: 64bit
+ x = [H:L] = [H2:H1:H0:L]
+ mod p
+ x = L + H + (H << 64)
+ = L + H + [H1:H0:0] + H2 + (H2 << 64)
+ [e:t] = L + H + [H1:H0:H2] + [H2:0] ; 2bit(e) over
+ y = t + e + (e << 64)
+ if (y >= p) y -= p
+ */
+ void gen_mcl_fpDbl_mod_NIST_P192()
+ {
+ resetGlobalIdx();
+ Operand out(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ mcl_fpDbl_mod_NIST_P192 = Function("mcl_fpDbl_mod_NIST_P192L" + suf, Void, out, px);
+ verifyAndSetPrivate(mcl_fpDbl_mod_NIST_P192);
+ beginFunc(mcl_fpDbl_mod_NIST_P192);
+
+ const int n = 192 / unit;
+ Operand L = loadN(px, n);
+ L = zext(L, 256);
+
+ Operand H192 = loadN(px, n, n);
+ Operand H = zext(H192, 256);
+
+ Operand H10 = shl(H192, 64);
+ H10 = zext(H10, 256);
+
+ Operand H2 = extract192to64(H192, 128);
+ H2 = zext(H2, 256);
+ Operand H102 = _or(H10, H2);
+
+ H2 = shl(H2, 64);
+
+ Operand t = add(L, H);
+ t = add(t, H102);
+ t = add(t, H2);
+
+ Operand e = lshr(t, 192);
+ e = trunc(e, 64);
+ e = zext(e, 256);
+ Operand e2 = shl(e, 64);
+ e = _or(e, e2);
+
+ t = trunc(t, 192);
+ t = zext(t, 256);
+
+ Operand z = add(t, e);
+ Operand p = call(makeNIST_P192);
+ p = zext(p, 256);
+ Operand zp = sub(z, p);
+ Operand c = trunc(lshr(zp, 192), 1);
+ z = trunc(select(c, z, zp), 192);
+ storeN(z, out);
+ ret(Void);
+ endFunc();
+ }
+ /*
+ NIST_P521
+ p = (1 << 521) - 1
+ x = [H:L]
+ x % p = (L + H) % p
+ */
+ void gen_mcl_fpDbl_mod_NIST_P521()
+ {
+ resetGlobalIdx();
+ const uint32_t len = 521;
+ const uint32_t n = len / unit;
+ const uint32_t round = unit * (n + 1);
+ const uint32_t rem = len - n * unit;
+ const size_t mask = -(1 << rem);
+ const Operand py(IntPtr, unit);
+ const Operand px(IntPtr, unit);
+ Function f("mcl_fpDbl_mod_NIST_P521L" + suf, Void, py, px);
+ verifyAndSetPrivate(f);
+ beginFunc(f);
+ Operand x = loadN(px, n * 2 + 1);
+ Operand L = trunc(x, len);
+ L = zext(L, round);
+ Operand H = lshr(x, len);
+ H = trunc(H, round); // x = [H:L]
+ Operand t = add(L, H);
+ Operand t0 = lshr(t, len);
+ t0 = _and(t0, makeImm(round, 1));
+ t = add(t, t0);
+ t = trunc(t, len);
+ Operand z0 = zext(t, round);
+ t = extract(z0, n * unit);
+ Operand m = _or(t, makeImm(unit, mask));
+ for (uint32_t i = 0; i < n; i++) {
+ Operand s = extract(z0, unit * i);
+ m = _and(m, s);
+ }
+ Operand c = icmp(eq, m, makeImm(unit, -1));
+ Label zero("zero");
+ Label nonzero("nonzero");
+ br(c, zero, nonzero);
+ putLabel(zero);
+ for (uint32_t i = 0; i < n + 1; i++) {
+ storeN(makeImm(unit, 0), py, i);
+ }
+ ret(Void);
+ putLabel(nonzero);
+ storeN(z0, py);
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fp_sqr_NIST_P192()
+ {
+ resetGlobalIdx();
+ Operand py(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ mcl_fp_sqr_NIST_P192 = Function("mcl_fp_sqr_NIST_P192L" + suf, Void, py, px);
+ verifyAndSetPrivate(mcl_fp_sqr_NIST_P192);
+ beginFunc(mcl_fp_sqr_NIST_P192);
+ Operand buf = _alloca(unit, 192 * 2 / unit);
+ // QQQ define later
+ Function mcl_fpDbl_sqrPre("mcl_fpDbl_sqrPre" + cybozu::itoa(192 / unit) + "L" + suf, Void, buf, px);
+ call(mcl_fpDbl_sqrPre, buf, px);
+ call(mcl_fpDbl_mod_NIST_P192, py, buf);
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fp_mulNIST_P192()
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Function f("mcl_fp_mulNIST_P192L" + suf, Void, pz, px, py);
+ verifyAndSetPrivate(f);
+ beginFunc(f);
+ Operand buf = _alloca(unit, 192 * 2 / unit);
+ // QQQ define later
+ Function mcl_fpDbl_mulPre("mcl_fpDbl_mulPre" + cybozu::itoa(192 / unit) + "L" + suf, Void, buf, px, py);
+ call(mcl_fpDbl_mulPre, buf, px, py);
+ call(mcl_fpDbl_mod_NIST_P192, pz, buf);
+ ret(Void);
+ endFunc();
+ }
+ void gen_once()
+ {
+ gen_mulUU();
+ gen_extractHigh();
+ gen_mulPos();
+ gen_makeNIST_P192();
+ gen_mcl_fpDbl_mod_NIST_P192();
+ gen_mcl_fp_sqr_NIST_P192();
+ gen_mcl_fp_mulNIST_P192();
+ gen_mcl_fpDbl_mod_NIST_P521();
+ }
+ Operand extract(const Operand& x, uint32_t shift)
+ {
+ Operand t = lshr(x, shift);
+ t = trunc(t, unit);
+ return t;
+ }
+ void gen_mcl_fp_addsubPre(bool isAdd)
+ {
+ resetGlobalIdx();
+ Operand r(Int, unit);
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ std::string name;
+ if (isAdd) {
+ name = "mcl_fp_addPre" + cybozu::itoa(N) + "L" + suf;
+ mcl_fp_addPreM[N] = Function(name, r, pz, px, py);
+ verifyAndSetPrivate(mcl_fp_addPreM[N]);
+ beginFunc(mcl_fp_addPreM[N]);
+ } else {
+ name = "mcl_fp_subPre" + cybozu::itoa(N) + "L" + suf;
+ mcl_fp_subPreM[N] = Function(name, r, pz, px, py);
+ verifyAndSetPrivate(mcl_fp_subPreM[N]);
+ beginFunc(mcl_fp_subPreM[N]);
+ }
+ Operand x = zext(loadN(px, N), bit + unit);
+ Operand y = zext(loadN(py, N), bit + unit);
+ Operand z;
+ if (isAdd) {
+ z = add(x, y);
+ storeN(trunc(z, bit), pz);
+ r = trunc(lshr(z, bit), unit);
+ } else {
+ z = sub(x, y);
+ storeN(trunc(z, bit), pz);
+ r = _and(trunc(lshr(z, bit), unit), makeImm(unit, 1));
+ }
+ ret(r);
+ endFunc();
+ }
+#if 0 // void-return version
+ void gen_mcl_fp_addsubPre(bool isAdd)
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, bit);
+ Operand px(IntPtr, bit);
+ Operand py(IntPtr, bit);
+ std::string name;
+ if (isAdd) {
+ name = "mcl_fp_addPre" + cybozu::itoa(bit) + "L";
+ mcl_fp_addPreM[bit] = Function(name, Void, pz, px, py);
+ verifyAndSetPrivate(mcl_fp_addPreM[bit]);
+ beginFunc(mcl_fp_addPreM[bit]);
+ } else {
+ name = "mcl_fp_subPre" + cybozu::itoa(bit) + "L";
+ mcl_fp_subPreM[bit] = Function(name, Void, pz, px, py);
+ verifyAndSetPrivate(mcl_fp_subPreM[bit]);
+ beginFunc(mcl_fp_subPreM[bit]);
+ }
+ Operand x = load(px);
+ Operand y = load(py);
+ Operand z;
+ if (isAdd) {
+ z = add(x, y);
+ } else {
+ z = sub(x, y);
+ }
+ store(z, pz);
+ ret(Void);
+ endFunc();
+ }
+#endif
+ void gen_mcl_fp_shr1()
+ {
+ resetGlobalIdx();
+ Operand py(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ std::string name = "mcl_fp_shr1_" + cybozu::itoa(N) + "L" + suf;
+ mcl_fp_shr1_M[N] = Function(name, Void, py, px);
+ verifyAndSetPrivate(mcl_fp_shr1_M[N]);
+ beginFunc(mcl_fp_shr1_M[N]);
+ Operand x = loadN(px, N);
+ x = lshr(x, 1);
+ storeN(x, py);
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fp_add(bool isFullBit = true)
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ std::string name = "mcl_fp_add";
+ if (!isFullBit) {
+ name += "NF";
+ }
+ name += cybozu::itoa(N) + "L" + suf;
+ mcl_fp_addM[N] = Function(name, Void, pz, px, py, pp);
+ verifyAndSetPrivate(mcl_fp_addM[N]);
+ beginFunc(mcl_fp_addM[N]);
+ Operand x = loadN(px, N);
+ Operand y = loadN(py, N);
+ if (isFullBit) {
+ x = zext(x, bit + unit);
+ y = zext(y, bit + unit);
+ Operand t0 = add(x, y);
+ Operand t1 = trunc(t0, bit);
+ storeN(t1, pz);
+ Operand p = loadN(pp, N);
+ p = zext(p, bit + unit);
+ Operand vc = sub(t0, p);
+ Operand c = lshr(vc, bit);
+ c = trunc(c, 1);
+ Label carry("carry");
+ Label nocarry("nocarry");
+ br(c, carry, nocarry);
+ putLabel(nocarry);
+ storeN(trunc(vc, bit), pz);
+ ret(Void);
+ putLabel(carry);
+ } else {
+ x = add(x, y);
+ Operand p = loadN(pp, N);
+ y = sub(x, p);
+ Operand c = trunc(lshr(y, bit - 1), 1);
+ x = select(c, x, y);
+ storeN(x, pz);
+ }
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fp_sub(bool isFullBit = true)
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ std::string name = "mcl_fp_sub";
+ if (!isFullBit) {
+ name += "NF";
+ }
+ name += cybozu::itoa(N) + "L" + suf;
+ mcl_fp_subM[N] = Function(name, Void, pz, px, py, pp);
+ verifyAndSetPrivate(mcl_fp_subM[N]);
+ beginFunc(mcl_fp_subM[N]);
+ Operand x = loadN(px, N);
+ Operand y = loadN(py, N);
+ if (isFullBit) {
+ x = zext(x, bit + unit);
+ y = zext(y, bit + unit);
+ Operand vc = sub(x, y);
+ Operand v, c;
+ v = trunc(vc, bit);
+ c = lshr(vc, bit);
+ c = trunc(c, 1);
+ storeN(v, pz);
+ Label carry("carry");
+ Label nocarry("nocarry");
+ br(c, carry, nocarry);
+ putLabel(nocarry);
+ ret(Void);
+ putLabel(carry);
+ Operand p = loadN(pp, N);
+ Operand t = add(v, p);
+ storeN(t, pz);
+ } else {
+ Operand v = sub(x, y);
+ Operand c;
+ c = trunc(lshr(v, bit - 1), 1);
+ Operand p = loadN(pp, N);
+ c = select(c, p, makeImm(bit, 0));
+ Operand t = add(v, c);
+ storeN(t, pz);
+ }
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fpDbl_add()
+ {
+ // QQQ : generate unnecessary memory copy for large bit
+ const int bu = bit + unit;
+ const int b2 = bit * 2;
+ const int b2u = b2 + unit;
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ std::string name = "mcl_fpDbl_add" + cybozu::itoa(N) + "L" + suf;
+ Function f(name, Void, pz, px, py, pp);
+ verifyAndSetPrivate(f);
+ beginFunc(f);
+ Operand x = loadN(px, N * 2);
+ Operand y = loadN(py, N * 2);
+ x = zext(x, b2u);
+ y = zext(y, b2u);
+ Operand t = add(x, y); // x + y = [H:L]
+ Operand L = trunc(t, bit);
+ storeN(L, pz);
+
+ Operand H = lshr(t, bit);
+ H = trunc(H, bu);
+ Operand p = loadN(pp, N);
+ p = zext(p, bu);
+ Operand Hp = sub(H, p);
+ t = lshr(Hp, bit);
+ t = trunc(t, 1);
+ t = select(t, H, Hp);
+ t = trunc(t, bit);
+ storeN(t, pz, N);
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fpDbl_sub()
+ {
+ // QQQ : rol is used?
+ const int b2 = bit * 2;
+ const int b2u = b2 + unit;
+ resetGlobalIdx();
+ std::string name = "mcl_fpDbl_sub" + cybozu::itoa(N) + "L" + suf;
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ Function f(name, Void, pz, px, py, pp);
+ verifyAndSetPrivate(f);
+ beginFunc(f);
+ Operand x = loadN(px, N * 2);
+ Operand y = loadN(py, N * 2);
+ x = zext(x, b2u);
+ y = zext(y, b2u);
+ Operand vc = sub(x, y); // x - y = [H:L]
+ Operand L = trunc(vc, bit);
+ storeN(L, pz);
+
+ Operand H = lshr(vc, bit);
+ H = trunc(H, bit);
+ Operand c = lshr(vc, b2);
+ c = trunc(c, 1);
+ Operand p = loadN(pp, N);
+ c = select(c, p, makeImm(bit, 0));
+ Operand t = add(H, c);
+ storeN(t, pz, N);
+ ret(Void);
+ endFunc();
+ }
+ /*
+ return [px[n-1]:px[n-2]:...:px[0]]
+ */
+ Operand pack(const Operand *px, size_t n)
+ {
+ Operand x = px[0];
+ for (size_t i = 1; i < n; i++) {
+ Operand y = px[i];
+ size_t shift = x.bit;
+ size_t size = x.bit + y.bit;
+ x = zext(x, size);
+ y = zext(y, size);
+ y = shl(y, shift);
+ x = _or(x, y);
+ }
+ return x;
+ }
+ /*
+ z = px[0..N] * y
+ */
+ void gen_mulPv()
+ {
+ const int bu = bit + unit;
+ resetGlobalIdx();
+ Operand z(Int, bu);
+ Operand px(IntPtr, unit);
+ Operand y(Int, unit);
+ std::string name = "mulPv" + cybozu::itoa(bit) + "x" + cybozu::itoa(unit);
+ mulPvM[bit] = Function(name, z, px, y);
+ mulPvM[bit].setPrivate();
+ verifyAndSetPrivate(mulPvM[bit]);
+ beginFunc(mulPvM[bit]);
+ OperandVec L(N), H(N);
+ for (uint32_t i = 0; i < N; i++) {
+ Operand xy = call(mulPos, px, y, makeImm(unit, i));
+ L[i] = trunc(xy, unit);
+ H[i] = call(extractHigh, xy);
+ }
+ Operand LL = pack(&L[0], N);
+ Operand HH = pack(&H[0], N);
+ LL = zext(LL, bu);
+ HH = zext(HH, bu);
+ HH = shl(HH, unit);
+ z = add(LL, HH);
+ ret(z);
+ endFunc();
+ }
+ void gen_mcl_fp_mulUnitPre()
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand y(Int, unit);
+ std::string name = "mcl_fp_mulUnitPre" + cybozu::itoa(N) + "L" + suf;
+ mcl_fp_mulUnitPreM[N] = Function(name, Void, pz, px, y);
+ verifyAndSetPrivate(mcl_fp_mulUnitPreM[N]);
+ beginFunc(mcl_fp_mulUnitPreM[N]);
+ Operand z = call(mulPvM[bit], px, y);
+ storeN(z, pz);
+ ret(Void);
+ endFunc();
+ }
+ void generic_fpDbl_mul(const Operand& pz, const Operand& px, const Operand& py)
+ {
+ if (N == 1) {
+ Operand x = load(px);
+ Operand y = load(py);
+ x = zext(x, unit * 2);
+ y = zext(y, unit * 2);
+ Operand z = mul(x, y);
+ storeN(z, pz);
+ ret(Void);
+ } else if (N >= 8 && (N % 2) == 0) {
+ /*
+ W = 1 << half
+ (aW + b)(cW + d) = acW^2 + (ad + bc)W + bd
+ ad + bc = (a + b)(c + d) - ac - bd
+ */
+ const int H = N / 2;
+ const int half = bit / 2;
+ Operand pxW = getelementptr(px, H);
+ Operand pyW = getelementptr(py, H);
+ Operand pzWW = getelementptr(pz, N);
+ call(mcl_fpDbl_mulPreM[H], pz, px, py); // bd
+ call(mcl_fpDbl_mulPreM[H], pzWW, pxW, pyW); // ac
+
+ Operand a = zext(loadN(pxW, H), half + unit);
+ Operand b = zext(loadN(px, H), half + unit);
+ Operand c = zext(loadN(pyW, H), half + unit);
+ Operand d = zext(loadN(py, H), half + unit);
+ Operand t1 = add(a, b);
+ Operand t2 = add(c, d);
+ Operand buf = _alloca(unit, N);
+ Operand t1L = trunc(t1, half);
+ Operand t2L = trunc(t2, half);
+ Operand c1 = trunc(lshr(t1, half), 1);
+ Operand c2 = trunc(lshr(t2, half), 1);
+ Operand c0 = _and(c1, c2);
+ c1 = select(c1, t2L, makeImm(half, 0));
+ c2 = select(c2, t1L, makeImm(half, 0));
+ Operand buf1 = _alloca(unit, half / unit);
+ Operand buf2 = _alloca(unit, half / unit);
+ storeN(t1L, buf1);
+ storeN(t2L, buf2);
+ call(mcl_fpDbl_mulPreM[N / 2], buf, buf1, buf2);
+ Operand t = loadN(buf, N);
+ t = zext(t, bit + unit);
+ c0 = zext(c0, bit + unit);
+ c0 = shl(c0, bit);
+ t = _or(t, c0);
+ c1 = zext(c1, bit + unit);
+ c2 = zext(c2, bit + unit);
+ c1 = shl(c1, half);
+ c2 = shl(c2, half);
+ t = add(t, c1);
+ t = add(t, c2);
+ t = sub(t, zext(loadN(pz, N), bit + unit));
+ t = sub(t, zext(loadN(pz, N, N), bit + unit));
+ if (bit + half > t.bit) {
+ t = zext(t, bit + half);
+ }
+ t = add(t, loadN(pz, N + H, H));
+ storeN(t, pz, H);
+ ret(Void);
+ } else {
+ Operand y = load(py);
+ Operand xy = call(mulPvM[bit], px, y);
+ store(trunc(xy, unit), pz);
+ Operand t = lshr(xy, unit);
+ for (uint32_t i = 1; i < N; i++) {
+ y = loadN(py, 1, i);
+ xy = call(mulPvM[bit], px, y);
+ t = add(t, xy);
+ if (i < N - 1) {
+ storeN(trunc(t, unit), pz, i);
+ t = lshr(t, unit);
+ }
+ }
+ storeN(t, pz, N - 1);
+ ret(Void);
+ }
+ }
+ void gen_mcl_fpDbl_mulPre()
+ {
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ std::string name = "mcl_fpDbl_mulPre" + cybozu::itoa(N) + "L" + suf;
+ mcl_fpDbl_mulPreM[N] = Function(name, Void, pz, px, py);
+ verifyAndSetPrivate(mcl_fpDbl_mulPreM[N]);
+ beginFunc(mcl_fpDbl_mulPreM[N]);
+ generic_fpDbl_mul(pz, px, py);
+ endFunc();
+ }
+ void gen_mcl_fpDbl_sqrPre()
+ {
+ resetGlobalIdx();
+ Operand py(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ std::string name = "mcl_fpDbl_sqrPre" + cybozu::itoa(N) + "L" + suf;
+ mcl_fpDbl_sqrPreM[N] = Function(name, Void, py, px);
+ verifyAndSetPrivate(mcl_fpDbl_sqrPreM[N]);
+ beginFunc(mcl_fpDbl_sqrPreM[N]);
+ generic_fpDbl_mul(py, px, px);
+ endFunc();
+ }
+ void gen_mcl_fp_mont(bool isFullBit = true)
+ {
+ const int bu = bit + unit;
+ const int bu2 = bit + unit * 2;
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand px(IntPtr, unit);
+ Operand py(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ std::string name = "mcl_fp_mont";
+ if (!isFullBit) {
+ name += "NF";
+ }
+ name += cybozu::itoa(N) + "L" + suf;
+ mcl_fp_montM[N] = Function(name, Void, pz, px, py, pp);
+ mcl_fp_montM[N].setAlias();
+ verifyAndSetPrivate(mcl_fp_montM[N]);
+ beginFunc(mcl_fp_montM[N]);
+ Operand rp = load(getelementptr(pp, -1));
+ Operand z, s, a;
+ if (isFullBit) {
+ for (uint32_t i = 0; i < N; i++) {
+ Operand y = load(getelementptr(py, i));
+ Operand xy = call(mulPvM[bit], px, y);
+ Operand at;
+ if (i == 0) {
+ a = zext(xy, bu2);
+ at = trunc(xy, unit);
+ } else {
+ xy = zext(xy, bu2);
+ a = add(s, xy);
+ at = trunc(a, unit);
+ }
+ Operand q = mul(at, rp);
+ Operand pq = call(mulPvM[bit], pp, q);
+ pq = zext(pq, bu2);
+ Operand t = add(a, pq);
+ s = lshr(t, unit);
+ }
+ s = trunc(s, bu);
+ Operand p = zext(loadN(pp, N), bu);
+ Operand vc = sub(s, p);
+ Operand c = trunc(lshr(vc, bit), 1);
+ z = select(c, s, vc);
+ z = trunc(z, bit);
+ storeN(z, pz);
+ } else {
+ Operand y = load(py);
+ Operand xy = call(mulPvM[bit], px, y);
+ Operand c0 = trunc(xy, unit);
+ Operand q = mul(c0, rp);
+ Operand pq = call(mulPvM[bit], pp, q);
+ Operand t = add(xy, pq);
+ t = lshr(t, unit); // bu-bit
+ for (uint32_t i = 1; i < N; i++) {
+ y = load(getelementptr(py, i));
+ xy = call(mulPvM[bit], px, y);
+ t = add(t, xy);
+ c0 = trunc(t, unit);
+ q = mul(c0, rp);
+ pq = call(mulPvM[bit], pp, q);
+ t = add(t, pq);
+ t = lshr(t, unit);
+ }
+ t = trunc(t, bit);
+ Operand vc = sub(t, loadN(pp, N));
+ Operand c = trunc(lshr(vc, bit - 1), 1);
+ z = select(c, t, vc);
+ storeN(z, pz);
+ }
+ ret(Void);
+ endFunc();
+ }
+ void gen_mcl_fp_montRed()
+ {
+ const int bu = bit + unit;
+ const int b2 = bit * 2;
+ const int b2u = b2 + unit;
+ resetGlobalIdx();
+ Operand pz(IntPtr, unit);
+ Operand pxy(IntPtr, unit);
+ Operand pp(IntPtr, unit);
+ std::string name = "mcl_fp_montRed" + cybozu::itoa(N) + "L" + suf;
+ mcl_fp_montRedM[N] = Function(name, Void, pz, pxy, pp);
+ verifyAndSetPrivate(mcl_fp_montRedM[N]);
+ beginFunc(mcl_fp_montRedM[N]);
+ Operand rp = load(getelementptr(pp, -1));
+ Operand p = loadN(pp, N);
+ Operand xy = loadN(pxy, N * 2);
+ Operand t = zext(xy, b2 + unit);
+ Operand z;
+ for (uint32_t i = 0; i < N; i++) {
+ Operand z = trunc(t, unit);
+ Operand q = mul(z, rp);
+ Operand pq = call(mulPvM[bit], pp, q);
+ pq = zext(pq, b2u - unit * i);
+ z = add(t, pq);
+ z = lshr(z, unit);
+ t = trunc(z, b2 - unit * i);
+ }
+ p = zext(p, bu);
+ Operand vc = sub(t, p);
+ Operand c = trunc(lshr(vc, bit), 1);
+ z = select(c, t, vc);
+ z = trunc(z, bit);
+ storeN(z, pz);
+ ret(Void);
+ endFunc();
+ }
+ void gen_all()
+ {
+ gen_mcl_fp_addsubPre(true);
+ gen_mcl_fp_addsubPre(false);
+ gen_mcl_fp_shr1();
+ }
+ void gen_addsub()
+ {
+ gen_mcl_fp_add(true);
+ gen_mcl_fp_add(false);
+ gen_mcl_fp_sub(true);
+ gen_mcl_fp_sub(false);
+ gen_mcl_fpDbl_add();
+ gen_mcl_fpDbl_sub();
+ }
+ void gen_mul()
+ {
+ gen_mulPv();
+ gen_mcl_fp_mulUnitPre();
+ gen_mcl_fpDbl_mulPre();
+ gen_mcl_fpDbl_sqrPre();
+ gen_mcl_fp_mont(true);
+ gen_mcl_fp_mont(false);
+ gen_mcl_fp_montRed();
+ }
+ void setBit(uint32_t bit)
+ {
+ this->bit = bit;
+ N = bit / unit;
+ }
+ void setUnit(uint32_t unit)
+ {
+ this->unit = unit;
+ unit2 = unit * 2;
+ unitStr = cybozu::itoa(unit);
+ }
+ void gen(const StrSet& privateFuncList, uint32_t maxBitSize, const std::string& suf)
+ {
+ this->suf = suf;
+ this->privateFuncList = &privateFuncList;
+#ifdef FOR_WASM
+ gen_mulUU();
+#else
+ gen_once();
+ uint32_t end = ((maxBitSize + unit - 1) / unit);
+ for (uint32_t n = 1; n <= end; n++) {
+ setBit(n * unit);
+ gen_mul();
+ gen_all();
+ gen_addsub();
+ }
+ if (unit == 64 && maxBitSize == 768) {
+ for (uint32_t i = maxBitSize + unit * 2; i <= maxBitSize * 2; i += unit * 2) {
+ setBit(i);
+ gen_all();
+ }
+ }
+#endif
+ }
+};
+
+int main(int argc, char *argv[])
+ try
+{
+ uint32_t unit;
+ bool oldLLVM;
+ bool wasm;
+ std::string suf;
+ std::string privateFile;
+ cybozu::Option opt;
+ opt.appendOpt(&unit, uint32_t(sizeof(void*)) * 8, "u", ": unit");
+ opt.appendBoolOpt(&oldLLVM, "old", ": old LLVM(before 3.8)");
+ opt.appendBoolOpt(&wasm, "wasm", ": for wasm");
+ opt.appendOpt(&suf, "", "s", ": suffix of function name");
+ opt.appendOpt(&privateFile, "", "f", ": private function list file");
+ opt.appendHelp("h");
+ if (!opt.parse(argc, argv)) {
+ opt.usage();
+ return 1;
+ }
+ StrSet privateFuncList;
+ if (!privateFile.empty()) {
+ std::ifstream ifs(privateFile.c_str(), std::ios::binary);
+ std::string name;
+ while (ifs >> name) {
+ privateFuncList.insert(name);
+ }
+ }
+ Code c;
+ if (oldLLVM) {
+ c.setOldLLVM();
+ }
+ c.wasm = wasm;
+ c.setUnit(unit);
+ uint32_t maxBitSize = MCL_MAX_BIT_SIZE;
+ c.gen(privateFuncList, maxBitSize, suf);
+} catch (std::exception& e) {
+ printf("ERR %s\n", e.what());
+ return 1;
+}