/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.tls.crypto.impl;

import java.io.IOException;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.SecurityParameters;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCipher;
import org.bouncycastle.tls.crypto.TlsCryptoParameters;
import org.bouncycastle.tls.crypto.TlsCryptoUtils;
import org.bouncycastle.tls.crypto.TlsDecodeResult;
import org.bouncycastle.tls.crypto.TlsEncodeResult;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.tls.crypto.impl.AEADNonceGenerator;
import org.bouncycastle.tls.crypto.impl.BCFipsAEADNonceGenerator;
import org.bouncycastle.tls.crypto.impl.TlsAEADCipherImpl;
import org.bouncycastle.tls.crypto.impl.TlsImplUtils;
import org.bouncycastle.util.Arrays;

public class TlsAEADCipher
implements TlsCipher {
    public static final int AEAD_CCM = 1;
    public static final int AEAD_CHACHA20_POLY1305 = 2;
    public static final int AEAD_GCM = 3;
    private static final int NONCE_RFC5288 = 1;
    private static final int NONCE_RFC7905 = 2;
    public static final byte[] EPOCH_1 = new byte[]{0, 1};
    private static final Class fipsNonceGeneratorClass = TlsAEADCipher.lookup("org.bouncycastle.crypto.fips.FipsNonceGenerator");
    protected final TlsCryptoParameters cryptoParams;
    protected final int keySize;
    protected final int macSize;
    protected final int fixed_iv_length;
    protected final int record_iv_length;
    protected final TlsAEADCipherImpl decryptCipher;
    protected final TlsAEADCipherImpl encryptCipher;
    protected final byte[] decryptNonce;
    protected final byte[] encryptNonce;
    protected final boolean isTLSv13;
    protected final int nonceMode;
    protected final AEADNonceGenerator encryptNonceGenerator;

    public TlsAEADCipher(TlsCryptoParameters cryptoParams, TlsAEADCipherImpl encryptCipher, TlsAEADCipherImpl decryptCipher, int keySize, int macSize, int aeadType) throws IOException {
        SecurityParameters securityParameters = cryptoParams.getSecurityParametersHandshake();
        ProtocolVersion negotiatedVersion = securityParameters.getNegotiatedVersion();
        if (!TlsImplUtils.isTLSv12(negotiatedVersion)) {
            throw new TlsFatalAlert(80);
        }
        this.isTLSv13 = TlsImplUtils.isTLSv13(negotiatedVersion);
        this.nonceMode = TlsAEADCipher.getNonceMode(this.isTLSv13, aeadType);
        switch (this.nonceMode) {
            case 1: {
                this.fixed_iv_length = 4;
                this.record_iv_length = 8;
                break;
            }
            case 2: {
                this.fixed_iv_length = 12;
                this.record_iv_length = 0;
                break;
            }
            default: {
                throw new TlsFatalAlert(80);
            }
        }
        this.cryptoParams = cryptoParams;
        this.keySize = keySize;
        this.macSize = macSize;
        this.decryptCipher = decryptCipher;
        this.encryptCipher = encryptCipher;
        this.decryptNonce = new byte[this.fixed_iv_length];
        this.encryptNonce = new byte[this.fixed_iv_length];
        boolean isServer = cryptoParams.isServer();
        if (this.isTLSv13) {
            this.encryptNonceGenerator = null;
            this.rekeyCipher(securityParameters, decryptCipher, this.decryptNonce, !isServer);
            this.rekeyCipher(securityParameters, encryptCipher, this.encryptNonce, isServer);
            return;
        }
        int keyBlockSize = 2 * keySize + 2 * this.fixed_iv_length;
        byte[] keyBlock = TlsImplUtils.calculateKeyBlock(cryptoParams, keyBlockSize);
        int pos = 0;
        if (isServer) {
            decryptCipher.setKey(keyBlock, pos, keySize);
            encryptCipher.setKey(keyBlock, pos += keySize, keySize);
            System.arraycopy(keyBlock, pos += keySize, this.decryptNonce, 0, this.fixed_iv_length);
            System.arraycopy(keyBlock, pos += this.fixed_iv_length, this.encryptNonce, 0, this.fixed_iv_length);
            pos += this.fixed_iv_length;
        } else {
            encryptCipher.setKey(keyBlock, pos, keySize);
            decryptCipher.setKey(keyBlock, pos += keySize, keySize);
            System.arraycopy(keyBlock, pos += keySize, this.encryptNonce, 0, this.fixed_iv_length);
            System.arraycopy(keyBlock, pos += this.fixed_iv_length, this.decryptNonce, 0, this.fixed_iv_length);
            pos += this.fixed_iv_length;
        }
        if (keyBlockSize != pos) {
            throw new TlsFatalAlert(80);
        }
        int nonceLength = this.fixed_iv_length + this.record_iv_length;
        byte[] dummyNonce = new byte[nonceLength];
        dummyNonce[0] = ~this.encryptNonce[0];
        dummyNonce[1] = ~this.decryptNonce[1];
        encryptCipher.init(dummyNonce, macSize, null);
        decryptCipher.init(dummyNonce, macSize, null);
        if (3 == aeadType && null != fipsNonceGeneratorClass) {
            int counterBits = 64;
            byte[] baseNonce = Arrays.copyOf(this.encryptNonce, nonceLength);
            if (negotiatedVersion.isDTLS()) {
                counterBits = 48;
                int n = baseNonce.length - 8;
                baseNonce[n] = (byte)(baseNonce[n] ^ EPOCH_1[0]);
                int n2 = baseNonce.length - 7;
                baseNonce[n2] = (byte)(baseNonce[n2] ^ EPOCH_1[1]);
            }
            this.encryptNonceGenerator = new BCFipsAEADNonceGenerator(baseNonce, counterBits);
        } else {
            this.encryptNonceGenerator = null;
        }
    }

    public int getCiphertextDecodeLimit(int plaintextLimit) {
        return plaintextLimit + this.macSize + this.record_iv_length + (this.isTLSv13 ? 1 : 0);
    }

    public int getCiphertextEncodeLimit(int plaintextLength, int plaintextLimit) {
        int innerPlaintextLimit = plaintextLength;
        if (this.isTLSv13) {
            int maxPadding = 0;
            innerPlaintextLimit = 1 + Math.min(plaintextLimit, plaintextLength + maxPadding);
        }
        return innerPlaintextLimit + this.macSize + this.record_iv_length;
    }

    public int getPlaintextLimit(int ciphertextLimit) {
        return ciphertextLimit - this.macSize - this.record_iv_length - (this.isTLSv13 ? 1 : 0);
    }

    public TlsEncodeResult encodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, int headerAllocation, byte[] plaintext, int plaintextOffset, int plaintextLength) throws IOException {
        byte[] nonce = new byte[this.encryptNonce.length + this.record_iv_length];
        if (null != this.encryptNonceGenerator) {
            this.encryptNonceGenerator.generateNonce(nonce);
        } else {
            switch (this.nonceMode) {
                case 1: {
                    System.arraycopy(this.encryptNonce, 0, nonce, 0, this.encryptNonce.length);
                    TlsUtils.writeUint64(seqNo, nonce, this.encryptNonce.length);
                    break;
                }
                case 2: {
                    TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
                    for (int i = 0; i < this.encryptNonce.length; ++i) {
                        int n = i;
                        nonce[n] = (byte)(nonce[n] ^ this.encryptNonce[i]);
                    }
                    break;
                }
                default: {
                    throw new TlsFatalAlert(80);
                }
            }
        }
        int encryptionLength = this.encryptCipher.getOutputSize(plaintextLength + (this.isTLSv13 ? 1 : 0));
        int ciphertextLength = this.record_iv_length + encryptionLength;
        byte[] output = new byte[headerAllocation + ciphertextLength];
        int outputPos = headerAllocation;
        if (this.record_iv_length != 0) {
            System.arraycopy(nonce, nonce.length - this.record_iv_length, output, outputPos, this.record_iv_length);
            outputPos += this.record_iv_length;
        }
        short recordType = this.isTLSv13 ? (short)23 : (short)contentType;
        byte[] additionalData = this.getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);
        try {
            byte[] byArray;
            this.encryptCipher.init(nonce, this.macSize, additionalData);
            if (this.isTLSv13) {
                byte[] byArray2 = new byte[1];
                byArray = byArray2;
                byArray2[0] = (byte)contentType;
            } else {
                byArray = TlsUtils.EMPTY_BYTES;
            }
            byte[] extraInput = byArray;
            outputPos += this.encryptCipher.doFinal(plaintext, plaintextOffset, plaintextLength, extraInput, output, outputPos);
        }
        catch (Exception e) {
            throw new TlsFatalAlert(80, (Throwable)e);
        }
        if (outputPos != output.length) {
            throw new TlsFatalAlert(80);
        }
        return new TlsEncodeResult(output, 0, output.length, recordType);
    }

    public TlsDecodeResult decodeCiphertext(long seqNo, short recordType, ProtocolVersion recordVersion, byte[] ciphertext, int ciphertextOffset, int ciphertextLength) throws IOException {
        short contentType;
        int plaintextLength;
        int encryptionOffset;
        block10: {
            byte octet;
            int outputPos;
            if (this.getPlaintextLimit(ciphertextLength) < 0) {
                throw new TlsFatalAlert(50);
            }
            byte[] nonce = new byte[this.decryptNonce.length + this.record_iv_length];
            switch (this.nonceMode) {
                case 1: {
                    System.arraycopy(this.decryptNonce, 0, nonce, 0, this.decryptNonce.length);
                    System.arraycopy(ciphertext, ciphertextOffset, nonce, nonce.length - this.record_iv_length, this.record_iv_length);
                    break;
                }
                case 2: {
                    TlsUtils.writeUint64(seqNo, nonce, nonce.length - 8);
                    for (int i = 0; i < this.decryptNonce.length; ++i) {
                        int n = i;
                        nonce[n] = (byte)(nonce[n] ^ this.decryptNonce[i]);
                    }
                    break;
                }
                default: {
                    throw new TlsFatalAlert(80);
                }
            }
            encryptionOffset = ciphertextOffset + this.record_iv_length;
            int encryptionLength = ciphertextLength - this.record_iv_length;
            plaintextLength = this.decryptCipher.getOutputSize(encryptionLength);
            byte[] additionalData = this.getAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, plaintextLength);
            try {
                this.decryptCipher.init(nonce, this.macSize, additionalData);
                outputPos = this.decryptCipher.doFinal(ciphertext, encryptionOffset, encryptionLength, TlsUtils.EMPTY_BYTES, ciphertext, encryptionOffset);
            }
            catch (Exception e) {
                throw new TlsFatalAlert(20, (Throwable)e);
            }
            if (outputPos != plaintextLength) {
                throw new TlsFatalAlert(80);
            }
            contentType = recordType;
            if (!this.isTLSv13) break block10;
            int pos = plaintextLength;
            do {
                if (--pos >= 0) continue;
                throw new TlsFatalAlert(10);
            } while (0 == (octet = ciphertext[encryptionOffset + pos]));
            contentType = (short)(octet & 0xFF);
            plaintextLength = pos;
        }
        return new TlsDecodeResult(ciphertext, encryptionOffset, plaintextLength, contentType);
    }

    public void rekeyDecoder() throws IOException {
        this.rekeyCipher(this.cryptoParams.getSecurityParametersConnection(), this.decryptCipher, this.decryptNonce, !this.cryptoParams.isServer());
    }

    public void rekeyEncoder() throws IOException {
        this.rekeyCipher(this.cryptoParams.getSecurityParametersConnection(), this.encryptCipher, this.encryptNonce, this.cryptoParams.isServer());
    }

    public boolean usesOpaqueRecordType() {
        return this.isTLSv13;
    }

    protected byte[] getAdditionalData(long seqNo, short recordType, ProtocolVersion recordVersion, int ciphertextLength, int plaintextLength) throws IOException {
        if (this.isTLSv13) {
            byte[] additional_data = new byte[5];
            TlsUtils.writeUint8(recordType, additional_data, 0);
            TlsUtils.writeVersion(recordVersion, additional_data, 1);
            TlsUtils.writeUint16(ciphertextLength, additional_data, 3);
            return additional_data;
        }
        byte[] additional_data = new byte[13];
        TlsUtils.writeUint64(seqNo, additional_data, 0);
        TlsUtils.writeUint8(recordType, additional_data, 8);
        TlsUtils.writeVersion(recordVersion, additional_data, 9);
        TlsUtils.writeUint16(plaintextLength, additional_data, 11);
        return additional_data;
    }

    protected void rekeyCipher(SecurityParameters securityParameters, TlsAEADCipherImpl cipher, byte[] nonce, boolean serverSecret) throws IOException {
        TlsSecret secret;
        if (!this.isTLSv13) {
            throw new TlsFatalAlert(80);
        }
        TlsSecret tlsSecret = secret = serverSecret ? securityParameters.getTrafficSecretServer() : securityParameters.getTrafficSecretClient();
        if (null == secret) {
            throw new TlsFatalAlert(80);
        }
        this.setup13Cipher(cipher, nonce, secret, TlsCryptoUtils.getHash(securityParameters.getPRFHashAlgorithm()));
    }

    protected void setup13Cipher(TlsAEADCipherImpl cipher, byte[] nonce, TlsSecret secret, int cryptoHashAlgorithm) throws IOException {
        byte[] key = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "key", TlsUtils.EMPTY_BYTES, this.keySize).extract();
        byte[] iv = TlsCryptoUtils.hkdfExpandLabel(secret, cryptoHashAlgorithm, "iv", TlsUtils.EMPTY_BYTES, this.fixed_iv_length).extract();
        cipher.setKey(key, 0, this.keySize);
        System.arraycopy(iv, 0, nonce, 0, this.fixed_iv_length);
        iv[0] = (byte)(iv[0] ^ 0x80);
        cipher.init(iv, this.macSize, null);
    }

    private static int getNonceMode(boolean isTLSv13, int aeadType) throws IOException {
        switch (aeadType) {
            case 1: 
            case 3: {
                return isTLSv13 ? 2 : 1;
            }
            case 2: {
                return 2;
            }
        }
        throw new TlsFatalAlert(80);
    }

    private static Class lookup(String className) {
        try {
            Class<?> def = TlsAEADCipher.class.getClassLoader().loadClass(className);
            return def;
        }
        catch (Exception e) {
            return null;
        }
    }
}

