/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.security;

import com.google.common.annotations.VisibleForTesting;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FilterInputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import java.util.regex.Pattern;
import javax.security.auth.kerberos.KerberosPrincipal;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.GlobPattern;
import org.apache.hadoop.ipc.ProtobufRpcEngine;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.RpcConstants;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos;
import org.apache.hadoop.security.AccessControlException;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SaslPropertiesResolver;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authentication.util.KerberosName;
import org.apache.hadoop.security.rpcauth.RpcAuthMethod;
import org.apache.hadoop.security.rpcauth.RpcAuthRegistry;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.util.ProtoUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.LimitedPrivate(value={"HDFS", "MapReduce"})
@InterfaceStability.Evolving
public class SaslRpcClient {
    public static final Logger LOG = LoggerFactory.getLogger(SaslRpcClient.class);
    private final UserGroupInformation ugi;
    private final Class<?> protocol;
    private final InetSocketAddress serverAddr;
    private final Configuration conf;
    private SaslClient saslClient;
    private SaslPropertiesResolver saslPropsResolver;
    private RpcAuthMethod authMethod;
    private static final RpcHeaderProtos.RpcRequestHeaderProto saslHeader = ProtoUtil.makeRpcRequestHeader(RPC.RpcKind.RPC_PROTOCOL_BUFFER, RpcHeaderProtos.RpcRequestHeaderProto.OperationProto.RPC_FINAL_PACKET, Server.AuthProtocol.SASL.callId, -1, RpcConstants.DUMMY_CLIENT_ID);
    private static final RpcHeaderProtos.RpcSaslProto negotiateRequest = RpcHeaderProtos.RpcSaslProto.newBuilder().setState(RpcHeaderProtos.RpcSaslProto.SaslState.NEGOTIATE).build();

