/*--------------------------------------------------------------------------
 *  Copyright 2011 Taro L. Saito
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *--------------------------------------------------------------------------*/
package com.mapr.fs;

import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.jar.Attributes;
import java.util.jar.Attributes.Name;
import java.util.jar.Manifest;
import java.util.Date;
import java.text.SimpleDateFormat;

public class ShimLoader {

  static final String NATIVE_LOADER_CLASS_NAME =  "com.mapr.fs.shim.LibraryLoader";
  static final String[] PRELOAD_CLASSES = {
    "com.mapr.fs.jni.Errno",
    "com.mapr.fs.jni.MapRConstants",
    "com.mapr.fs.jni.MapRConstants$JniUsername",
    "com.mapr.fs.jni.MapRConstants$ErrorValue",
    "com.mapr.fs.jni.MapRConstants$RowConstants",
    "com.mapr.fs.jni.MapRConstants$PutConstants",
    "com.mapr.fs.jni.JNIBlockLocation",
    "com.mapr.fs.jni.JNIFsStatus",
    "com.mapr.fs.jni.JNIFileStatus",
    "com.mapr.fs.jni.JNIFileStatus$VolumeInfo",
    "com.mapr.fs.jni.JNILoggerProxy",
    "com.mapr.fs.jni.IPPort",
    "com.mapr.fs.jni.GatewaySource",
    "com.mapr.fs.jni.Page",
    "com.mapr.fs.jni.Page$CacheState",
    "com.mapr.fs.jni.InodeAttributes",
    "com.mapr.fs.jni.SFid",
    "com.mapr.fs.jni.MapRAsyncRpc",
    "com.mapr.fs.jni.MapRGet",
    "com.mapr.fs.jni.MapRJSONPut",
    "com.mapr.fs.jni.MapRPut",
    "com.mapr.fs.jni.MapRIncrement",
    "com.mapr.fs.jni.MapRKeyValue",
    "com.mapr.fs.jni.MapRRowConstraint",
    "com.mapr.fs.jni.MapRScan",
    "com.mapr.fs.jni.MapRCallBackQueue",
    "com.mapr.fs.jni.MapRClient",
    "com.mapr.fs.jni.MapRTableTools",
    "com.mapr.security.JNISecurity",
    "com.mapr.security.JNISecurity$MutableErr",
    "com.mapr.security.UnixUserGroupHelper",
    "com.mapr.fs.jni.MapRUserGroupInfo",
    "com.mapr.fs.jni.MapRUserInfo",
    "com.mapr.fs.jni.RpcNative",
    "com.mapr.fs.RpcCallContext",
    "com.mapr.fs.jni.MapRClientInitParams",
    "com.mapr.fs.jni.RowColDecoder",
    "com.mapr.fs.jni.RowColDecoder$1",
    "com.mapr.fs.jni.RowColDecoderCallback",
    "com.mapr.fs.jni.RowColParser",
    "com.mapr.fs.jni.RowColParser$1",
    "com.mapr.fs.jni.RowColParser$STATE",
    "com.mapr.fs.jni.RowColParser$ValType",
    "com.mapr.fs.jni.MapRResult",
    "com.mapr.fs.jni.MapRResult$MapRResultDecoderCallback",
    "com.mapr.fs.jni.ParsedRow",
    "com.mapr.fs.jni.MarlinProducerResult",
    "com.mapr.fs.jni.NativeData",
    "com.mapr.fs.jni.ListenerRecord",
    "com.mapr.fs.jni.MarlinJniClient",
    "com.mapr.fs.jni.MarlinJniAdmin",
    "com.mapr.fs.jni.MarlinJniProducer",
    "com.mapr.fs.jni.MarlinJniListener"
  };

  static final String[] WEBAPP_SYSTEM_CLASSES = {
    "com.mapr.fs.jni."
  };


  private static volatile boolean isLoaded = false;
  private static boolean debugLog;
  private static final String USER_NAME;
  private static final String LIBRARY_VERSION;
  static {
    debugLog = System.getProperty("shimloader.debuglog") != null;
    USER_NAME = System.getProperty("user.name").replaceAll("[\\\\/:]", "_");
    LIBRARY_VERSION = getLibraryVersion(ShimLoader.class);
  }

