package com.mapr.security.maprauth;

import org.apache.hadoop.security.authentication.client.AuthenticationException;
import org.apache.hadoop.security.authentication.client.KerberosAuthenticator;
import org.apache.hadoop.security.authentication.server.AuthenticationToken;
import org.apache.hadoop.security.authentication.server.MultiMechsAuthenticationHandler;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.commons.codec.binary.Base64;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Properties;

import com.mapr.fs.proto.Security.AuthenticationReqFull;
import com.mapr.fs.proto.Security.AuthenticationResp;
import com.mapr.fs.proto.Security.CredentialsMsg;
import com.mapr.fs.proto.Security.Key;
import com.mapr.fs.proto.Security.Ticket;
import com.mapr.security.ClusterServerTicketGeneration;
import com.mapr.security.MutableInt;
import com.mapr.security.Security;

import java.io.IOException;

public class MaprAuthenticationHandler extends MultiMechsAuthenticationHandler {
    private static Logger LOG = LoggerFactory.
            getLogger(MaprAuthenticationHandler.class);

    /**
     * Authentication type will be embedded in the authentication token
     */
  public static final String TYPE = "maprauth";

/*
  @Override
    public String getType() {
        return TYPE;
    }
*/
    /**
     * This function is invoked when the filter is coming up.
     * we try to get the mapr serverkey which will be used later
     * to decrypt information sent by the client
     *
     * Also since we may be required to authenticate using Kerberos
     * we invoke the kerberos init code after checking if the
     * principal and keytab specified in the config file exist. If they
     * don't exist we don't invoke the kerberos init code because
     * we don't expect to use kerberos.
     *
     * @param config configuration properties to initialize the handler.
     *
     * @throws ServletException
     */
    @Override
    public void init(Properties config) throws ServletException {
    /* Get the server key */
        try {

          /* TODO: Check for UserGroupInformation.isMaprSecurityEnabled() */
          ClusterServerTicketGeneration.getInstance().generateTicketAndSetServerKey();

        } catch (Exception e) {
            throw new ServletException(e);
        }
    }

    @Override
    public void destroy() {
    }

    @Override
    public AuthenticationToken postauthenticate(HttpServletRequest request, 
        final HttpServletResponse response)
    throws IOException, AuthenticationException {
      if (request.getHeader(KerberosAuthenticator.AUTHORIZATION) != null) {
          return maprAuthenticate(request, response);
      }
      return null;
    }
    
