package com.mapr.security.zookeeper.auth;

import org.apache.commons.codec.binary.Base64;
import org.apache.log4j.Logger;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.KeeperException.Code;
import org.apache.zookeeper.data.Id;
import org.apache.zookeeper.server.ServerCnxn;
import org.apache.zookeeper.server.auth.AuthenticationProvider;

import com.mapr.fs.proto.Security.AuthenticationReqFull;
import com.mapr.fs.proto.Security.CredentialsMsg;
import com.mapr.fs.proto.Security.Key;
import com.mapr.fs.proto.Security.ServerKeyType;
import com.mapr.fs.proto.Security.Ticket;
import com.mapr.security.MutableInt;
import com.mapr.security.Security;

/**
 * Mapr Zookeeper Authentication and Authorization provider
 * authenticates incoming requests based on their userticket + epoch
 * since it is only server authentication
 *
 */
public class MaprZKAuthProvider implements AuthenticationProvider {

  static {
    com.mapr.fs.ShimLoader.load();
  }

  private static final Logger LOG = Logger.getLogger(MaprZKAuthProvider.class);
  
  public static final String MAPR_ZK_AUTH_PROVIDER_SCHEME = "maprauth";
  private static final long MAX_EPOCH_DELAY = 10*60*1000l;
  // Need to define mapr.ticketkeyfile.location in zoo.cfg
  // before ZKServer startup
  private final static String cldbKeyFile = System.getProperty(
            "zookeeper.mapr.cldbkeyfile.location");
  private final static String configuredEpochDelay = System.getProperty(
      "zookeeper.mapr.epoch.delay", "600000"); 
  
  private static long epochDelay = MAX_EPOCH_DELAY;
  
  /**
   * As this is called once during ProviderRegistry init
   * this is a good place to init my own ticket and set Security needed data
   * @throws InstantiationException 
   */
  public MaprZKAuthProvider() throws InstantiationException {
    try {
      epochDelay = Long.valueOf(configuredEpochDelay);
    } catch (NumberFormatException nfe) {
      LOG.warn("Configured Epoch Delay is not a number: " + configuredEpochDelay + ". Will use default");
    }
    
    if ( cldbKeyFile == null ) {
      LOG.error("Location of ZK cldb key is not set");
      throw new InstantiationException("Location of ZK cldb key is not set");
    }
    
    MutableInt err = new MutableInt();
    int errCode = Security.SetKeyFile(ServerKeyType.CldbKey, cldbKeyFile);
    if (errCode != 0) {
       LOG.error("Failed to set cldb key file " + cldbKeyFile + " err " + err);
       throw new InstantiationException("Failed to set cldb key file " + cldbKeyFile + " err " + err);
    } else {
      if (LOG.isInfoEnabled()) {
        LOG.info("Set the cldb key file to " + cldbKeyFile);
      }
    }  
    Key cldbKey = Security.GetKey(ServerKeyType.CldbKey, err);
    if (cldbKey == null) {
      LOG.error("Cldb key can not be obtained: " + err.GetValue());
      throw new InstantiationException("Cldb key can not be obtained: " + err.GetValue());
    }
    Key serverKey = Security.GetServerKey(cldbKey, 0);
    
    if ( serverKey == null ) {
      LOG.error("Server key can not be obtained");
      throw new InstantiationException("Server key can not be obtained");      
    }
    
    errCode = Security.SetKey(ServerKeyType.ServerKey, serverKey);
    if ( errCode != 0 ) {
      LOG.error("Failed to set Server key with error: " + err);
      throw new InstantiationException("Failed to set Server key with error: " + err);
    }
    
  }
  
  @Override
  public String getScheme() {
    return MAPR_ZK_AUTH_PROVIDER_SCHEME;
  }

