package com.mapr.security.maprsasl;

import java.util.Map;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;

import org.apache.commons.codec.binary.Base64;
import org.apache.log4j.Logger;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.mapr.baseutils.cldbutils.CLDBRpcCommonUtils;
import com.mapr.fs.proto.Security.AuthenticationReqFull;
import com.mapr.fs.proto.Security.AuthenticationResp;
import com.mapr.fs.proto.Security.Key;
import com.mapr.fs.proto.Security.ServerKeyType;
import com.mapr.fs.proto.Security.TicketAndKey;
import com.mapr.security.JNISecurity;
import com.mapr.security.MutableInt;
import com.mapr.security.Security;
import com.mapr.security.maprsasl.MaprSaslServer.QOP;
import com.mapr.login.client.MapRLoginClient;
import com.mapr.login.client.MapRLoginHttpsClient;

public class MaprSaslClient implements SaslClient {

  private static final Logger LOG = Logger.getLogger(MaprSaslClient.class);

  private static final Integer MAX_BUF_SIZE_FOR_WRAP = new Integer(64 * 1024);

  private boolean completed;
  private boolean firstPassDone;
  private CallbackHandler cbh;
  private long randomSecret;
  private String authorizationId;
  private String authenticationId;
  private Key sessionKey;
  private Key userKey;
  private String negotiatedQOPProperty;
  private String localqopProperty;

  public MaprSaslClient(Map<String, ?> props, CallbackHandler cbh)
      throws SaslException {
    this.cbh = cbh;
    if ( props == null || props.isEmpty()) {
      localqopProperty = QOP.AUTHENTICATION.getQopString();
    } else { 
      for ( Map.Entry<String, ?> entry : props.entrySet()) {
        String key = entry.getKey();
        if ( key.equals(Sasl.QOP)) {
          Object valueO = entry.getValue();
          if ( valueO instanceof String) {
            localqopProperty = (String) entry.getValue();
          }
        }
      }
    }
  }

  @Override
  public void dispose() throws SaslException {
    // remove sensitive data
    sessionKey = null;
    randomSecret = -1;
    authorizationId = null;
    authenticationId = null;
  }