    public SaslRpcClient(UserGroupInformation ugi, Class<?> protocol, InetSocketAddress serverAddr, Configuration conf) {
        this.ugi = ugi;
        this.protocol = protocol;
        this.serverAddr = serverAddr;
        this.conf = conf;
        this.saslPropsResolver = SaslPropertiesResolver.getInstance(conf);
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    public Object getNegotiatedProperty(String key) {
        return this.saslClient != null ? this.saslClient.getNegotiatedProperty(key) : null;
    }

    @InterfaceAudience.Private
    public RpcAuthMethod getAuthMethod() {
        return this.authMethod;
    }

    private RpcHeaderProtos.RpcSaslProto.SaslAuth selectSaslClient(List<RpcHeaderProtos.RpcSaslProto.SaslAuth> authTypes) throws SaslException, AccessControlException, IOException {
        RpcHeaderProtos.RpcSaslProto.SaslAuth selectedAuthType = null;
        boolean switchToSimple = false;
        for (RpcHeaderProtos.RpcSaslProto.SaslAuth authType : authTypes) {
            if (!this.isValidAuthType(authType)) continue;
            RpcAuthMethod authMethod = RpcAuthRegistry.getAuthMethod(authType.getMethod());
            if (authMethod.equals(RpcAuthRegistry.SIMPLE)) {
                switchToSimple = true;
            } else {
                this.saslClient = this.createSaslClient(authType);
                if (this.saslClient == null) continue;
            }
            selectedAuthType = authType;
            break;
        }
        if (this.saslClient == null && !switchToSimple) {
            ArrayList<String> serverAuthMethods = new ArrayList<String>();
            for (RpcHeaderProtos.RpcSaslProto.SaslAuth authType : authTypes) {
                serverAuthMethods.add(authType.getMethod());
            }
            throw new AccessControlException("Client cannot authenticate via:" + serverAuthMethods);
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Use " + selectedAuthType.getMethod() + " authentication for protocol " + this.protocol.getSimpleName());
        }
        return selectedAuthType;
    }

    private boolean isValidAuthType(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) {
        RpcAuthMethod authMethod = RpcAuthRegistry.getAuthMethod(authType.getMethod());
        return authMethod != null && authMethod.getMechanismName().equals(authType.getMechanism());
    }

    private SaslClient createSaslClient(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws SaslException, IOException {
        TreeMap<String, Object> saslProperties = new TreeMap<String, Object>(this.saslPropsResolver.getClientProperties(this.serverAddr.getAddress()));
        RpcAuthMethod method = RpcAuthRegistry.getAuthMethod(authType.getMethod());
        switch (method.getAuthenticationMethod()) {
            case TOKEN: {
                Token<?> token = this.getServerToken(authType);
                if (token == null) {
                    return null;
                }
                saslProperties.put("org.apache.hadoop.auth.token", token);
                break;
            }
            case KERBEROS: {
                if (this.ugi.getRealAuthenticationMethod() != UserGroupInformation.AuthenticationMethod.KERBEROS) {
                    return null;
                }
                String serverPrincipal = this.getServerPrincipal(authType);
                if (serverPrincipal == null) {
                    return null;
                }
                if (LOG.isDebugEnabled()) {
                    LOG.debug("RPC Server's Kerberos principal name for protocol=" + this.protocol.getCanonicalName() + " is " + serverPrincipal);
                }
                saslProperties.put("org.apache.hadoop.auth.kerberos.principal", serverPrincipal);
                break;
            }
        }
        String mechanism = method.getMechanismName();
        if (LOG.isDebugEnabled()) {
            LOG.debug("Creating SASL " + mechanism + "(" + method + ")  client to authenticate to service at " + authType.getServerId());
        }
        return method.createSaslClient(saslProperties);
    }

    private Token<?> getServerToken(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws IOException {
        TokenInfo tokenInfo = SecurityUtil.getTokenInfo(this.protocol, this.conf);
        LOG.debug("Get token info proto:" + this.protocol + " info:" + tokenInfo);
        if (tokenInfo == null) {
            return null;
        }
        TokenSelector<? extends TokenIdentifier> tokenSelector = null;
        try {
            tokenSelector = tokenInfo.value().newInstance();
        }
        catch (InstantiationException e) {
            throw new IOException(e.toString());
        }
        catch (IllegalAccessException e) {
            throw new IOException(e.toString());
        }
        return tokenSelector.selectToken(SecurityUtil.buildTokenService(this.serverAddr), this.ugi.getTokens());
    }

    @VisibleForTesting
    String getServerPrincipal(RpcHeaderProtos.RpcSaslProto.SaslAuth authType) throws IOException {
        KerberosInfo krbInfo = SecurityUtil.getKerberosInfo(this.protocol, this.conf);
        LOG.debug("Get kerberos info proto:" + this.protocol + " info:" + krbInfo);
        if (krbInfo == null) {
            return null;
        }
        String serverKey = krbInfo.serverPrincipal();
        if (serverKey == null) {
            throw new IllegalArgumentException("Can't obtain server Kerberos config key from protocol=" + this.protocol.getCanonicalName());
        }
        String serverPrincipal = new KerberosPrincipal(authType.getProtocol() + "/" + authType.getServerId(), 3).getName();
        boolean isPrincipalValid = false;
        String serverKeyPattern = this.conf.get(serverKey + ".pattern");
        if (serverKeyPattern != null && !serverKeyPattern.isEmpty()) {
            Pattern pattern = GlobPattern.compile(serverKeyPattern);
            isPrincipalValid = pattern.matcher(serverPrincipal).matches();
        } else {
            String confPrincipal = SecurityUtil.getServerPrincipal(this.conf.get(serverKey), this.serverAddr.getAddress());
            if (LOG.isDebugEnabled()) {
                LOG.debug("getting serverKey: " + serverKey + " conf value: " + this.conf.get(serverKey) + " principal: " + confPrincipal);
            }
            if (confPrincipal == null || confPrincipal.isEmpty()) {
                throw new IllegalArgumentException("Failed to specify server's Kerberos principal name");
            }
            KerberosName name = new KerberosName(confPrincipal);
            if (name.getHostName() == null) {
                throw new IllegalArgumentException("Kerberos principal name does NOT have the expected hostname part: " + confPrincipal);
            }
            isPrincipalValid = serverPrincipal.equals(confPrincipal);
        }
        if (!isPrincipalValid) {
            throw new IllegalArgumentException("Server has invalid Kerberos principal: " + serverPrincipal);
        }
        return serverPrincipal;
    }

    public RpcAuthMethod saslConnect(InputStream inS, OutputStream outS) throws IOException {
        DataInputStream inStream = new DataInputStream(new BufferedInputStream(inS));
        DataOutputStream outStream = new DataOutputStream(new BufferedOutputStream(outS));
        this.authMethod = RpcAuthRegistry.SIMPLE;
        this.sendSaslMessage(outStream, negotiateRequest);
        boolean done = false;
        do {
            int totalLen = inStream.readInt();
            ProtobufRpcEngine.RpcResponseMessageWrapper responseWrapper = new ProtobufRpcEngine.RpcResponseMessageWrapper();
            responseWrapper.readFields(inStream);
            RpcHeaderProtos.RpcResponseHeaderProto header = (RpcHeaderProtos.RpcResponseHeaderProto)responseWrapper.getMessageHeader();
            switch (header.getStatus()) {
                case ERROR: 
                case FATAL: {
                    throw new RemoteException(header.getExceptionClassName(), header.getErrorMsg());
                }
            }
            if (totalLen != responseWrapper.getLength()) {
                throw new SaslException("Received malformed response length");
            }
            if (header.getCallId() != Server.AuthProtocol.SASL.callId) {
                throw new SaslException("Non-SASL response during negotiation");
            }
            RpcHeaderProtos.RpcSaslProto saslMessage = RpcHeaderProtos.RpcSaslProto.parseFrom(responseWrapper.getMessageBytes());
            if (LOG.isDebugEnabled()) {
                LOG.debug("Received SASL message " + saslMessage);
            }
            RpcHeaderProtos.RpcSaslProto.Builder response = null;
            switch (saslMessage.getState()) {
                case NEGOTIATE: {
                    RpcHeaderProtos.RpcSaslProto.SaslAuth saslAuthType = this.selectSaslClient(saslMessage.getAuthsList());
                    this.authMethod = RpcAuthRegistry.getAuthMethod(saslAuthType.getMethod());
                    byte[] responseToken = null;
                    if (this.authMethod.equals(RpcAuthRegistry.SIMPLE)) {
                        done = true;
                    } else {
                        byte[] challengeToken = null;
                        if (saslAuthType.hasChallenge()) {
                            challengeToken = saslAuthType.getChallenge().toByteArray();
                            saslAuthType = RpcHeaderProtos.RpcSaslProto.SaslAuth.newBuilder(saslAuthType).clearChallenge().build();
                        } else if (this.saslClient.hasInitialResponse()) {
                            challengeToken = new byte[]{};
                        }
                        responseToken = challengeToken != null ? this.saslClient.evaluateChallenge(challengeToken) : new byte[]{};
                    }
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.INITIATE, responseToken);
                    response.addAuths(saslAuthType);
                    break;
                }
                case CHALLENGE: {
                    if (this.saslClient == null) {
                        throw new SaslException("Server sent unsolicited challenge");
                    }
                    byte[] responseToken = this.saslEvaluateToken(saslMessage, false);
                    response = this.createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState.RESPONSE, responseToken);
                    break;
                }
                case SUCCESS: {
                    if (this.saslClient == null) {
                        this.authMethod = RpcAuthRegistry.SIMPLE;
                    } else {
                        this.saslEvaluateToken(saslMessage, true);
                    }
                    done = true;
                    break;
                }
                default: {
                    throw new SaslException("RPC client doesn't support SASL " + saslMessage.getState());
                }
            }
            if (response == null) continue;
            this.sendSaslMessage(outStream, response.build());
        } while (!done);
        return this.authMethod;
    }

    private void sendSaslMessage(DataOutputStream out, RpcHeaderProtos.RpcSaslProto message) throws IOException {
        if (LOG.isDebugEnabled()) {
            LOG.debug("Sending sasl message " + message);
        }
        ProtobufRpcEngine.RpcRequestMessageWrapper request = new ProtobufRpcEngine.RpcRequestMessageWrapper(saslHeader, (Message)message);
        out.writeInt(request.getLength());
        request.write(out);
        out.flush();
    }

    private byte[] saslEvaluateToken(RpcHeaderProtos.RpcSaslProto saslResponse, boolean serverIsDone) throws SaslException {
        byte[] saslToken = null;
        if (saslResponse.hasToken()) {
            saslToken = saslResponse.getToken().toByteArray();
            saslToken = this.saslClient.evaluateChallenge(saslToken);
        } else if (!serverIsDone) {
            throw new SaslException("Server challenge contains no token");
        }
        if (serverIsDone) {
            if (!this.saslClient.isComplete()) {
                throw new SaslException("Client is out of sync with server");
            }
            if (saslToken != null) {
                throw new SaslException("Client generated spurious response");
            }
        }
        return saslToken;
    }

    private RpcHeaderProtos.RpcSaslProto.Builder createSaslReply(RpcHeaderProtos.RpcSaslProto.SaslState state, byte[] responseToken) {
        RpcHeaderProtos.RpcSaslProto.Builder response = RpcHeaderProtos.RpcSaslProto.newBuilder();
        response.setState(state);
        if (responseToken != null) {
            response.setToken(ByteString.copyFrom((byte[])responseToken));
        }
        return response;
    }

    private boolean useWrap() {
        String qop = (String)this.saslClient.getNegotiatedProperty("javax.security.sasl.qop");
        if (LOG.isDebugEnabled()) {
            LOG.debug("QOP supported by " + this.saslClient + ": " + qop);
        }
        return qop != null && !"auth".equalsIgnoreCase(qop);
    }

    public InputStream getInputStream(InputStream in) throws IOException {
        if (this.useWrap()) {
            in = new WrappedInputStream(in);
        }
        return in;
    }

    public OutputStream getOutputStream(OutputStream out) throws IOException {
        if (this.useWrap()) {
            String maxBuf = (String)this.saslClient.getNegotiatedProperty("javax.security.sasl.rawsendsize");
            out = new BufferedOutputStream(new WrappedOutputStream(out), Integer.parseInt(maxBuf));
        }
        return out;
    }

    public void dispose() throws SaslException {
        if (this.saslClient != null) {
            this.saslClient.dispose();
            this.saslClient = null;
        }
    }

    class WrappedOutputStream
    extends FilterOutputStream {
        public WrappedOutputStream(OutputStream out) throws IOException {
            super(out);
        }

        @Override
        public void write(byte[] buf, int off, int len) throws IOException {
            if (LOG.isDebugEnabled()) {
                LOG.debug("wrapping token of length:" + len);
            }
            buf = SaslRpcClient.this.saslClient.wrap(buf, off, len);
            RpcHeaderProtos.RpcSaslProto saslMessage = RpcHeaderProtos.RpcSaslProto.newBuilder().setState(RpcHeaderProtos.RpcSaslProto.SaslState.WRAP).setToken(ByteString.copyFrom((byte[])buf, (int)0, (int)buf.length)).build();
            ProtobufRpcEngine.RpcRequestMessageWrapper request = new ProtobufRpcEngine.RpcRequestMessageWrapper(saslHeader, (Message)saslMessage);
            DataOutputStream dob = new DataOutputStream(this.out);
            dob.writeInt(request.getLength());
            request.write(dob);
        }
    }

    class WrappedInputStream
    extends FilterInputStream {
        private ByteBuffer unwrappedRpcBuffer;

        public WrappedInputStream(InputStream in) throws IOException {
            super(in);
            this.unwrappedRpcBuffer = ByteBuffer.allocate(0);
        }

        @Override
        public int read() throws IOException {
            byte[] b = new byte[1];
            int n = this.read(b, 0, 1);
            return n != -1 ? b[0] : -1;
        }

        @Override
        public int read(byte[] b) throws IOException {
            return this.read(b, 0, b.length);
        }

        @Override
        public synchronized int read(byte[] buf, int off, int len) throws IOException {
            if (this.unwrappedRpcBuffer.remaining() == 0) {
                this.readNextRpcPacket();
            }
            int readLen = Math.min(len, this.unwrappedRpcBuffer.remaining());
            this.unwrappedRpcBuffer.get(buf, off, readLen);
            return readLen;
        }

        private void readNextRpcPacket() throws IOException {
            LOG.debug("reading next wrapped RPC packet");
            DataInputStream dis = new DataInputStream(this.in);
            int rpcLen = dis.readInt();
            byte[] rpcBuf = new byte[rpcLen];
            dis.readFully(rpcBuf);
            ByteArrayInputStream bis = new ByteArrayInputStream(rpcBuf);
            RpcHeaderProtos.RpcResponseHeaderProto.Builder headerBuilder = RpcHeaderProtos.RpcResponseHeaderProto.newBuilder();
            headerBuilder.mergeDelimitedFrom(bis);
            boolean isWrapped = false;
            if (headerBuilder.getCallId() == Server.AuthProtocol.SASL.callId) {
                RpcHeaderProtos.RpcSaslProto.Builder saslMessage = RpcHeaderProtos.RpcSaslProto.newBuilder();
                saslMessage.mergeDelimitedFrom(bis);
                if (saslMessage.getState() == RpcHeaderProtos.RpcSaslProto.SaslState.WRAP) {
                    isWrapped = true;
                    byte[] token = saslMessage.getToken().toByteArray();
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("unwrapping token of length:" + token.length);
                    }
                    token = SaslRpcClient.this.saslClient.unwrap(token, 0, token.length);
                    this.unwrappedRpcBuffer = ByteBuffer.wrap(token);
                }
            }
            if (!isWrapped) {
                throw new SaslException("Server sent non-wrapped response");
            }
        }
    }
}