  private static ClassLoader getRootClassLoader() {
    ClassLoader cl = Thread.currentThread().getContextClassLoader();
    if (cl == null)
      cl = ShimLoader.class.getClassLoader();
    trace("getRootClassLoader: thread classLoader is '%s'",
          cl.getClass().getCanonicalName());
    while (cl.getParent() != null) {
      cl = cl.getParent();
    }
    trace("getRootClassLoader: root classLoader is '%s'",
          cl.getClass().getCanonicalName());
    return cl;
  }

  private static byte[] getByteCode(String resourcePath) throws IOException {

    InputStream in = ShimLoader.class.getResourceAsStream(resourcePath);
    if (in == null)
      throw new IOException(resourcePath + " is not found");
    byte[] buf = new byte[1024];
    ByteArrayOutputStream byteCodeBuf = new ByteArrayOutputStream();
    for (int readLength; (readLength = in.read(buf)) != -1;) {
      byteCodeBuf.write(buf, 0, readLength);
    }
    in.close();

    return byteCodeBuf.toByteArray();
  }

  public static boolean isNativeLibraryLoaded() {
    return isLoaded;
  }

  private static boolean isMaprClntLibLoaded() {
    boolean loaderLoaded = false;
    try {
      @SuppressWarnings("unused")
      Class<?> loaderClass = Class.forName(NATIVE_LOADER_CLASS_NAME);
      
      // If loader is loaded, check if the library is loaded
      Method getMethod =
          loaderClass.getDeclaredMethod("isMapRClntLibLoaded");

      loaderLoaded = (Boolean) getMethod.invoke(null);
    } catch (Exception e) {
      // do loading
      loaderLoaded = false;
    }

    return loaderLoaded;
  }

  /**
   * Load native library and its JNI native implementation using the root class
   * loader. This hack is for avoiding the JNI multi-loading issue when the same
   * JNI library is loaded by different class loaders.
   *
   * In order to load native code in the root class loader, this method first
   * inject LibraryLoader class into the root class loader, because
   * {@link System#load(String)} method uses the class loader of the caller
   * class when loading native libraries.
   *
   * <pre>
   * (root class loader) -> [com.mapr.fs.jni.*]  (injected by this method)
   *    |
   *    |
   * (child class loader) -> Sees the above classes loaded by the root class loader.
   *   Then creates MapRFileSystem implementation.
   * </pre>
   *
   *
   * <pre>
   * (root class loader) -> [ShimLoader, MapRClient, etc]  -> native code is loaded by once in this class loader
   *   |   \
   *   |    (child2 class loader)
   * (child1 class loader)
   *
   * child1 and child2 share the same com.mapr.fs.jni.* code loaded by the root class loader.
   * </pre>
   *
   * Note that Java's class loader first delegates the class lookup to its
   * parent class loader. So once com.mapr.fs.jni.* is loaded by the root class
   * loader, no child class loader initialize these classes again.
   *
   * @return
   */
  public static synchronized void load() {

    if(isLoaded) 
    {
      trace("MapR native classes already loaded");
      return;
    }
    
    boolean loadInRootClassloader = 
        (System.getProperty("mapr.library.flatclass") == null);
    
    trace("Load in root Classloader: %s.", loadInRootClassloader);

    try {
      if (loadInRootClassloader) {
        if (!isMaprClntLibLoaded()) {
          trace("Injecting Native Loader");
          String synchronizationString = NATIVE_LOADER_CLASS_NAME.intern();
          synchronized (synchronizationString)
          {
            if (!isMaprClntLibLoaded())
            {
              Class<?> nativeLoader = injectNativeLoader();

              // Load the JNI code using the injected loader
              loadNativeLibrary(nativeLoader);
              
              trace("Native Loader injected");
            }
          }
        }
        
        addSystemClassesToWebApps(PRELOAD_CLASSES);
      }else{
        loadNativeLibrary(System.class);
      }

      isLoaded = true;

    } catch (Exception e) {
      trace("Unable to load libMapRClient.so native library.");
      e.printStackTrace(System.err);
      throw new ExceptionInInitializerError(e);
    }

  }