  @Override
  public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
    if (completed) {
      throw new IllegalStateException(
          "MaprSasl authentication already completed");
    }
    if ( !firstPassDone ) {
      // this is first time so challenge is nothing
      // need to pass ServerTicket + {randomNumber} encrypted by UserKey

      try {
        MutableInt err = new MutableInt();
        MapRLoginClient client = new MapRLoginHttpsClient();
        //TODO - assuming default cluster. I think this is true for Sasl always.
        TicketAndKey ticketKey = null;
        
        if ( MaprSecurityLoginModule.isUseMaprServerTicket() ) {
          String currentClusterName =
            CLDBRpcCommonUtils.getInstance().getCurrentClusterName();
          ticketKey = Security.GetTicketAndKeyForCluster(ServerKeyType.CldbKey, currentClusterName, err);
        } else {
          ticketKey = client.authenticateIfNeeded();
        }
        if ( ticketKey == null ) {
          throw new SaslException("ServerTicketKey was not set");
        }
        long ticketExpTime = ticketKey.getExpiryTime() * 1000L;
        if ( ticketExpTime < System.currentTimeMillis() ) {
          throw new SaslException("MaprSaslClient My ticket Expired");
        }
        userKey = ticketKey.getUserKey();
        randomSecret = JNISecurity.GenerateRandomNumber();
        byte[] writeBuffer = new byte[8];
        writeBuffer[0] = (byte)(randomSecret >>> 56);
        writeBuffer[1] = (byte)(randomSecret >>> 48);
        writeBuffer[2] = (byte)(randomSecret >>> 40);
        writeBuffer[3] = (byte)(randomSecret >>> 32);
        writeBuffer[4] = (byte)(randomSecret >>> 24);
        writeBuffer[5] = (byte)(randomSecret >>> 16);
        writeBuffer[6] = (byte)(randomSecret >>>  8);
        writeBuffer[7] = (byte)(randomSecret >>>  0);
        AuthenticationReqFull.Builder bld = AuthenticationReqFull.newBuilder();
        byte [] secretBytesEncrypted = Security.Encrypt(userKey, writeBuffer, err);
        if ( err.GetValue() != 0 ) {
          throw new SaslException("Error while encrypting data: " + err.GetValue());
        }

        bld.setEncryptedRandomSecret(ByteString.copyFrom(secretBytesEncrypted));
        bld.setEncryptedTicket(ticketKey.getEncryptedTicket());
        byte [] authRequestBytes = bld.build().toByteArray();

        authRequestBytes = Base64.encodeBase64(authRequestBytes);
        firstPassDone = true;
        return authRequestBytes;
      } catch (Throwable t) {
        if ( t instanceof SaslException) {
          throw (SaslException) t;
        }
        LOG.error("Exception while processing ticket data", t);
        throw new SaslException("Exception while processing ticket data",t);
      } 
    }
    // we should receive back {randomNumber, sessionKey} encrypted by UserKey
    // store sessionKey for further encrypt/decrypt if needed
    // check if received random number is the same as what we have here
    // hack something here for testing
    if ( challenge == null || challenge.length < 1) {
      throw new SaslException("Received challenge is empty when secret expected");
    }
    if ( userKey == null ) {
      throw new SaslException("Bad userKey");
    }
    try {
      MutableInt err = new MutableInt();
      challenge = Base64.decodeBase64(challenge);
      byte [] decodedResponse = Security.Decrypt(userKey, challenge, err);
      if ( err.GetValue() != 0 ) {
        throw new SaslException("Error while decrypting data: " + err.GetValue());
      }
      AuthenticationResp authResponse = null;
      try {
        authResponse = AuthenticationResp.parseFrom(decodedResponse);
      } catch (InvalidProtocolBufferException e) {
        throw new SaslException("Can not parse out the data from server response", e);
      }
      if ( authResponse != null ) {
        int status = authResponse.getStatus();
        if ( status == 0 ) {
          if ( authResponse.hasChallengeResponse() ) {
            long respLong = authResponse.getChallengeResponse();
            if ( randomSecret != respLong ) {
              throw new SaslException("Bad returned secret");
            }
          } else {
            throw new SaslException("No returned secret");
          }
          // check QOP property
          if ( authResponse.hasEncodingType() ) {
            int qopInt = authResponse.getEncodingType();
            // TODO looks like it may be set of strings, though from our server we won't reeive it
            String interimQOPOption = MaprSaslServer.QOP.getStringFromQOPInt(qopInt);
            if ( interimQOPOption != null ) {
              if ( !interimQOPOption.equals(localqopProperty)) {
                // TODO make stronger option win - will do it later
                // since if server has weaker option we may need to travel back to server
                // with proposed response
                LOG.warn("SASL Server qopProperty: " +  interimQOPOption 
                  + "is different from Client: " + localqopProperty + ".Using Server one");
              } 
              negotiatedQOPProperty = interimQOPOption;
            }
          } else {
            throw new SaslException("No server QOP in response");
          }
          if ( !QOP.AUTHENTICATION.getQopString().equals(negotiatedQOPProperty)) {
            // only should need sessionKey when negotiated property is not "auth"
            sessionKey = authResponse.getSessionKey();
            if ( sessionKey == null ) {
              throw new SaslException("Bad returned sessionKey");
            }
          }
          completed = true;
          // nothing more to do
          return new byte[0];
        }
      } 
      throw new SaslException("Bad response");
    } catch (Throwable t) {
      if ( t instanceof SaslException ) {
        throw (SaslException) t;
      }
      LOG.error("Exception while processing ticket data", t);
      throw new SaslException("Exception while processing ticket data",t);
    } 
  }

  @Override
  public String getMechanismName() {
    return "MAPR-SECURITY";
  }

  @Override
  public Object getNegotiatedProperty(String propName) {
    if (completed) {
      if ( Sasl.QOP.equals(propName) ) {
        return negotiatedQOPProperty;
      } else if (Sasl.RAW_SEND_SIZE.equals(propName)) {
        return MAX_BUF_SIZE_FOR_WRAP.toString();
      } else {
        throw new IllegalStateException("MAPR-SECURITY does not support any property except "
            + Sasl.QOP + " and " + Sasl.RAW_SEND_SIZE);
      }
    } else {
      throw new IllegalStateException("MAPR-SECURITY authentication not completed");
    }
  }

  @Override
  public boolean hasInitialResponse() {
    return true;
  }

  @Override
  public boolean isComplete() {
    return completed;
  }

  @Override
  public byte[] unwrap(byte[] incoming, int offset, int len)
      throws SaslException {
    byte[] keyBytes = new byte[len];
    System.arraycopy(incoming, offset, keyBytes, 0, len);
    MutableInt err = new MutableInt();
    byte [] decrypted = Security.Decrypt(sessionKey, keyBytes, err);
    return decrypted;
  }

  @Override
  public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
    byte[] keyBytes = new byte[len];
    System.arraycopy(outgoing, offset, keyBytes, 0, len);
    MutableInt err = new MutableInt();
    byte [] encrypted = Security.Encrypt(sessionKey, keyBytes, err);
    return encrypted;
  }

  public static class SaslMaprClientFactory implements SaslClientFactory {

    @Override
    public String[] getMechanismNames(Map<String, ?> props) {
      return new String[] { MaprSaslServer.MAPR_SECURITY_MECH_NAME };
    }

    @Override
    public SaslClient createSaslClient(String[] mechanisms,
        String authorizationId, String protocol, String serverName,
        Map<String, ?> props, CallbackHandler cbh) throws SaslException {
      if ( mechanisms != null ) {
        for ( String mechanism : mechanisms ) {
          if (MaprSaslServer.MAPR_SECURITY_MECH_NAME.equals(mechanism)) {
            return new MaprSaslClient(props, cbh);
          }
        }
      }
      return null;
    }
  }

}
