/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.security.util.crypto;

import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.security.util.crypto.AESKeyedCipherProvider;
import org.apache.nifi.security.util.crypto.CipherUtility;
import org.apache.nifi.security.util.crypto.KeyedCipherProvider;
import org.apache.nifi.security.util.crypto.RandomIVPBECipherProvider;
import org.apache.nifi.security.util.crypto.ScryptSecureHasher;
import org.apache.nifi.security.util.crypto.scrypt.Scrypt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScryptCipherProvider
extends RandomIVPBECipherProvider {
    private static final Logger logger = LoggerFactory.getLogger(ScryptCipherProvider.class);
    private final int n;
    private final int r;
    private final int p;
    private static final int DEFAULT_N = Double.valueOf(Math.pow(2.0, 14.0)).intValue();
    private static final int DEFAULT_R = 8;
    private static final int DEFAULT_P = 1;
    private static final Pattern SCRYPT_SALT_FORMAT = Pattern.compile("^\\$s0\\$[a-f0-9]{5,16}\\$[\\w\\/\\+]{12,44}");
    private static final Pattern MCRYPT_SALT_FORMAT = Pattern.compile("^\\$\\d+\\$\\d+\\$\\d+\\$[a-f0-9]{16,64}");

    public ScryptCipherProvider() {
        this(DEFAULT_N, 8, 1);
    }

    public ScryptCipherProvider(int n, int r, int p) {
        this.n = n;
        this.r = r;
        this.p = p;
        if (n < DEFAULT_N) {
            logger.warn("The provided iteration count {} is below the recommended minimum {}", (Object)n, (Object)DEFAULT_N);
        }
        if (r < 8) {
            logger.warn("The provided block size {} is below the recommended minimum {}", (Object)r, (Object)8);
        }
        if (p < 1) {
            logger.warn("The provided parallelization factor {} is below the recommended minimum {}", (Object)p, (Object)1);
        }
        if (!ScryptCipherProvider.isPValid(r, p)) {
            logger.warn("Based on the provided block size {}, the provided parallelization factor {} is out of bounds", (Object)r, (Object)p);
            throw new IllegalArgumentException("Invalid p value exceeds p boundary");
        }
    }

    public static boolean isPValid(int r, int p) {
        if (!ScryptCipherProvider.isRValid(r)) {
            logger.warn("The provided block size {} must be greater than 0", (Object)r);
            throw new IllegalArgumentException("Invalid r value; must be greater than 0");
        }
        double pBoundary = (Math.pow(2.0, 32.0) - 1.0) * (32.0 / (double)(r * 128));
        return (double)p <= pBoundary && p > 0;
    }

    public static boolean isRValid(int r) {
        return r > 0;
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode) throws Exception {
        try {
            return this.getInitializedCipher(encryptionMethod, password, salt, iv, keyLength, encryptMode);
        }
        catch (IllegalArgumentException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ProcessException("Error initializing the cipher", (Throwable)e);
        }
    }

    @Override
    Logger getLogger() {
        return logger;
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, int keyLength, boolean encryptMode) throws Exception {
        return this.getCipher(encryptionMethod, password, salt, new byte[0], keyLength, encryptMode);
    }

    protected Cipher getInitializedCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode) throws Exception {
        int p;
        int r;
        int n;
        if (encryptionMethod == null) {
            throw new IllegalArgumentException("The encryption method must be specified");
        }
        if (!encryptionMethod.isCompatibleWithStrongKDFs()) {
            throw new IllegalArgumentException(encryptionMethod.name() + " is not compatible with Scrypt");
        }
        if (StringUtils.isEmpty((CharSequence)password)) {
            throw new IllegalArgumentException("Encryption with an empty password is not supported");
        }
        String algorithm = encryptionMethod.getAlgorithm();
        String cipherName = CipherUtility.parseCipherFromAlgorithm(algorithm);
        if (!CipherUtility.isValidKeyLength(keyLength, cipherName)) {
            throw new IllegalArgumentException(String.valueOf(keyLength) + " is not a valid key length for " + cipherName);
        }
        String saltString = new String(salt, StandardCharsets.UTF_8);
        byte[] rawSalt = new byte[this.getDefaultSaltLength()];
        if (ScryptCipherProvider.isScryptFormattedSalt(saltString)) {
            ArrayList<Integer> params = new ArrayList<Integer>(3);
            this.parseSalt(saltString, rawSalt, params);
            n = (Integer)params.get(0);
            r = (Integer)params.get(1);
            p = (Integer)params.get(2);
        } else {
            rawSalt = salt;
            n = this.getN();
            r = this.getR();
            p = this.getP();
        }
        ScryptSecureHasher scryptSecureHasher = new ScryptSecureHasher(n, r, p, keyLength / 8);
        try {
            byte[] keyBytes = scryptSecureHasher.hashRaw(password.getBytes(StandardCharsets.UTF_8), rawSalt);
            SecretKeySpec tempKey = new SecretKeySpec(keyBytes, algorithm);
            AESKeyedCipherProvider keyedCipherProvider = new AESKeyedCipherProvider();
            return ((KeyedCipherProvider)keyedCipherProvider).getCipher(encryptionMethod, tempKey, iv, encryptMode);
        }
        catch (IllegalArgumentException e) {
            if (e.getMessage().contains("The salt length")) {
                throw new IllegalArgumentException("The raw salt must be greater than or equal to 8 bytes", e);
            }
            logger.error("Encountered an error generating the Scrypt hash", (Throwable)e);
            throw e;
        }
    }

    public static byte[] extractRawSaltFromScryptSalt(String scryptSalt) {
        String[] saltComponents = scryptSalt.split("\\$");
        if (saltComponents.length < 4) {
            throw new IllegalArgumentException("Could not parse salt");
        }
        return Base64.decodeBase64((String)saltComponents[3]);
    }

    public static boolean isScryptFormattedSalt(String salt) {
        if (salt == null || salt.length() == 0) {
            throw new IllegalArgumentException("The salt cannot be empty. To generate a salt, use ScryptCipherProvider#generateSalt()");
        }
        Matcher matcher = SCRYPT_SALT_FORMAT.matcher(salt);
        return matcher.find();
    }

    private void parseSalt(String scryptSalt, byte[] rawSalt, List<Integer> params) {
        if (StringUtils.isEmpty((CharSequence)scryptSalt)) {
            throw new IllegalArgumentException("Cannot parse empty salt");
        }
        byte[] salt = ScryptCipherProvider.extractRawSaltFromScryptSalt(scryptSalt);
        if (rawSalt.length < salt.length) {
            byte[] tempBytes = new byte[salt.length];
            System.arraycopy(rawSalt, 0, tempBytes, 0, rawSalt.length);
            rawSalt = tempBytes;
        }
        System.arraycopy(salt, 0, rawSalt, 0, salt.length);
        if (params == null) {
            params = new ArrayList<Integer>(3);
        }
        String[] saltComponents = scryptSalt.split("\\$");
        params.addAll(Scrypt.parseParameters(saltComponents[2]));
    }

    public String formatSaltForScrypt(byte[] salt) {
        String saltString = new String(salt, StandardCharsets.UTF_8);
        if (ScryptCipherProvider.isScryptFormattedSalt(saltString)) {
            return saltString;
        }
        return ScryptCipherProvider.formatSaltForScrypt(salt, this.getN(), this.getR(), this.getP());
    }

    public static String formatSaltForScrypt(byte[] salt, int n, int r, int p) {
        String saltString = new String(salt, StandardCharsets.UTF_8);
        if (ScryptCipherProvider.isScryptFormattedSalt(saltString)) {
            return saltString;
        }
        if (saltString.startsWith("$")) {
            logger.warn("Salt starts with $ but is not valid scrypt salt");
            Matcher matcher = MCRYPT_SALT_FORMAT.matcher(saltString);
            if (matcher.find()) {
                logger.warn("The salt appears to be of the modified mcrypt format. Use ScryptCipherProvider#translateSalt(mcryptSalt) to form a valid salt");
                return ScryptCipherProvider.translateSalt(saltString);
            }
            logger.info("Salt is not modified mcrypt format");
        }
        logger.info("Treating as raw salt bytes");
        int saltLength = salt.length;
        if (saltLength < 8 || saltLength > 32) {
            throw new IllegalArgumentException("The raw salt must be between 8 and 32 bytes");
        }
        return Scrypt.formatSalt(salt, n, r, p);
    }

    public static String translateSalt(String mcryptSalt) {
        if (StringUtils.isEmpty((CharSequence)mcryptSalt)) {
            throw new IllegalArgumentException("Cannot translate empty salt");
        }
        Matcher matcher = MCRYPT_SALT_FORMAT.matcher(mcryptSalt);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Salt is not valid mcrypt format of $n$r$p$saltHex");
        }
        String[] components = mcryptSalt.split("\\$");
        try {
            return Scrypt.formatSalt(Hex.decodeHex((char[])components[4].toCharArray()), Integer.parseInt(components[1]), Integer.parseInt(components[2]), Integer.parseInt(components[3]));
        }
        catch (DecoderException e) {
            String msg = "Mcrypt salt was not properly hex-encoded";
            logger.warn("Mcrypt salt was not properly hex-encoded");
            throw new IllegalArgumentException("Mcrypt salt was not properly hex-encoded");
        }
    }

    @Override
    public byte[] generateSalt() {
        byte[] salt = new byte[Scrypt.getDefaultSaltLength()];
        new SecureRandom().nextBytes(salt);
        return Scrypt.formatSalt(salt, this.n, this.r, this.p).getBytes(StandardCharsets.UTF_8);
    }

    @Override
    public int getDefaultSaltLength() {
        return Scrypt.getDefaultSaltLength();
    }

    protected int getN() {
        return this.n;
    }

    protected int getR() {
        return this.r;
    }

    protected int getP() {
        return this.p;
    }
}