  /**
   * Inject SnappyNativeLoader class to the root class loader
   *
   * @return native code loader class initialized in the root class loader
   */
  private static Class<?> injectNativeLoader() {

    try {
      // Use parent class loader to load native supporting classes, since applications like Tomcat use
      // different class loaders for different apps and cannot load JNI interface
      // twice

      ClassLoader rootClassLoader = getRootClassLoader();

      // Load a byte code
      byte[] libLoaderByteCode = getByteCode("/com/mapr/fs/shim/LibraryLoader.bytecode");


      List<byte[]> preloadClassByteCode = new ArrayList<byte[]>(PRELOAD_CLASSES.length);
      for (String each : PRELOAD_CLASSES) {
        preloadClassByteCode.add(getByteCode(String.format("/%s.class",each.replaceAll("\\.", "/"))));
      }

      // Create SnappyNativeLoader class from a byte code
      Class<?> classLoader = Class.forName("java.lang.ClassLoader");
      Method defineClass = classLoader.getDeclaredMethod("defineClass", new Class[] { String.class, byte[].class, int.class, int.class, ProtectionDomain.class });
      ProtectionDomain pd = System.class.getProtectionDomain();

      // ClassLoader.defineClass is a protected method, so we have to make it
      // accessible
      defineClass.setAccessible(true);
      try {
        trace("injectNativeLoader: Loading MapR native classes");
        
        // Create a new class using a ClassLoader#defineClass
        defineClass.invoke(rootClassLoader, NATIVE_LOADER_CLASS_NAME, 
                           libLoaderByteCode, 0,libLoaderByteCode.length, pd);

        // And also define dependent classes in the root class loader
        for (int i = 0; i < PRELOAD_CLASSES.length; ++i) {
          byte[] b = preloadClassByteCode.get(i);
          defineClass.invoke(rootClassLoader, PRELOAD_CLASSES[i], b, 0,
              b.length, pd);
        }
      } catch (InvocationTargetException ex) {
          throw ex;
      } finally {
        // Reset the accessibility to defineClass method
        defineClass.setAccessible(false);
      }

      // Load the LibraryLoader class
      return rootClassLoader.loadClass(NATIVE_LOADER_CLASS_NAME);

    } catch (Exception e) {
      e.printStackTrace(System.err);
      throw new RuntimeException("Failure loading MapRClient. ", e);
    }

  }

  /**
   * Load native code using load method of the LibraryLoader class injected to
   * the root class loader.
   *
   * @param loaderClass
   * @throws SecurityException
   * @throws NoSuchMethodException
   * @throws IllegalArgumentException
   * @throws IllegalAccessException
   * @throws InvocationTargetException
   */
  private static void loadNativeLibrary(Class<?> loaderClass)
      throws Exception {
    if (loaderClass == null) throw new RuntimeException("Missing LibraryLoader native loader class");

    Exception straightLoadException;

    // Load preinstalled MapRClient (in the path -Djava.library.path)
    try{
      Method loadMethod = loaderClass.getDeclaredMethod("loadLibrary", new Class[] { String.class });
      loadMethod.invoke(null, "MapRClient");
      trace("Loaded native library from '%s'.", System.getProperty("java.library.path"));
      if (NATIVE_LOADER_CLASS_NAME.equals(loaderClass.getName())) {
        Method setMethod =
          loaderClass.getDeclaredMethod("setMaprClntLibLoaded");
        setMethod.invoke(null);
      }
      return;
    }catch(Exception ex){
      // fall through to jar load.
      straightLoadException = ex;
    }

    try{
      // Direct load didn't work.  Try loading from jar file.
      File nativeLib = findNativeLibrary();
      if (nativeLib != null) {
        // Load extracted native library.
        Method loadMethod = loaderClass.getDeclaredMethod("load", new Class[] { String.class });
        loadMethod.invoke(null, nativeLib.getAbsolutePath());
        trace("Native library loaded.");
        if (NATIVE_LOADER_CLASS_NAME.equals(loaderClass.getName())) {
          Method setMethod =
            loaderClass.getDeclaredMethod("setMaprClntLibLoaded");
          setMethod.invoke(null);
        }
        return;
      }else{
        throw straightLoadException;
      }
    }catch(RuntimeException ex){
      System.err.println("==========Unable to find library on native path due to Exception. ==============");
      straightLoadException.printStackTrace(System.err);
      System.err.println("==========Unable to find library in jar due to exception. ==============");
      ex.printStackTrace(System.err);
      throw straightLoadException;
    }

  }

