package org.apache.hadoop.hbase.security;

import java.io.IOException;
import java.nio.charset.Charset;
import java.security.PrivilegedExceptionAction;
import java.util.Random;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.directory.api.ldap.model.constants.JndiPropertyConstants;
import org.apache.hadoop.hbase.classification.InterfaceAudience;
import org.apache.hadoop.hbase.security.HBaseSaslRpcClient;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hive.io.netty.buffer.ByteBuf;
import org.apache.hive.io.netty.channel.Channel;
import org.apache.hive.io.netty.channel.ChannelDuplexHandler;
import org.apache.hive.io.netty.channel.ChannelFuture;
import org.apache.hive.io.netty.channel.ChannelFutureListener;
import org.apache.hive.io.netty.channel.ChannelHandlerContext;
import org.apache.hive.io.netty.channel.ChannelPromise;
import org.apache.hive.io.netty.util.concurrent.Future;
import org.apache.hive.io.netty.util.concurrent.GenericFutureListener;
import org.apache.hive.jdbc.Utils;
import org.apache.hive.org.apache.commons.logging.Log;
import org.apache.hive.org.apache.commons.logging.LogFactory;

@InterfaceAudience.Private
/* loaded from: input_file:org/apache/hadoop/hbase/security/SaslClientHandler.class */
public class SaslClientHandler extends ChannelDuplexHandler {
    public static final Log LOG = LogFactory.getLog(SaslClientHandler.class);
    private final boolean fallbackAllowed;
    private final UserGroupInformation ticket;
    private final SaslClient saslClient;
    private final SaslExceptionHandler exceptionHandler;
    private final SaslSuccessfulConnectHandler successfulConnectHandler;
    private byte[] saslToken;
    private boolean firstRead = true;
    private int retryCount = 0;
    private Random random;

    /* loaded from: input_file:org/apache/hadoop/hbase/security/SaslClientHandler$SaslExceptionHandler.class */
    public interface SaslExceptionHandler {
        void handle(int i, Random random, Throwable th);
    }

    /* loaded from: input_file:org/apache/hadoop/hbase/security/SaslClientHandler$SaslSuccessfulConnectHandler.class */
    public interface SaslSuccessfulConnectHandler {
        void onSuccess(Channel channel);
    }