  @Override
  public Code handleAuthentication(ServerCnxn cnxn, byte[] authData) {
    // authData should be really a ticket with user credentials
    // + randomNumber (let's use epoch to have some validation)
    // encrypted by UserKey
    // essentially code here should resemble MaprSasl code 
    // on first challenge process
    if ( authData == null || authData.length < 1 ) {
      LOG.error("Received challenge is empty when secret expected");
      return KeeperException.Code.AUTHFAILED;
    }
    
    try {
      byte[] base64decoded = Base64.decodeBase64(authData);
      AuthenticationReqFull reply = AuthenticationReqFull.parseFrom(base64decoded);
      if ( reply != null && reply.getEncryptedTicket() != null ) {
        byte [] encryptedTicket = reply.getEncryptedTicket().toByteArray();
        MutableInt err = new MutableInt();
        
        Ticket decryptedTicket = Security.DecryptTicket(encryptedTicket, err);
        if ( err.GetValue() != 0 || decryptedTicket == null ) {
          LOG.error("Error while trying to decrypt ticket: " + err.GetValue());
          return KeeperException.Code.AUTHFAILED;
        }
        CredentialsMsg userCreds = decryptedTicket.getUserCreds();
        Key userKey = decryptedTicket.getUserKey();
        if ( userCreds == null || userKey == null ) {
          LOG.error("Incoming info is not valid");
          return KeeperException.Code.AUTHFAILED;
        }

        String uID  = Integer.toString(userCreds.getUid());
        
        // decrypt randomSecret - which is epoch
        // let's assume that it can not differ from my epoch by 10 mins
        byte[] secretNumberBytes = reply.getEncryptedRandomSecret().toByteArray();
        byte[] secretNumberBytesDecrypted = Security.Decrypt(userKey, secretNumberBytes, err);
        if ( secretNumberBytesDecrypted.length != Long.SIZE/4) {
          LOG.error("Bad random secret");
          return KeeperException.Code.AUTHFAILED;
        }
        long returnLong = (((long)secretNumberBytesDecrypted[0] << 56) +
            ((long)(secretNumberBytesDecrypted[1] & 255) << 48) +
            ((long)(secretNumberBytesDecrypted[2] & 255) << 40) +
            ((long)(secretNumberBytesDecrypted[3] & 255) << 32) +
            ((long)(secretNumberBytesDecrypted[4] & 255) << 24) +
            ((secretNumberBytesDecrypted[5] & 255) << 16) +
            ((secretNumberBytesDecrypted[6] & 255) <<  8) +
            ((secretNumberBytesDecrypted[7] & 255) <<  0));
        
        // For now we assume that secret is epoch and it should not differ from one
        // which is here by much
        long currentTime = System.currentTimeMillis();
        
        if ( Math.abs(currentTime - returnLong) > epochDelay) {
          LOG.warn("Epoch on client differs > " + epochDelay + "ms. then on server: " + currentTime + ", client: " + returnLong +
              ". Most likely it is related to ZK disconnect");
        } else {
          if ( LOG.isDebugEnabled()) {
            LOG.debug("on server: " + currentTime + ", client: " + returnLong);
          }
        }
        
        long clientSession = (((long)secretNumberBytesDecrypted[8] << 56) +
            ((long)(secretNumberBytesDecrypted[9] & 255) << 48) +
            ((long)(secretNumberBytesDecrypted[10] & 255) << 40) +
            ((long)(secretNumberBytesDecrypted[11] & 255) << 32) +
            ((long)(secretNumberBytesDecrypted[12] & 255) << 24) +
            ((secretNumberBytesDecrypted[13] & 255) << 16) +
            ((secretNumberBytesDecrypted[14] & 255) <<  8) +
            ((secretNumberBytesDecrypted[15] & 255) <<  0));

/*        if ( cnxn.getSessionId() != clientSession ) {
          LOG.error("Client Session Id differs from Server Session Id: server: " + cnxn.getSessionId() + ", client: " + clientSession);
          return KeeperException.Code.AUTHFAILED;
        } else {
          if ( LOG.isDebugEnabled()) {
            LOG.debug("Session Id: server: " + cnxn.getSessionId() + ", client: " + clientSession);
          }
        }
*/         // need to set auth info for further use in user matching
        Id myId = new Id(getScheme(), uID);
        cnxn.addAuthInfo(myId);
/*       if ( !cnxn.getAuthInfo().contains(myId)) {
          cnxn.getAuthInfo().add(myId);
        } else {
          if ( LOG.isDebugEnabled()) {
            LOG.debug("AuthInfo was contained");
          }
        }
*/        if ( LOG.isDebugEnabled()) {
          LOG.debug("Auth info size: " + cnxn.getAuthInfo().size());
         // LOG.debug("Session ID: " + cnxn.getSessionId());
        }
        return KeeperException.Code.OK;
      } else {
        LOG.error("Malformed auth info");
        return KeeperException.Code.AUTHFAILED;
      }
      } catch (Throwable t) {
        LOG.error("Bad server key ", t);
        return KeeperException.Code.AUTHFAILED;
      }
  }

  @Override
  public boolean isAuthenticated() {
    return true;
  }

  @Override
  public boolean isValid(String id) {
    // can add some validation of user Id syntax if needed 
    return true;
  }

  @Override
  public boolean matches(String id, String aclExpr) {
    // usually for READ data we allow everybody
    // and for every other other operation only owner can perform it
    // if later needed we can introduce pattern matching
    if ( aclExpr.equals("anyone")) {
      return true;
    }
    return id.equals(aclExpr);
  }
}
