/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.tls.crypto.impl.jcajce;

import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyFactory;
import java.security.KeyPairGenerator;
import java.security.PublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.KeySpec;
import java.security.spec.X509EncodedKeySpec;
import javax.crypto.KeyGenerator;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.jcajce.util.JcaJceHelper;
import org.bouncycastle.jcajce.util.NamedJcaJceHelper;
import org.bouncycastle.pqc.jcajce.interfaces.MLKEMPublicKey;
import org.bouncycastle.pqc.jcajce.spec.MLKEMParameterSpec;
import org.bouncycastle.pqc.jcajce.spec.MLKEMPublicKeySpec;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCrypto;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider;

/*
 * Multiple versions of this class in jar - see https://www.benf.org/other/cfr/multi-version-jar.html
 */
class KemUtil {
    static JcaTlsCrypto kemCrypto = null;
    static final ASN1ObjectIdentifier nistAlgorithm = new ASN1ObjectIdentifier("2.16.840.1.101.3.4");
    static final ASN1ObjectIdentifier kems = nistAlgorithm.branch("4");
    static final ASN1ObjectIdentifier id_alg_ml_kem_512 = kems.branch("1");
    static final ASN1ObjectIdentifier id_alg_ml_kem_768 = kems.branch("2");
    static final ASN1ObjectIdentifier id_alg_ml_kem_1024 = kems.branch("3");

    KemUtil() {
    }

    private static synchronized void makeKemCrypto(JcaTlsCrypto crypto) {
        if (kemCrypto == null) {
            try {
                kemCrypto = new JcaTlsCrypto((JcaJceHelper)new NamedJcaJceHelper("BCPQC"), crypto.getSecureRandom(), new JcaTlsCryptoProvider.NonceEntropySource(crypto.getHelper(), crypto.getSecureRandom()));
            }
            catch (GeneralSecurityException e) {
                throw new IllegalStateException("cannot use passed in crypto with KEM");
            }
            catch (Exception e) {
                throw new IllegalStateException("BCPQC provider not present");
            }
        }
    }

    static PublicKey decodePublicKey(JcaTlsCrypto crypto, String kemName, byte[] encoding) throws TlsFatalAlert {
        try {
            KemUtil.makeKemCrypto(crypto);
            KeyFactory kf = kemCrypto.getHelper().createKeyFactory(kemName);
            if (kf.getProvider().getName().equals("BCPQC")) {
                try {
                    MLKEMParameterSpec params = MLKEMParameterSpec.fromName((String)kemName);
                    MLKEMPublicKeySpec keySpec = new MLKEMPublicKeySpec(params, encoding);
                    return kf.generatePublic((KeySpec)keySpec);
                }
                catch (Exception params) {
                    // empty catch block
                }
            }
            X509EncodedKeySpec keySpec = KemUtil.createX509EncodedKeySpec(KemUtil.getAlgorithmOID(kemName), encoding);
            return kf.generatePublic(keySpec);
        }
        catch (Exception e) {
            throw new TlsFatalAlert(47, (Throwable)e);
        }
    }

    static byte[] encodePublicKey(PublicKey publicKey) throws TlsFatalAlert {
        if (publicKey instanceof MLKEMPublicKey) {
            return ((MLKEMPublicKey)publicKey).getPublicData();
        }
        if (!"X.509".equals(publicKey.getFormat())) {
            throw new TlsFatalAlert(80, "Public key format unrecognized");
        }
        try {
            SubjectPublicKeyInfo spki = SubjectPublicKeyInfo.getInstance((Object)publicKey.getEncoded());
            return spki.getPublicKeyData().getOctets();
        }
        catch (Exception e) {
            throw new TlsFatalAlert(80, (Throwable)e);
        }
    }

    static KeyFactory getKeyFactory(JcaTlsCrypto crypto, String kemName) {
        try {
            KemUtil.makeKemCrypto(crypto);
            return kemCrypto.getHelper().createKeyFactory(kemName);
        }
        catch (AssertionError assertionError) {
        }
        catch (Exception exception) {
            // empty catch block
        }
        return null;
    }

    static KeyGenerator getKeyGenerator(JcaTlsCrypto crypto, String kemName) {
        try {
            KemUtil.makeKemCrypto(crypto);
            return kemCrypto.getHelper().createKeyGenerator(kemName);
        }
        catch (AssertionError assertionError) {
        }
        catch (Exception exception) {
            // empty catch block
        }
        return null;
    }

    static KeyPairGenerator getKeyPairGenerator(JcaTlsCrypto crypto, String kemName) {
        try {
            KemUtil.makeKemCrypto(crypto);
            KeyPairGenerator keyPairGenerator = kemCrypto.getHelper().createKeyPairGenerator("ML-KEM");
            keyPairGenerator.initialize((AlgorithmParameterSpec)MLKEMParameterSpec.fromName((String)kemName), kemCrypto.getSecureRandom());
            return keyPairGenerator;
        }
        catch (AssertionError assertionError) {
        }
        catch (Exception exception) {
            // empty catch block
        }
        return null;
    }

    static boolean isKemSupported(JcaTlsCrypto crypto, String kemName) {
        return kemName != null && KemUtil.getKeyFactory(crypto, kemName) != null && KemUtil.getKeyGenerator(crypto, kemName) != null && KemUtil.getKeyPairGenerator(crypto, kemName) != null;
    }

    private static X509EncodedKeySpec createX509EncodedKeySpec(ASN1ObjectIdentifier oid, byte[] encoding) throws IOException {
        AlgorithmIdentifier algID = new AlgorithmIdentifier(oid);
        SubjectPublicKeyInfo spki = new SubjectPublicKeyInfo(algID, encoding);
        return new X509EncodedKeySpec(spki.getEncoded("DER"));
    }

    private static ASN1ObjectIdentifier getAlgorithmOID(String kemName) {
        if ("ML-KEM-512".equalsIgnoreCase(kemName)) {
            return id_alg_ml_kem_512;
        }
        if ("ML-KEM-768".equalsIgnoreCase(kemName)) {
            return id_alg_ml_kem_768;
        }
        if ("ML-KEM-1024".equalsIgnoreCase(kemName)) {
            return id_alg_ml_kem_1024;
        }
        throw new IllegalArgumentException("unknown kem name " + kemName);
    }
}

