/*
 * Decompiled with CFR 0.152.
 */
package cn.org.bjca.gaia.pqc.crypto.ntru;

import cn.org.bjca.gaia.crypto.CipherParameters;
import cn.org.bjca.gaia.crypto.Digest;
import cn.org.bjca.gaia.pqc.crypto.ntru.NTRUSignerPrng;
import cn.org.bjca.gaia.pqc.crypto.ntru.NTRUSigningParameters;
import cn.org.bjca.gaia.pqc.crypto.ntru.NTRUSigningPrivateKeyParameters;
import cn.org.bjca.gaia.pqc.crypto.ntru.NTRUSigningPublicKeyParameters;
import cn.org.bjca.gaia.pqc.math.ntru.polynomial.IntegerPolynomial;
import cn.org.bjca.gaia.pqc.math.ntru.polynomial.Polynomial;
import java.nio.ByteBuffer;

public class NTRUSigner {
    private NTRUSigningParameters params;
    private Digest hashAlg;
    private NTRUSigningPrivateKeyParameters signingKeyPair;
    private NTRUSigningPublicKeyParameters verificationKey;

    public NTRUSigner(NTRUSigningParameters params) {
        this.params = params;
    }

    public void init(boolean forSigning, CipherParameters params) {
        if (forSigning) {
            this.signingKeyPair = (NTRUSigningPrivateKeyParameters)params;
        } else {
            this.verificationKey = (NTRUSigningPublicKeyParameters)params;
        }
        this.hashAlg = this.params.hashAlg;
        this.hashAlg.reset();
    }

    public void update(byte b) {
        if (this.hashAlg == null) {
            throw new IllegalStateException("Call initSign or initVerify first!");
        }
        this.hashAlg.update(b);
    }

    public void update(byte[] m3, int off, int length) {
        if (this.hashAlg == null) {
            throw new IllegalStateException("Call initSign or initVerify first!");
        }
        this.hashAlg.update(m3, off, length);
    }

    public byte[] generateSignature() {
        if (this.hashAlg == null || this.signingKeyPair == null) {
            throw new IllegalStateException("Call initSign first!");
        }
        byte[] msgHash = new byte[this.hashAlg.getDigestSize()];
        this.hashAlg.doFinal(msgHash, 0);
        return this.signHash(msgHash, this.signingKeyPair);
    }

    private byte[] signHash(byte[] msgHash, NTRUSigningPrivateKeyParameters kp) {
        IntegerPolynomial s2;
        IntegerPolynomial i;
        int r = 0;
        NTRUSigningPublicKeyParameters kPub = kp.getPublicKey();
        do {
            if (++r <= this.params.signFailTolerance) continue;
            throw new IllegalStateException("Signing failed: too many retries (max=" + this.params.signFailTolerance + ")");
        } while (!this.verify(i = this.createMsgRep(msgHash, r), s2 = this.sign(i, kp), kPub.h));
        byte[] rawSig = s2.toBinary(this.params.q);
        ByteBuffer sbuf = ByteBuffer.allocate(rawSig.length + 4);
        sbuf.put(rawSig);
        sbuf.putInt(r);
        return sbuf.array();
    }

    private IntegerPolynomial sign(IntegerPolynomial i, NTRUSigningPrivateKeyParameters kp) {
        IntegerPolynomial x;
        IntegerPolynomial y;
        Polynomial fPrime;
        Polynomial f;
        int N = this.params.N;
        int q = this.params.q;
        int perturbationBases = this.params.B;
        NTRUSigningPrivateKeyParameters kPriv = kp;
        NTRUSigningPublicKeyParameters kPub = kp.getPublicKey();
        IntegerPolynomial s2 = new IntegerPolynomial(N);
        for (int iLoop = perturbationBases; iLoop >= 1; --iLoop) {
            f = kPriv.getBasis((int)iLoop).f;
            fPrime = kPriv.getBasis((int)iLoop).fPrime;
            y = f.mult(i);
            y.div(q);
            y = fPrime.mult(y);
            x = fPrime.mult(i);
            x.div(q);
            x = f.mult(x);
            IntegerPolynomial si = y;
            si.sub(x);
            s2.add(si);
            IntegerPolynomial hi = (IntegerPolynomial)kPriv.getBasis((int)iLoop).h.clone();
            if (iLoop > 1) {
                hi.sub(kPriv.getBasis((int)(iLoop - 1)).h);
            } else {
                hi.sub(kPub.h);
            }
            i = si.mult(hi, q);
        }
        f = kPriv.getBasis((int)0).f;
        fPrime = kPriv.getBasis((int)0).fPrime;
        y = f.mult(i);
        y.div(q);
        y = fPrime.mult(y);
        x = fPrime.mult(i);
        x.div(q);
        x = f.mult(x);
        y.sub(x);
        s2.add(y);
        s2.modPositive(q);
        return s2;
    }

    public boolean verifySignature(byte[] sig) {
        if (this.hashAlg == null || this.verificationKey == null) {
            throw new IllegalStateException("Call initVerify first!");
        }
        byte[] msgHash = new byte[this.hashAlg.getDigestSize()];
        this.hashAlg.doFinal(msgHash, 0);
        return this.verifyHash(msgHash, sig, this.verificationKey);
    }

    private boolean verifyHash(byte[] msgHash, byte[] sig, NTRUSigningPublicKeyParameters pub) {
        ByteBuffer sbuf = ByteBuffer.wrap(sig);
        byte[] rawSig = new byte[sig.length - 4];
        sbuf.get(rawSig);
        IntegerPolynomial s2 = IntegerPolynomial.fromBinary(rawSig, this.params.N, this.params.q);
        int r = sbuf.getInt();
        return this.verify(this.createMsgRep(msgHash, r), s2, pub.h);
    }

    private boolean verify(IntegerPolynomial i, IntegerPolynomial s2, IntegerPolynomial h2) {
        int q = this.params.q;
        double normBoundSq = this.params.normBoundSq;
        double betaSq = this.params.betaSq;
        IntegerPolynomial t = h2.mult(s2, q);
        t.sub(i);
        long centeredNormSq = (long)((double)s2.centeredNormSq(q) + betaSq * (double)t.centeredNormSq(q));
        return (double)centeredNormSq <= normBoundSq;
    }

    protected IntegerPolynomial createMsgRep(byte[] msgHash, int r) {
        int N = this.params.N;
        int q = this.params.q;
        int c = 31 - Integer.numberOfLeadingZeros(q);
        int B = (c + 7) / 8;
        IntegerPolynomial i = new IntegerPolynomial(N);
        ByteBuffer cbuf = ByteBuffer.allocate(msgHash.length + 4);
        cbuf.put(msgHash);
        cbuf.putInt(r);
        NTRUSignerPrng prng = new NTRUSignerPrng(cbuf.array(), this.params.hashAlg);
        for (int t = 0; t < N; ++t) {
            byte[] o = prng.nextBytes(B);
            int hi = o[o.length - 1];
            hi >>= 8 * B - c;
            o[o.length - 1] = (byte)(hi <<= 8 * B - c);
            ByteBuffer obuf = ByteBuffer.allocate(4);
            obuf.put(o);
            obuf.rewind();
            i.coeffs[t] = Integer.reverseBytes(obuf.getInt());
        }
        return i;
    }
}

