package com.mapr.security.simplesasl;

import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.Principal;
import java.util.Map;

import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;

import org.apache.commons.codec.binary.Base64;

public class SimpleSaslClient implements SaslClient {

  private boolean completed;
  private final String currentId;
  
  public SimpleSaslClient() {
    AccessControlContext context = AccessController.getContext();
    Subject subject = Subject.getSubject(context);
    if ( subject != null && !subject.getPrincipals().isEmpty() ) {
      // determine client principal from subject.
      final Principal clientPrincipal = subject.getPrincipals().iterator().next();
      currentId = clientPrincipal.getName();
    } else {
      // could not get for some reason
      currentId = "";
    }
  }
  
  @Override
  public void dispose() throws SaslException {

  }

  @Override
  public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
    if (completed) {
      throw new IllegalStateException(
          "SimpleSasl authentication already completed");
    }
    byte [] authRequestBytes = Base64.encodeBase64(currentId.getBytes());
    completed = true;
    return authRequestBytes;
  }

  @Override
  public String getMechanismName() {
    return "SIMPLE-SECURITY";
  }

  @Override
  public Object getNegotiatedProperty(String propName) {
    // nothing to do
    return null;
  }

  @Override
  public boolean hasInitialResponse() {
    return true;
  }

  @Override
  public boolean isComplete() {
    // always return true
    return completed;
  }

  @Override
  public byte[] unwrap(byte[] incoming, int offset, int len)
      throws SaslException {
    // nothing to do
    return incoming;
  }

  @Override
  public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
    // nothing to do
    return outgoing;
  }

  public static class SaslSimpleClientFactory implements SaslClientFactory {

    @Override
    public String[] getMechanismNames(Map<String, ?> props) {
      return new String[] { SimpleSaslServer.SIMPLE_SECURITY_MECH_NAME };
    }

    @Override
    public SaslClient createSaslClient(String[] mechanisms,
        String authorizationId, String protocol, String serverName,
        Map<String, ?> props, CallbackHandler cbh) throws SaslException {
      if ( mechanisms != null ) {
        for ( String mechanism : mechanisms ) {
          if (SimpleSaslServer.SIMPLE_SECURITY_MECH_NAME.equals(mechanism)) {
            return new SimpleSaslClient();
          }
        }
      }
      return null;
    }
  }
}