  /**
   * Computes the MD5 value of the input stream.
   *
   * @param input
   * @return
   * @throws IOException
   * @throws NoSuchAlgorithmException
   */
  static String md5sum(InputStream input) throws IOException {
    BufferedInputStream in = new BufferedInputStream(input);
    try {
      MessageDigest digest = java.security.MessageDigest.getInstance("MD5");
      DigestInputStream digestInputStream = new DigestInputStream(in, digest);
      int bytesRead = 0;
      byte[] buffer = new byte[8192];
      for (;(bytesRead = digestInputStream.read(buffer)) != -1;) {
      }
      ByteArrayOutputStream md5out = new ByteArrayOutputStream();
      md5out.write(digest.digest());
      return md5out.toString();
    } catch (NoSuchAlgorithmException e) {
      throw new IllegalStateException("MD5 algorithm is not available: " + e);
    } finally {
      in.close();
    }
  }

  /**
   * Extract the specified library file to the target folder
   *
   * @param libFolderForCurrentOS
   * @param libraryFileName
   * @param targetFolder
   * @return
   */
  private static File extractLibraryFile(String libFolderForCurrentOS,
      String libraryFileName, String targetFolder) {
    trace("Extracting native library to '%s'.", targetFolder);

    //  mapr-<username>-libMapRClient.<build-version>.so
    int extentionStart = libraryFileName.lastIndexOf('.');
    final String extractedLibFileName = "mapr-" + USER_NAME + "-"
        + libraryFileName.substring(0, extentionStart+1)
        + LIBRARY_VERSION
        + libraryFileName.substring(extentionStart);
    final File extractedLibFile = new File(targetFolder, extractedLibFileName);

    trace("Native library for this platform is '%s'.", extractedLibFileName);
    try {
      String nativeLibraryFilePath = libFolderForCurrentOS + "/" + libraryFileName;

      if (extractedLibFile.exists()) {
        // test md5sum value
        trace("Target file '%s' already exists, verifying checksum.",
          extractedLibFile.getAbsolutePath());
        String md5sum1 = md5sum(ShimLoader.class.getResourceAsStream(nativeLibraryFilePath));
        String md5sum2 = md5sum(new FileInputStream(extractedLibFile));

        if (md5sum1.equals(md5sum2)) {
          trace("Checksum matches, will not extract from the JAR.");
          return extractedLibFile;
        } else {
          // remove old native library file
          trace("Checksum did not match, will replace existing file from the JAR.");
          if (!extractedLibFile.delete()) {
            throw new IOException("Failed to remove existing native library file: "
                    + extractedLibFile.getAbsolutePath());
          }
        }
      }

      trace("Target file '%s' does not exist, will extract from the JAR.",
        extractedLibFile);
      // Extract a native library file into the target directory
      InputStream reader = ShimLoader.class.getResourceAsStream(nativeLibraryFilePath);

      // Create target folder if it does not exist
      File targetFolderFile = new File(targetFolder);
      if (!targetFolderFile.exists()) {
        trace("Creating target folder %s", targetFolder);
        targetFolderFile.mkdirs();
      }

      FileOutputStream writer = new FileOutputStream(extractedLibFile);
      byte[] buffer = new byte[8192];
      int bytesRead = 0;
      while ((bytesRead = reader.read(buffer)) != -1) {
        writer.write(buffer, 0, bytesRead);
      }
      writer.close();
      reader.close();

      // Set executable (x) flag to enable Java to load the native library
      if (!System.getProperty("os.name").contains("Windows")) {
        try {
          Runtime.getRuntime().exec(
            new String[] { "chmod", "755",
                extractedLibFile.getAbsolutePath() }).waitFor();
        } catch (Throwable e) {
          trace("Error setting executable permission.\n%s.", e.getMessage());
        }
      }

      return extractedLibFile;
    } catch (IOException e) {
      e.printStackTrace(System.err);
      return null;
    }
  }

