/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;

import junit.framework.Assert;

import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.parse.TypeCheckProcFactory;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.testutil.BaseScalarUdfTest;
import org.apache.hadoop.hive.ql.testutil.DataBuilder;
import org.apache.hadoop.hive.ql.testutil.OperatorTestUtils;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.InspectableObject;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.junit.Test;

public class TestGenericUDFRound extends BaseScalarUdfTest {
  private static final String[] cols = {"s", "i", "d", "f", "b", "sh", "l", "dec"};

  @Override
  public InspectableObject[] getBaseTable() {
    DataBuilder db = new DataBuilder();
    db.setColumnNames(cols);
    db.setColumnTypes(
        PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        PrimitiveObjectInspectorFactory.javaIntObjectInspector,
        PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
        PrimitiveObjectInspectorFactory.javaFloatObjectInspector,
        PrimitiveObjectInspectorFactory.javaByteObjectInspector,
        PrimitiveObjectInspectorFactory.javaShortObjectInspector,
        PrimitiveObjectInspectorFactory.javaLongObjectInspector,
        PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(TypeInfoFactory.getDecimalTypeInfo(15, 5)));
    db.addRow("one", 170, new Double("1.1"), new Float("32.1234"), new Byte("25"), new Short("1234"), 123456L, HiveDecimal.create("983.7235"));
    db.addRow( "-234", null, null, new Float("0.347232"), new Byte("109"), new Short("551"), 923L, HiveDecimal.create("983723.005"));
    db.addRow("454.45", 22345,  new Double("-23.00009"), new Float("-3.4"), new Byte("76"), new Short("2321"), 9232L, HiveDecimal.create("-932032.7"));
    return db.createRows();
  }

  @Override
  public InspectableObject[] getExpectedResult() {
    DataBuilder db = new DataBuilder();
    db.setColumnNames("_col1", "_col2", "_col3", "_col4", "_col5", "_col6", "_col7", "_col8");
    db.setColumnTypes(PrimitiveObjectInspectorFactory.javaStringObjectInspector,
        PrimitiveObjectInspectorFactory.writableIntObjectInspector,
        PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
        PrimitiveObjectInspectorFactory.writableFloatObjectInspector,
        PrimitiveObjectInspectorFactory.writableByteObjectInspector,
        PrimitiveObjectInspectorFactory.writableShortObjectInspector,
        PrimitiveObjectInspectorFactory.writableLongObjectInspector,
        PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector);
    db.addRow(null, new IntWritable(170), new DoubleWritable(1.1), new FloatWritable(32f),
        new ByteWritable((byte)0), new ShortWritable((short)1234), new LongWritable(123500L), new HiveDecimalWritable(HiveDecimal.create("983.724")));
    db.addRow(new DoubleWritable(-200), null, null, new FloatWritable(0f),
        new ByteWritable((byte)100), new ShortWritable((short)551), new LongWritable(900L), new HiveDecimalWritable(HiveDecimal.create("983723.005")));
    db.addRow(new DoubleWritable(500), new IntWritable(22345),  new DoubleWritable(-23.000), new FloatWritable(-3f),
        new ByteWritable((byte)100), new ShortWritable((short)2321), new LongWritable(9200L), new HiveDecimalWritable(HiveDecimal.create("-932032.7")));
    return db.createRows();
  }

  @Override
  public List<ExprNodeDesc> getExpressionList() throws UDFArgumentException {
    List<ExprNodeDesc> exprs = new ArrayList<ExprNodeDesc>(cols.length);
    for (int i = 0; i < cols.length; i++) {
      exprs.add(OperatorTestUtils.getStringColumn(cols[i]));
    }

    ExprNodeDesc[] scales = { new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, -2),
        new ExprNodeConstantDesc(TypeInfoFactory.byteTypeInfo, (byte)0),
        new ExprNodeConstantDesc(TypeInfoFactory.shortTypeInfo, (short)3),
        new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 0),
        new ExprNodeConstantDesc(TypeInfoFactory.longTypeInfo, -2L),
        new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 0),
        new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, -2),
        new ExprNodeConstantDesc(TypeInfoFactory.intTypeInfo, 3) };

    List<ExprNodeDesc> earr = new ArrayList<ExprNodeDesc>();
    for (int j = 0; j < cols.length; j++) {
      ExprNodeDesc r = TypeCheckProcFactory.DefaultExprProcessor.getFuncExprNodeDesc("round", exprs.get(j), scales[j]);
      earr.add(r);
    }

    return earr;
  }

  @Test
  public void testDecimalRoundingMetaData() throws UDFArgumentException {
    GenericUDFRound udf = new GenericUDFRound();
    ObjectInspector[] inputOIs = {
        PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)),
        PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(2))
    };
    PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs);
    DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo();
    Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(7, 2), outputTypeInfo);
  }

  @Test
  public void testDecimalRoundingMetaData1() throws UDFArgumentException {
    GenericUDFRound udf = new GenericUDFRound();
    ObjectInspector[] inputOIs = {
        PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)),
        PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(-2))
    };
    PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs);
    DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo();
    Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(5, 0), outputTypeInfo);
  }

  @Test
  public void testDecimalRoundingMetaData2() throws UDFArgumentException {
    GenericUDFRound udf = new GenericUDFRound();
    ObjectInspector[] inputOIs = {
        PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getDecimalTypeInfo(7, 3)),
        PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(TypeInfoFactory.intTypeInfo, new IntWritable(5))
    };
    PrimitiveObjectInspector outputOI = (PrimitiveObjectInspector) udf.initialize(inputOIs);
    DecimalTypeInfo outputTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo();
    Assert.assertEquals(TypeInfoFactory.getDecimalTypeInfo(9, 5), outputTypeInfo);
  }

}
