#include "llvm_gen.hpp" #include #include #include #include #include typedef std::set StrSet; struct Code : public mcl::Generator { typedef std::map FunctionMap; typedef std::vector OperandVec; Operand Void; uint32_t unit; uint32_t unit2; uint32_t bit; uint32_t N; const StrSet *privateFuncList; bool wasm; std::string suf; std::string unitStr; Function mulUU; Function mul32x32; // for WASM Function extractHigh; Function mulPos; Function makeNIST_P192; Function mcl_fpDbl_mod_NIST_P192; Function mcl_fp_sqr_NIST_P192; FunctionMap mcl_fp_shr1_M; FunctionMap mcl_fp_addPreM; FunctionMap mcl_fp_subPreM; FunctionMap mcl_fp_addM; FunctionMap mcl_fp_subM; FunctionMap mulPvM; FunctionMap mcl_fp_mulUnitPreM; FunctionMap mcl_fpDbl_mulPreM; FunctionMap mcl_fpDbl_sqrPreM; FunctionMap mcl_fp_montM; FunctionMap mcl_fp_montRedM; Code() : unit(0), unit2(0), bit(0), N(0), privateFuncList(0), wasm(false) { } void verifyAndSetPrivate(Function& f) { if (privateFuncList && privateFuncList->find(f.name) != privateFuncList->end()) { f.setPrivate(); } } void storeN(Operand r, Operand p, int offset = 0) { if (p.bit != unit) { throw cybozu::Exception("bad IntPtr size") << p.bit; } if (offset > 0) { p = getelementptr(p, offset); } if (r.bit == unit) { store(r, p); return; } const size_t n = r.bit / unit; for (size_t i = 0; i < n; i++) { store(trunc(r, unit), getelementptr(p, i)); if (i < n - 1) { r = lshr(r, unit); } } } Operand loadN(Operand p, size_t n, int offset = 0) { if (p.bit != unit) { throw cybozu::Exception("bad IntPtr size") << p.bit; } if (offset > 0) { p = getelementptr(p, offset); } Operand v = load(p); for (size_t i = 1; i < n; i++) { v = zext(v, v.bit + unit); Operand t = load(getelementptr(p, i)); t = zext(t, v.bit); t = shl(t, unit * i); v = _or(v, t); } return v; } void gen_mul32x32() { const int u = 32; resetGlobalIdx(); Operand z(Int, u * 2); Operand x(Int, u); Operand y(Int, u); mul32x32 = Function("mul32x32L", z, x, y); mul32x32.setPrivate(); verifyAndSetPrivate(mul32x32); beginFunc(mul32x32); x = zext(x, u * 2); y = zext(y, u * 2); z = mul(x, y); ret(z); endFunc(); } void gen_mul64x64(Operand& z, Operand& x, Operand& y) { Operand a = trunc(lshr(x, 32), 32); Operand b = trunc(x, 32); Operand c = trunc(lshr(y, 32), 32); Operand d = trunc(y, 32); Operand ad = call(mul32x32, a, d); Operand bd = call(mul32x32, b, d); bd = zext(bd, 96); ad = shl(zext(ad, 96), 32); ad = add(ad, bd); Operand ac = call(mul32x32, a, c); Operand bc = call(mul32x32, b, c); bc = zext(bc, 96); ac = shl(zext(ac, 96), 32); ac = add(ac, bc); ad = zext(ad, 128); ac = shl(zext(ac, 128), 32); z = add(ac, ad); } void gen_mulUU() { if (wasm) { gen_mul32x32(); } resetGlobalIdx(); Operand z(Int, unit2); Operand x(Int, unit); Operand y(Int, unit); std::string name = "mul"; name += unitStr + "x" + unitStr + "L"; mulUU = Function(name, z, x, y); mulUU.setPrivate(); verifyAndSetPrivate(mulUU); beginFunc(mulUU); if (wasm) { gen_mul64x64(z, x, y); } else { x = zext(x, unit2); y = zext(y, unit2); z = mul(x, y); } ret(z); endFunc(); } void gen_extractHigh() { resetGlobalIdx(); Operand z(Int, unit); Operand x(Int, unit2); std::string name = "extractHigh"; name += unitStr; extractHigh = Function(name, z, x); extractHigh.setPrivate(); beginFunc(extractHigh); x = lshr(x, unit); z = trunc(x, unit); ret(z); endFunc(); } void gen_mulPos() { resetGlobalIdx(); Operand xy(Int, unit2); Operand px(IntPtr, unit); Operand y(Int, unit); Operand i(Int, unit); std::string name = "mulPos"; name += unitStr + "x" + unitStr; mulPos = Function(name, xy, px, y, i); mulPos.setPrivate(); beginFunc(mulPos); Operand x = load(getelementptr(px, i)); xy = call(mulUU, x, y); ret(xy); endFunc(); } Operand extract192to64(const Operand& x, uint32_t shift) { Operand y = lshr(x, shift); y = trunc(y, 64); return y; } void gen_makeNIST_P192() { resetGlobalIdx(); Operand p(Int, 192); Operand p0(Int, 64); Operand p1(Int, 64); Operand p2(Int, 64); Operand _0 = makeImm(64, 0); Operand _1 = makeImm(64, 1); Operand _2 = makeImm(64, 2); makeNIST_P192 = Function("makeNIST_P192L" + suf, p); verifyAndSetPrivate(makeNIST_P192); beginFunc(makeNIST_P192); p0 = sub(_0, _1); p1 = sub(_0, _2); p2 = sub(_0, _1); p0 = zext(p0, 192); p1 = zext(p1, 192); p2 = zext(p2, 192); p1 = shl(p1, 64); p2 = shl(p2, 128); p = add(p0, p1); p = add(p, p2); ret(p); endFunc(); } /* NIST_P192 p = 0xfffffffffffffffffffffffffffffffeffffffffffffffff 0 1 2 ffffffffffffffff fffffffffffffffe ffffffffffffffff p = (1 << 192) - (1 << 64) - 1 (1 << 192) % p = (1 << 64) + 1 L : 192bit Hi: 64bit x = [H:L] = [H2:H1:H0:L] mod p x = L + H + (H << 64) = L + H + [H1:H0:0] + H2 + (H2 << 64) [e:t] = L + H + [H1:H0:H2] + [H2:0] ; 2bit(e) over y = t + e + (e << 64) if (y >= p) y -= p */ void gen_mcl_fpDbl_mod_NIST_P192() { resetGlobalIdx(); Operand out(IntPtr, unit); Operand px(IntPtr, unit); mcl_fpDbl_mod_NIST_P192 = Function("mcl_fpDbl_mod_NIST_P192L" + suf, Void, out, px); verifyAndSetPrivate(mcl_fpDbl_mod_NIST_P192); beginFunc(mcl_fpDbl_mod_NIST_P192); const int n = 192 / unit; Operand L = loadN(px, n); L = zext(L, 256); Operand H192 = loadN(px, n, n); Operand H = zext(H192, 256); Operand H10 = shl(H192, 64); H10 = zext(H10, 256); Operand H2 = extract192to64(H192, 128); H2 = zext(H2, 256); Operand H102 = _or(H10, H2); H2 = shl(H2, 64); Operand t = add(L, H); t = add(t, H102); t = add(t, H2); Operand e = lshr(t, 192); e = trunc(e, 64); e = zext(e, 256); Operand e2 = shl(e, 64); e = _or(e, e2); t = trunc(t, 192); t = zext(t, 256); Operand z = add(t, e); Operand p = call(makeNIST_P192); p = zext(p, 256); Operand zp = sub(z, p); Operand c = trunc(lshr(zp, 192), 1); z = trunc(select(c, z, zp), 192); storeN(z, out); ret(Void); endFunc(); } /* NIST_P521 p = (1 << 521) - 1 x = [H:L] x % p = (L + H) % p */ void gen_mcl_fpDbl_mod_NIST_P521() { resetGlobalIdx(); const uint32_t len = 521; const uint32_t n = len / unit; const uint32_t round = unit * (n + 1); const uint32_t rem = len - n * unit; const size_t mask = -(1 << rem); const Operand py(IntPtr, unit); const Operand px(IntPtr, unit); Function f("mcl_fpDbl_mod_NIST_P521L" + suf, Void, py, px); verifyAndSetPrivate(f); beginFunc(f); Operand x = loadN(px, n * 2 + 1); Operand L = trunc(x, len); L = zext(L, round); Operand H = lshr(x, len); H = trunc(H, round); // x = [H:L] Operand t = add(L, H); Operand t0 = lshr(t, len); t0 = _and(t0, makeImm(round, 1)); t = add(t, t0); t = trunc(t, len); Operand z0 = zext(t, round); t = extract(z0, n * unit); Operand m = _or(t, makeImm(unit, mask)); for (uint32_t i = 0; i < n; i++) { Operand s = extract(z0, unit * i); m = _and(m, s); } Operand c = icmp(eq, m, makeImm(unit, -1)); Label zero("zero"); Label nonzero("nonzero"); br(c, zero, nonzero); putLabel(zero); for (uint32_t i = 0; i < n + 1; i++) { storeN(makeImm(unit, 0), py, i); } ret(Void); putLabel(nonzero); storeN(z0, py); ret(Void); endFunc(); } void gen_mcl_fp_sqr_NIST_P192() { resetGlobalIdx(); Operand py(IntPtr, unit); Operand px(IntPtr, unit); mcl_fp_sqr_NIST_P192 = Function("mcl_fp_sqr_NIST_P192L" + suf, Void, py, px); verifyAndSetPrivate(mcl_fp_sqr_NIST_P192); beginFunc(mcl_fp_sqr_NIST_P192); Operand buf = _alloca(unit, 192 * 2 / unit); // QQQ define later Function mcl_fpDbl_sqrPre("mcl_fpDbl_sqrPre" + cybozu::itoa(192 / unit) + "L" + suf, Void, buf, px); call(mcl_fpDbl_sqrPre, buf, px); call(mcl_fpDbl_mod_NIST_P192, py, buf); ret(Void); endFunc(); } void gen_mcl_fp_mulNIST_P192() { resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Function f("mcl_fp_mulNIST_P192L" + suf, Void, pz, px, py); verifyAndSetPrivate(f); beginFunc(f); Operand buf = _alloca(unit, 192 * 2 / unit); // QQQ define later Function mcl_fpDbl_mulPre("mcl_fpDbl_mulPre" + cybozu::itoa(192 / unit) + "L" + suf, Void, buf, px, py); call(mcl_fpDbl_mulPre, buf, px, py); call(mcl_fpDbl_mod_NIST_P192, pz, buf); ret(Void); endFunc(); } void gen_once() { gen_mulUU(); gen_extractHigh(); gen_mulPos(); gen_makeNIST_P192(); gen_mcl_fpDbl_mod_NIST_P192(); gen_mcl_fp_sqr_NIST_P192(); gen_mcl_fp_mulNIST_P192(); gen_mcl_fpDbl_mod_NIST_P521(); } Operand extract(const Operand& x, uint32_t shift) { Operand t = lshr(x, shift); t = trunc(t, unit); return t; } void gen_mcl_fp_addsubPre(bool isAdd) { resetGlobalIdx(); Operand r(Int, unit); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); std::string name; if (isAdd) { name = "mcl_fp_addPre" + cybozu::itoa(N) + "L" + suf; mcl_fp_addPreM[N] = Function(name, r, pz, px, py); verifyAndSetPrivate(mcl_fp_addPreM[N]); beginFunc(mcl_fp_addPreM[N]); } else { name = "mcl_fp_subPre" + cybozu::itoa(N) + "L" + suf; mcl_fp_subPreM[N] = Function(name, r, pz, px, py); verifyAndSetPrivate(mcl_fp_subPreM[N]); beginFunc(mcl_fp_subPreM[N]); } Operand x = zext(loadN(px, N), bit + unit); Operand y = zext(loadN(py, N), bit + unit); Operand z; if (isAdd) { z = add(x, y); storeN(trunc(z, bit), pz); r = trunc(lshr(z, bit), unit); } else { z = sub(x, y); storeN(trunc(z, bit), pz); r = _and(trunc(lshr(z, bit), unit), makeImm(unit, 1)); } ret(r); endFunc(); } #if 0 // void-return version void gen_mcl_fp_addsubPre(bool isAdd) { resetGlobalIdx(); Operand pz(IntPtr, bit); Operand px(IntPtr, bit); Operand py(IntPtr, bit); std::string name; if (isAdd) { name = "mcl_fp_addPre" + cybozu::itoa(bit) + "L"; mcl_fp_addPreM[bit] = Function(name, Void, pz, px, py); verifyAndSetPrivate(mcl_fp_addPreM[bit]); beginFunc(mcl_fp_addPreM[bit]); } else { name = "mcl_fp_subPre" + cybozu::itoa(bit) + "L"; mcl_fp_subPreM[bit] = Function(name, Void, pz, px, py); verifyAndSetPrivate(mcl_fp_subPreM[bit]); beginFunc(mcl_fp_subPreM[bit]); } Operand x = load(px); Operand y = load(py); Operand z; if (isAdd) { z = add(x, y); } else { z = sub(x, y); } store(z, pz); ret(Void); endFunc(); } #endif void gen_mcl_fp_shr1() { resetGlobalIdx(); Operand py(IntPtr, unit); Operand px(IntPtr, unit); std::string name = "mcl_fp_shr1_" + cybozu::itoa(N) + "L" + suf; mcl_fp_shr1_M[N] = Function(name, Void, py, px); verifyAndSetPrivate(mcl_fp_shr1_M[N]); beginFunc(mcl_fp_shr1_M[N]); Operand x = loadN(px, N); x = lshr(x, 1); storeN(x, py); ret(Void); endFunc(); } void gen_mcl_fp_add(bool isFullBit = true) { resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Operand pp(IntPtr, unit); std::string name = "mcl_fp_add"; if (!isFullBit) { name += "NF"; } name += cybozu::itoa(N) + "L" + suf; mcl_fp_addM[N] = Function(name, Void, pz, px, py, pp); verifyAndSetPrivate(mcl_fp_addM[N]); beginFunc(mcl_fp_addM[N]); Operand x = loadN(px, N); Operand y = loadN(py, N); if (isFullBit) { x = zext(x, bit + unit); y = zext(y, bit + unit); Operand t0 = add(x, y); Operand t1 = trunc(t0, bit); storeN(t1, pz); Operand p = loadN(pp, N); p = zext(p, bit + unit); Operand vc = sub(t0, p); Operand c = lshr(vc, bit); c = trunc(c, 1); Label carry("carry"); Label nocarry("nocarry"); br(c, carry, nocarry); putLabel(nocarry); storeN(trunc(vc, bit), pz); ret(Void); putLabel(carry); } else { x = add(x, y); Operand p = loadN(pp, N); y = sub(x, p); Operand c = trunc(lshr(y, bit - 1), 1); x = select(c, x, y); storeN(x, pz); } ret(Void); endFunc(); } void gen_mcl_fp_sub(bool isFullBit = true) { resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Operand pp(IntPtr, unit); std::string name = "mcl_fp_sub"; if (!isFullBit) { name += "NF"; } name += cybozu::itoa(N) + "L" + suf; mcl_fp_subM[N] = Function(name, Void, pz, px, py, pp); verifyAndSetPrivate(mcl_fp_subM[N]); beginFunc(mcl_fp_subM[N]); Operand x = loadN(px, N); Operand y = loadN(py, N); if (isFullBit) { x = zext(x, bit + unit); y = zext(y, bit + unit); Operand vc = sub(x, y); Operand v, c; v = trunc(vc, bit); c = lshr(vc, bit); c = trunc(c, 1); storeN(v, pz); Label carry("carry"); Label nocarry("nocarry"); br(c, carry, nocarry); putLabel(nocarry); ret(Void); putLabel(carry); Operand p = loadN(pp, N); Operand t = add(v, p); storeN(t, pz); } else { Operand v = sub(x, y); Operand c; c = trunc(lshr(v, bit - 1), 1); Operand p = loadN(pp, N); c = select(c, p, makeImm(bit, 0)); Operand t = add(v, c); storeN(t, pz); } ret(Void); endFunc(); } void gen_mcl_fpDbl_add() { // QQQ : generate unnecessary memory copy for large bit const int bu = bit + unit; const int b2 = bit * 2; const int b2u = b2 + unit; resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Operand pp(IntPtr, unit); std::string name = "mcl_fpDbl_add" + cybozu::itoa(N) + "L" + suf; Function f(name, Void, pz, px, py, pp); verifyAndSetPrivate(f); beginFunc(f); Operand x = loadN(px, N * 2); Operand y = loadN(py, N * 2); x = zext(x, b2u); y = zext(y, b2u); Operand t = add(x, y); // x + y = [H:L] Operand L = trunc(t, bit); storeN(L, pz); Operand H = lshr(t, bit); H = trunc(H, bu); Operand p = loadN(pp, N); p = zext(p, bu); Operand Hp = sub(H, p); t = lshr(Hp, bit); t = trunc(t, 1); t = select(t, H, Hp); t = trunc(t, bit); storeN(t, pz, N); ret(Void); endFunc(); } void gen_mcl_fpDbl_sub() { // QQQ : rol is used? const int b2 = bit * 2; const int b2u = b2 + unit; resetGlobalIdx(); std::string name = "mcl_fpDbl_sub" + cybozu::itoa(N) + "L" + suf; Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Operand pp(IntPtr, unit); Function f(name, Void, pz, px, py, pp); verifyAndSetPrivate(f); beginFunc(f); Operand x = loadN(px, N * 2); Operand y = loadN(py, N * 2); x = zext(x, b2u); y = zext(y, b2u); Operand vc = sub(x, y); // x - y = [H:L] Operand L = trunc(vc, bit); storeN(L, pz); Operand H = lshr(vc, bit); H = trunc(H, bit); Operand c = lshr(vc, b2); c = trunc(c, 1); Operand p = loadN(pp, N); c = select(c, p, makeImm(bit, 0)); Operand t = add(H, c); storeN(t, pz, N); ret(Void); endFunc(); } /* return [px[n-1]:px[n-2]:...:px[0]] */ Operand pack(const Operand *px, size_t n) { Operand x = px[0]; for (size_t i = 1; i < n; i++) { Operand y = px[i]; size_t shift = x.bit; size_t size = x.bit + y.bit; x = zext(x, size); y = zext(y, size); y = shl(y, shift); x = _or(x, y); } return x; } /* z = px[0..N] * y */ void gen_mulPv() { const int bu = bit + unit; resetGlobalIdx(); Operand z(Int, bu); Operand px(IntPtr, unit); Operand y(Int, unit); std::string name = "mulPv" + cybozu::itoa(bit) + "x" + cybozu::itoa(unit); mulPvM[bit] = Function(name, z, px, y); mulPvM[bit].setPrivate(); verifyAndSetPrivate(mulPvM[bit]); beginFunc(mulPvM[bit]); OperandVec L(N), H(N); for (uint32_t i = 0; i < N; i++) { Operand xy = call(mulPos, px, y, makeImm(unit, i)); L[i] = trunc(xy, unit); H[i] = call(extractHigh, xy); } Operand LL = pack(&L[0], N); Operand HH = pack(&H[0], N); LL = zext(LL, bu); HH = zext(HH, bu); HH = shl(HH, unit); z = add(LL, HH); ret(z); endFunc(); } void gen_mcl_fp_mulUnitPre() { resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand y(Int, unit); std::string name = "mcl_fp_mulUnitPre" + cybozu::itoa(N) + "L" + suf; mcl_fp_mulUnitPreM[N] = Function(name, Void, pz, px, y); verifyAndSetPrivate(mcl_fp_mulUnitPreM[N]); beginFunc(mcl_fp_mulUnitPreM[N]); Operand z = call(mulPvM[bit], px, y); storeN(z, pz); ret(Void); endFunc(); } void generic_fpDbl_mul(const Operand& pz, const Operand& px, const Operand& py) { if (N == 1) { Operand x = load(px); Operand y = load(py); x = zext(x, unit * 2); y = zext(y, unit * 2); Operand z = mul(x, y); storeN(z, pz); ret(Void); } else if (N >= 8 && (N % 2) == 0) { /* W = 1 << half (aW + b)(cW + d) = acW^2 + (ad + bc)W + bd ad + bc = (a + b)(c + d) - ac - bd */ const int H = N / 2; const int half = bit / 2; Operand pxW = getelementptr(px, H); Operand pyW = getelementptr(py, H); Operand pzWW = getelementptr(pz, N); call(mcl_fpDbl_mulPreM[H], pz, px, py); // bd call(mcl_fpDbl_mulPreM[H], pzWW, pxW, pyW); // ac Operand a = zext(loadN(pxW, H), half + unit); Operand b = zext(loadN(px, H), half + unit); Operand c = zext(loadN(pyW, H), half + unit); Operand d = zext(loadN(py, H), half + unit); Operand t1 = add(a, b); Operand t2 = add(c, d); Operand buf = _alloca(unit, N); Operand t1L = trunc(t1, half); Operand t2L = trunc(t2, half); Operand c1 = trunc(lshr(t1, half), 1); Operand c2 = trunc(lshr(t2, half), 1); Operand c0 = _and(c1, c2); c1 = select(c1, t2L, makeImm(half, 0)); c2 = select(c2, t1L, makeImm(half, 0)); Operand buf1 = _alloca(unit, half / unit); Operand buf2 = _alloca(unit, half / unit); storeN(t1L, buf1); storeN(t2L, buf2); call(mcl_fpDbl_mulPreM[N / 2], buf, buf1, buf2); Operand t = loadN(buf, N); t = zext(t, bit + unit); c0 = zext(c0, bit + unit); c0 = shl(c0, bit); t = _or(t, c0); c1 = zext(c1, bit + unit); c2 = zext(c2, bit + unit); c1 = shl(c1, half); c2 = shl(c2, half); t = add(t, c1); t = add(t, c2); t = sub(t, zext(loadN(pz, N), bit + unit)); t = sub(t, zext(loadN(pz, N, N), bit + unit)); if (bit + half > t.bit) { t = zext(t, bit + half); } t = add(t, loadN(pz, N + H, H)); storeN(t, pz, H); ret(Void); } else { Operand y = load(py); Operand xy = call(mulPvM[bit], px, y); store(trunc(xy, unit), pz); Operand t = lshr(xy, unit); for (uint32_t i = 1; i < N; i++) { y = loadN(py, 1, i); xy = call(mulPvM[bit], px, y); t = add(t, xy); if (i < N - 1) { storeN(trunc(t, unit), pz, i); t = lshr(t, unit); } } storeN(t, pz, N - 1); ret(Void); } } void gen_mcl_fpDbl_mulPre() { resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); std::string name = "mcl_fpDbl_mulPre" + cybozu::itoa(N) + "L" + suf; mcl_fpDbl_mulPreM[N] = Function(name, Void, pz, px, py); verifyAndSetPrivate(mcl_fpDbl_mulPreM[N]); beginFunc(mcl_fpDbl_mulPreM[N]); generic_fpDbl_mul(pz, px, py); endFunc(); } void gen_mcl_fpDbl_sqrPre() { resetGlobalIdx(); Operand py(IntPtr, unit); Operand px(IntPtr, unit); std::string name = "mcl_fpDbl_sqrPre" + cybozu::itoa(N) + "L" + suf; mcl_fpDbl_sqrPreM[N] = Function(name, Void, py, px); verifyAndSetPrivate(mcl_fpDbl_sqrPreM[N]); beginFunc(mcl_fpDbl_sqrPreM[N]); generic_fpDbl_mul(py, px, px); endFunc(); } void gen_mcl_fp_mont(bool isFullBit = true) { const int bu = bit + unit; const int bu2 = bit + unit * 2; resetGlobalIdx(); Operand pz(IntPtr, unit); Operand px(IntPtr, unit); Operand py(IntPtr, unit); Operand pp(IntPtr, unit); std::string name = "mcl_fp_mont"; if (!isFullBit) { name += "NF"; } name += cybozu::itoa(N) + "L" + suf; mcl_fp_montM[N] = Function(name, Void, pz, px, py, pp); mcl_fp_montM[N].setAlias(); verifyAndSetPrivate(mcl_fp_montM[N]); beginFunc(mcl_fp_montM[N]); Operand rp = load(getelementptr(pp, -1)); Operand z, s, a; if (isFullBit) { for (uint32_t i = 0; i < N; i++) { Operand y = load(getelementptr(py, i)); Operand xy = call(mulPvM[bit], px, y); Operand at; if (i == 0) { a = zext(xy, bu2); at = trunc(xy, unit); } else { xy = zext(xy, bu2); a = add(s, xy); at = trunc(a, unit); } Operand q = mul(at, rp); Operand pq = call(mulPvM[bit], pp, q); pq = zext(pq, bu2); Operand t = add(a, pq); s = lshr(t, unit); } s = trunc(s, bu); Operand p = zext(loadN(pp, N), bu); Operand vc = sub(s, p); Operand c = trunc(lshr(vc, bit), 1); z = select(c, s, vc); z = trunc(z, bit); storeN(z, pz); } else { Operand y = load(py); Operand xy = call(mulPvM[bit], px, y); Operand c0 = trunc(xy, unit); Operand q = mul(c0, rp); Operand pq = call(mulPvM[bit], pp, q); Operand t = add(xy, pq); t = lshr(t, unit); // bu-bit for (uint32_t i = 1; i < N; i++) { y = load(getelementptr(py, i)); xy = call(mulPvM[bit], px, y); t = add(t, xy); c0 = trunc(t, unit); q = mul(c0, rp); pq = call(mulPvM[bit], pp, q); t = add(t, pq); t = lshr(t, unit); } t = trunc(t, bit); Operand vc = sub(t, loadN(pp, N)); Operand c = trunc(lshr(vc, bit - 1), 1); z = select(c, t, vc); storeN(z, pz); } ret(Void); endFunc(); } void gen_mcl_fp_montRed() { const int bu = bit + unit; const int b2 = bit * 2; const int b2u = b2 + unit; resetGlobalIdx(); Operand pz(IntPtr, unit); Operand pxy(IntPtr, unit); Operand pp(IntPtr, unit); std::string name = "mcl_fp_montRed" + cybozu::itoa(N) + "L" + suf; mcl_fp_montRedM[N] = Function(name, Void, pz, pxy, pp); verifyAndSetPrivate(mcl_fp_montRedM[N]); beginFunc(mcl_fp_montRedM[N]); Operand rp = load(getelementptr(pp, -1)); Operand p = loadN(pp, N); Operand xy = loadN(pxy, N * 2); Operand t = zext(xy, b2 + unit); Operand z; for (uint32_t i = 0; i < N; i++) { Operand z = trunc(t, unit); Operand q = mul(z, rp); Operand pq = call(mulPvM[bit], pp, q); pq = zext(pq, b2u - unit * i); z = add(t, pq); z = lshr(z, unit); t = trunc(z, b2 - unit * i); } p = zext(p, bu); Operand vc = sub(t, p); Operand c = trunc(lshr(vc, bit), 1); z = select(c, t, vc); z = trunc(z, bit); storeN(z, pz); ret(Void); endFunc(); } void gen_all() { gen_mcl_fp_addsubPre(true); gen_mcl_fp_addsubPre(false); gen_mcl_fp_shr1(); } void gen_addsub() { gen_mcl_fp_add(true); gen_mcl_fp_add(false); gen_mcl_fp_sub(true); gen_mcl_fp_sub(false); gen_mcl_fpDbl_add(); gen_mcl_fpDbl_sub(); } void gen_mul() { gen_mulPv(); gen_mcl_fp_mulUnitPre(); gen_mcl_fpDbl_mulPre(); gen_mcl_fpDbl_sqrPre(); gen_mcl_fp_mont(true); gen_mcl_fp_mont(false); gen_mcl_fp_montRed(); } void setBit(uint32_t bit) { this->bit = bit; N = bit / unit; } void setUnit(uint32_t unit) { this->unit = unit; unit2 = unit * 2; unitStr = cybozu::itoa(unit); } void gen(const StrSet& privateFuncList, uint32_t maxBitSize, const std::string& suf) { this->suf = suf; this->privateFuncList = &privateFuncList; #ifdef FOR_WASM gen_mulUU(); #else gen_once(); uint32_t end = ((maxBitSize + unit - 1) / unit); for (uint32_t n = 1; n <= end; n++) { setBit(n * unit); gen_mul(); gen_all(); gen_addsub(); } if (unit == 64 && maxBitSize == 768) { for (uint32_t i = maxBitSize + unit * 2; i <= maxBitSize * 2; i += unit * 2) { setBit(i); gen_all(); } } #endif } }; int main(int argc, char *argv[]) try { uint32_t unit; bool oldLLVM; bool wasm; std::string suf; std::string privateFile; cybozu::Option opt; opt.appendOpt(&unit, uint32_t(sizeof(void*)) * 8, "u", ": unit"); opt.appendBoolOpt(&oldLLVM, "old", ": old LLVM(before 3.8)"); opt.appendBoolOpt(&wasm, "wasm", ": for wasm"); opt.appendOpt(&suf, "", "s", ": suffix of function name"); opt.appendOpt(&privateFile, "", "f", ": private function list file"); opt.appendHelp("h"); if (!opt.parse(argc, argv)) { opt.usage(); return 1; } StrSet privateFuncList; if (!privateFile.empty()) { std::ifstream ifs(privateFile.c_str(), std::ios::binary); std::string name; while (ifs >> name) { privateFuncList.insert(name); } } Code c; if (oldLLVM) { c.setOldLLVM(); } c.wasm = wasm; c.setUnit(unit); uint32_t maxBitSize = MCL_MAX_BIT_SIZE; c.gen(privateFuncList, maxBitSize, suf); } catch (std::exception& e) { printf("ERR %s\n", e.what()); return 1; }