    /**
     * This function is called once we establish the client is authenticating
     * using Mapr ticket and has responded with Mapr negotiate header.
     *
     * Here the server tries to decrypt the bytes (ticket and key) sent by the client using
     * serverkey. Verifies the credentials in the ticket and if the
     * ticket has not expired. Once the server decrypts the ticket and key it has the userkey
     * Using this userkey it decrypts the random secret challenge sent by the client.
     * Increments this by one, encrypts it using userkey and adds it to the response.
     * Since the server has completed verifying the client it generates the authentication
     * token and completes the handshake
     *
     * Anytime there is an error the server sets the error header (WWW_ERR_AUTHENTICATE) and appends
     * the reason for the error, so that the client can display relevant error message
     *
     * @param request the HTTP client request
     * @param response the HTTP client response
     * @return
     * @throws IOException
     * @throws AuthenticationException
     */
    public AuthenticationToken maprAuthenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, AuthenticationException {

        // changed to standard Authorization
        String authorization = request.getHeader(KerberosAuthenticator.AUTHORIZATION);

        /* Sanity check: Make sure header contains Mapr negotiate */
        if (!authorization.startsWith(MaprAuthenticator.NEGOTIATE)) {
          return null;
        } else {
            authorization = authorization.substring(MaprAuthenticator.NEGOTIATE.length()).trim();
            try {
                byte[] base64decoded = Base64.decodeBase64(authorization);

                LOG.trace("MaprAuthentication is started");
                AuthenticationReqFull req = AuthenticationReqFull.
                        parseFrom(base64decoded);

                if (req != null && req.getEncryptedTicket() != null ) {
                    byte [] encryptedTicket =
                            req.getEncryptedTicket().toByteArray();
                    MutableInt err = new MutableInt();

                    /* During login ServerKey should have been set -
                     * if it was not set we have a wrong server here
                     */
                    Ticket decryptedTicket = Security.DecryptTicket(encryptedTicket,
                            err);

                    if (err.GetValue() != 0 || decryptedTicket == null) {
                        String decryptError = "Error while decrypting ticket and key " + err.GetValue();

                        // set the error header 
                        response.setHeader(MaprAuthenticator.WWW_ERR_AUTHENTICATE, decryptError);

                        // set status to be 401 
                        response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);

                        return null;
                    }

                    CredentialsMsg userCreds = decryptedTicket.getUserCreds();
                    Key userKey = decryptedTicket.getUserKey();
                    String userName = userCreds.getUserName();

                    // decrypt randomSecret
                    byte[] secretNumberBytes = req.getEncryptedRandomSecret().toByteArray();
                    byte[] secretNumberBytesDecrypted = Security.
                            Decrypt(userKey, secretNumberBytes, err);

                    if (secretNumberBytesDecrypted.length != Long.SIZE/8) {

                        String badSecretError = "Bad random secret";
                        LOG.error(badSecretError);

                        // set the error header 
                        response.setHeader(MaprAuthenticator.WWW_ERR_AUTHENTICATE, badSecretError);

                        // set status to be 401 
                        response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);

                        return null;
                    }

                    long returnLong = (((long)secretNumberBytesDecrypted[0] << 56) +
                            ((long)(secretNumberBytesDecrypted[1] & 255) << 48) +
                            ((long)(secretNumberBytesDecrypted[2] & 255) << 40) +
                            ((long)(secretNumberBytesDecrypted[3] & 255) << 32) +
                            ((long)(secretNumberBytesDecrypted[4] & 255) << 24) +
                            ((secretNumberBytesDecrypted[5] & 255) << 16) +
                            ((secretNumberBytesDecrypted[6] & 255) <<  8) +
                            ((secretNumberBytesDecrypted[7] & 255) <<  0));

                    LOG.trace("Received secret number: " + returnLong);

                    //increment the secret number and generate response
                    returnLong++;

                    AuthenticationResp.Builder authResp =
                            AuthenticationResp.newBuilder();
                    authResp.setChallengeResponse(returnLong);
                    authResp.setStatus(0);
                    byte [] resp = authResp.build().toByteArray();
                    byte [] respEncrypted = Security.Encrypt(userKey, resp, err);

                    Base64 base64 = new Base64(0);
                    String authenticate = base64.encodeToString(respEncrypted);

                    response.setHeader(KerberosAuthenticator.AUTHORIZATION,
                            MaprAuthenticator.NEGOTIATE + " " + authenticate);

                    LOG.trace("MaprAuthentication is completed on server side");

                    /* generate the authentication token for the user */
                    return new AuthenticationToken(userName, userName, getType());


                } else {

                    String clientRequestError = "Malformed client request";

                    LOG.error(clientRequestError);

                    // set the error header 
                    response.setHeader(MaprAuthenticator.WWW_ERR_AUTHENTICATE, clientRequestError);

                    // set status to be 401 
                    response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);

                    return null;
                }

            } catch (Throwable t) {
                String serverKeyError = "Bad server key";

                LOG.error(serverKeyError, t);

                /* set the error header */
                response.setHeader(MaprAuthenticator.WWW_ERR_AUTHENTICATE, serverKeyError);

                /* set status to be 401 */
                response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);

                return null;
            }
        }
    }
    
    @Override
    public void addHeader(HttpServletResponse response) {
      response.addHeader(KerberosAuthenticator.WWW_AUTHENTICATE, MaprAuthenticator.NEGOTIATE);
    }

    @Override
    public MultiMechsAuthenticationHandler getAuthBasedEntity(String authorization) {
      if ( authorization != null && authorization.startsWith(MaprAuthenticator.NEGOTIATE)) {
        return this;
      }
      return null;
    }

}