  public static String getLibraryVersion(Class<?> clazz) {
    String libVersion = "unknown";
    try {
      String className = clazz.getSimpleName() + ".class";
      String qualifiedClassName = clazz.getName().replace('.', '/') + ".class";
      String classURL = clazz.getResource(className).toString();
      int endIndex = classURL.startsWith("jar:")
          ? classURL.lastIndexOf("!")+1
          : classURL.lastIndexOf(qualifiedClassName)-1;
      String manifestPath = classURL.substring(0, endIndex) + "/META-INF/MANIFEST.MF";
      Manifest manifest = new Manifest(new URL(manifestPath).openStream());
      Attributes attr = manifest.getMainAttributes();
      Name attrName = new Name("Implementation-Version");
      if (attr.containsKey(attrName)) {
        libVersion = attr.getValue(attrName);
      } else {
        attrName = new Name("Bundle-Version");
        if (attr.containsKey(attrName)) {
          libVersion = attr.getValue(attrName);
        }
      }
    } catch (Throwable e) { }

    return libVersion;
  }

  static File findNativeLibrary() {
    String nativeLibraryName = System.mapLibraryName("MapRClient");
    String nativeLibraryPath = "/com/mapr/fs/native/" +
        OSInfo.getNativeLibFolderPathForCurrentOS();
    trace("Searching for native library '%s/%s'.", nativeLibraryPath, nativeLibraryName);

    boolean hasNativeLib = hasResource(nativeLibraryPath + "/" + nativeLibraryName);

    // Doublecheck for openjdk7 for Mac
    if (!hasNativeLib && OSInfo.getOSName().equals("Mac")) {
        String altName = "libMapRClient.dylib";
        trace("Searching for alternative library '%s' on Mac.", altName);
        if (hasResource(nativeLibraryPath + "/" + altName)) {
          nativeLibraryName = altName;
          hasNativeLib = true;
        }
    }

    if (!hasNativeLib) {
      String errorMessage = String.format("no native library is found for os.name=%s and os.arch=%s",
          OSInfo.getOSName(), OSInfo.getArchName());
      trace(errorMessage);
      throw new RuntimeException(errorMessage);
    }

    String tempFolder = new File(System.getProperty("java.io.tmpdir")).getAbsolutePath();

    // Extract and load a native library inside the jar file
    return extractLibraryFile(nativeLibraryPath, nativeLibraryName, tempFolder);
  }

  private static boolean hasResource(String path) {
    return ShimLoader.class.getResource(path) != null;
  }

  private static void addSystemClassesToWebApps(String[] systemClasses) {
    try {
      // try jetty
      Class<?> jettyWebAppContextClass = Class.forName("org.mortbay.jetty.webapp.WebAppContext");
      Method getCurrentWebAppContextMethod = jettyWebAppContextClass.getMethod("getCurrentWebAppContext");
      Method getSystemClassesMethod = jettyWebAppContextClass.getMethod("getSystemClasses");
      Method setSystemClassesMethod = jettyWebAppContextClass.getMethod("setSystemClasses", String[].class);

      Object jettyCurrentWebAppContext = getCurrentWebAppContextMethod.invoke(null);
      if (jettyCurrentWebAppContext != null) {
        String[] currentSystemClasses = (String []) getSystemClassesMethod.invoke(jettyCurrentWebAppContext);
        List<String> newSystemClasses = new ArrayList<String>();
        Collections.addAll(newSystemClasses, currentSystemClasses);
        Collections.addAll(newSystemClasses, systemClasses);

        Object[] newSystemClassesAsObjectArray = {newSystemClasses.toArray(new String[0])};
        setSystemClassesMethod.invoke(jettyCurrentWebAppContext, newSystemClassesAsObjectArray);
      }

      // support more web containers in future
    } catch (ClassNotFoundException cnfe) {
      // no op
    } catch (Exception e) {
      e.printStackTrace();
    }
  }

  static void trace(String msg, Object... args) {
    if (debugLog) {
      SimpleDateFormat dateFormat = 
        new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
      System.err.println(dateFormat.format( new Date()) + " [" + 
                         Thread.currentThread().getId() + "] " + 
                         String.format(msg, args));
    }
  }

  public static void main(String[] args) {
    debugLog = true;
    trace("ShimLoader library version: %s.", LIBRARY_VERSION);
    if (args.length > 0 && args[0].equals("load")) {
      load();
    } else {
      trace("Native library path: '%s'.", findNativeLibrary());
    }
  }

}