    public SaslClientHandler(UserGroupInformation userGroupInformation, AuthMethod authMethod, Token<? extends TokenIdentifier> token, String str, boolean z, String str2, SaslExceptionHandler saslExceptionHandler, SaslSuccessfulConnectHandler saslSuccessfulConnectHandler) throws IOException {
        this.ticket = userGroupInformation;
        this.fallbackAllowed = z;
        this.exceptionHandler = saslExceptionHandler;
        this.successfulConnectHandler = saslSuccessfulConnectHandler;
        SaslUtil.initSaslProperties(str2);
        switch (authMethod) {
            case DIGEST:
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Creating SASL " + AuthMethod.DIGEST.getMechanismName() + " client to authenticate to service at " + token.getService());
                }
                this.saslClient = createDigestSaslClient(new String[]{AuthMethod.DIGEST.getMechanismName()}, "default", new HBaseSaslRpcClient.SaslClientCallbackHandler(token));
                break;
            case MAPRSASL:
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Creating SASL " + AuthMethod.MAPRSASL.getMechanismName() + " client to authenticate to service.");
                }
                this.saslClient = createMaprSaslClient(new String[]{AuthMethod.MAPRSASL.getMechanismName()});
                break;
            case KERBEROS:
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Creating SASL " + AuthMethod.KERBEROS.getMechanismName() + " client. Server's Kerberos principal name is " + str);
                }
                if (str != null && !str.isEmpty()) {
                    String[] splitKerberosName = SaslUtil.splitKerberosName(str);
                    if (splitKerberosName.length == 3) {
                        this.saslClient = createKerberosSaslClient(new String[]{AuthMethod.KERBEROS.getMechanismName()}, splitKerberosName[0], splitKerberosName[1]);
                        break;
                    } else {
                        throw new IOException("Kerberos principal does not have the expected format: " + str);
                    }
                } else {
                    throw new IOException("Failed to specify server's Kerberos principal name");
                }
            default:
                throw new IOException("Unknown authentication method " + authMethod);
        }
        if (this.saslClient == null) {
            throw new IOException("Unable to find SASL client implementation");
        }
    }

    protected SaslClient createDigestSaslClient(String[] strArr, String str, CallbackHandler callbackHandler) throws IOException {
        return Sasl.createSaslClient(strArr, (String) null, (String) null, str, SaslUtil.SASL_PROPS, callbackHandler);
    }

    protected SaslClient createKerberosSaslClient(String[] strArr, String str, String str2) throws IOException {
        return Sasl.createSaslClient(strArr, (String) null, str, str2, SaslUtil.SASL_PROPS, (CallbackHandler) null);
    }

    protected SaslClient createMaprSaslClient(String[] strArr) throws IOException {
        return Sasl.createSaslClient(strArr, (String) null, (String) null, (String) null, SaslUtil.SASL_PROPS, callbackArr -> {
            throw new UnsupportedCallbackException(callbackArr[0]);
        });
    }

    @Override // org.apache.hive.io.netty.channel.ChannelInboundHandlerAdapter, org.apache.hive.io.netty.channel.ChannelInboundHandler
    public void channelUnregistered(ChannelHandlerContext channelHandlerContext) throws Exception {
        this.saslClient.dispose();
    }

    private byte[] evaluateChallenge(final byte[] bArr) throws Exception {
        return (byte[]) this.ticket.doAs(new PrivilegedExceptionAction<byte[]>() { // from class: org.apache.hadoop.hbase.security.SaslClientHandler.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.security.PrivilegedExceptionAction
            public byte[] run() throws Exception {
                return SaslClientHandler.this.saslClient.evaluateChallenge(bArr);
            }
        });
    }

    @Override // org.apache.hive.io.netty.channel.ChannelHandlerAdapter, org.apache.hive.io.netty.channel.ChannelHandler
    public void handlerAdded(ChannelHandlerContext channelHandlerContext) throws Exception {
        this.saslToken = new byte[0];
        if (this.saslClient.hasInitialResponse()) {
            this.saslToken = evaluateChallenge(this.saslToken);
        }
        if (this.saslToken != null) {
            writeSaslToken(channelHandlerContext, this.saslToken);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Have sent token of size " + this.saslToken.length + " from initSASLContext.");
            }
        }
    }

    @Override // org.apache.hive.io.netty.channel.ChannelInboundHandlerAdapter, org.apache.hive.io.netty.channel.ChannelInboundHandler
    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        ByteBuf byteBuf = (ByteBuf) obj;
        if (this.saslClient.isComplete()) {
            try {
                int readInt = byteBuf.readInt();
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Actual length is " + readInt);
                }
                this.saslToken = new byte[readInt];
                byteBuf.readBytes(this.saslToken);
                try {
                    ByteBuf buffer = channelHandlerContext.channel().alloc().buffer(this.saslToken.length);
                    buffer.writeBytes(this.saslClient.unwrap(this.saslToken, 0, this.saslToken.length));
                    channelHandlerContext.fireChannelRead((Object) buffer);
                    return;
                } catch (SaslException e) {
                    try {
                        this.saslClient.dispose();
                    } catch (SaslException e2) {
                        LOG.debug("Ignoring SASL exception", e2);
                    }
                    throw e;
                }
            } catch (IndexOutOfBoundsException e3) {
                return;
            }
        }
        while (!this.saslClient.isComplete() && byteBuf.isReadable()) {
            readStatus(byteBuf);
            int readInt2 = byteBuf.readInt();
            if (this.firstRead) {
                this.firstRead = false;
                if (readInt2 == -88) {
                    if (!this.fallbackAllowed) {
                        throw new IOException("Server asks us to fall back to SIMPLE auth, but this client is configured to only allow secure connections.");
                    }
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("Server asks us to fall back to simple auth.");
                    }
                    this.saslClient.dispose();
                    channelHandlerContext.pipeline().remove(this);
                    this.successfulConnectHandler.onSuccess(channelHandlerContext.channel());
                    return;
                }
            }
            this.saslToken = new byte[readInt2];
            if (LOG.isDebugEnabled()) {
                LOG.debug("Will read input token of size " + this.saslToken.length + " for processing by initSASLContext");
            }
            byteBuf.readBytes(this.saslToken);
            this.saslToken = evaluateChallenge(this.saslToken);
            if (this.saslToken != null) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Will send token of size " + this.saslToken.length + " from initSASLContext.");
                }
                writeSaslToken(channelHandlerContext, this.saslToken);
            }
        }
        if (this.saslClient.isComplete()) {
            String str = (String) this.saslClient.getNegotiatedProperty(JndiPropertyConstants.JNDI_SASL_QOP);
            if (LOG.isDebugEnabled()) {
                LOG.debug("SASL client context established. Negotiated QoP: " + str);
            }
            if (!((str == null || Utils.JdbcConnectionParams.AUTH_TYPE.equalsIgnoreCase(str)) ? false : true)) {
                channelHandlerContext.pipeline().remove(this);
            }
            this.successfulConnectHandler.onSuccess(channelHandlerContext.channel());
        }
    }

    private void writeSaslToken(final ChannelHandlerContext channelHandlerContext, byte[] bArr) {
        ByteBuf buffer = channelHandlerContext.alloc().buffer(4 + bArr.length);
        buffer.writeInt(bArr.length);
        buffer.writeBytes(bArr, 0, bArr.length);
        channelHandlerContext.writeAndFlush(buffer).addListener2((GenericFutureListener<? extends Future<? super Void>>) new ChannelFutureListener() { // from class: org.apache.hadoop.hbase.security.SaslClientHandler.2
            @Override // org.apache.hive.io.netty.util.concurrent.GenericFutureListener
            public void operationComplete(ChannelFuture channelFuture) throws Exception {
                if (channelFuture.isSuccess()) {
                    return;
                }
                SaslClientHandler.this.exceptionCaught(channelHandlerContext, channelFuture.cause());
            }
        });
    }

    private static void readStatus(ByteBuf byteBuf) throws RemoteException {
        if (byteBuf.readInt() != SaslStatus.SUCCESS.state) {
            throw new RemoteException(byteBuf.toString(Charset.forName("UTF-8")), byteBuf.toString(Charset.forName("UTF-8")));
        }
    }

    @Override // org.apache.hive.io.netty.channel.ChannelInboundHandlerAdapter, org.apache.hive.io.netty.channel.ChannelHandlerAdapter, org.apache.hive.io.netty.channel.ChannelHandler, org.apache.hive.io.netty.channel.ChannelInboundHandler
    public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) throws Exception {
        this.saslClient.dispose();
        channelHandlerContext.close();
        if (this.random == null) {
            this.random = new Random();
        }
        SaslExceptionHandler saslExceptionHandler = this.exceptionHandler;
        int i = this.retryCount;
        this.retryCount = i + 1;
        saslExceptionHandler.handle(i, this.random, th);
    }

    @Override // org.apache.hive.io.netty.channel.ChannelDuplexHandler, org.apache.hive.io.netty.channel.ChannelOutboundHandler
    public void write(final ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
        if (!this.saslClient.isComplete()) {
            super.write(channelHandlerContext, obj, channelPromise);
            return;
        }
        ByteBuf byteBuf = (ByteBuf) obj;
        try {
            this.saslToken = this.saslClient.wrap(byteBuf.array(), byteBuf.readerIndex(), byteBuf.readableBytes());
        } catch (SaslException e) {
            try {
                this.saslClient.dispose();
            } catch (SaslException e2) {
                LOG.debug("Ignoring SASL exception", e2);
            }
            channelPromise.setFailure((Throwable) e);
        }
        if (this.saslToken != null) {
            ByteBuf buffer = channelHandlerContext.channel().alloc().buffer(4 + this.saslToken.length);
            buffer.writeInt(this.saslToken.length);
            buffer.writeBytes(this.saslToken, 0, this.saslToken.length);
            channelHandlerContext.write(buffer).addListener2((GenericFutureListener<? extends Future<? super Void>>) new ChannelFutureListener() { // from class: org.apache.hadoop.hbase.security.SaslClientHandler.3
                @Override // org.apache.hive.io.netty.util.concurrent.GenericFutureListener
                public void operationComplete(ChannelFuture channelFuture) throws Exception {
                    if (channelFuture.isSuccess()) {
                        return;
                    }
                    SaslClientHandler.this.exceptionCaught(channelHandlerContext, channelFuture.cause());
                }
            });
            this.saslToken = null;
        }
    }
}
