/*
 * Decompiled with CFR 0.152.
 */
package com.nvidia.spark.rapids.iceberg.spark.source;

import com.nvidia.spark.rapids.RapidsConf;
import com.nvidia.spark.rapids.iceberg.spark.Spark3Util;
import com.nvidia.spark.rapids.iceberg.spark.SparkFilters;
import com.nvidia.spark.rapids.iceberg.spark.SparkReadConf;
import com.nvidia.spark.rapids.iceberg.spark.SparkSchemaUtil;
import com.nvidia.spark.rapids.iceberg.spark.source.GpuSparkScan;
import com.nvidia.spark.rapids.shims.ShimSupportsRuntimeFiltering;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.FileScanTask;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableScan;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Binder;
import org.apache.iceberg.expressions.Evaluator;
import org.apache.iceberg.expressions.Expression;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.expressions.Projections;
import org.apache.iceberg.expressions.True;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.iceberg.util.TableScanUtil;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.Statistics;
import org.apache.spark.sql.sources.Filter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GpuSparkBatchQueryScan
extends GpuSparkScan
implements ShimSupportsRuntimeFiltering {
    private static final Logger LOG = LoggerFactory.getLogger(GpuSparkBatchQueryScan.class);
    private final TableScan scan;
    private final Long snapshotId;
    private final Long startSnapshotId;
    private final Long endSnapshotId;
    private final Long asOfTimestamp;
    private final List<Expression> runtimeFilterExpressions;
    private Set<Integer> specIds = null;
    private List<FileScanTask> files = null;
    private List<CombinedScanTask> tasks = null;

    public static boolean isMetadataScan(Scan cpuInstance) throws IllegalAccessException {
        List tasks = (List)FieldUtils.readField((Object)cpuInstance, (String)"tasks", (boolean)true);
        if (tasks == null || tasks.isEmpty()) {
            return true;
        }
        Iterator taskIter = ((CombinedScanTask)tasks.get(0)).files().iterator();
        return !taskIter.hasNext() || ((FileScanTask)taskIter.next()).isDataTask();
    }

    public static GpuSparkBatchQueryScan fromCpu(Scan cpuInstance, RapidsConf rapidsConf) throws IllegalAccessException {
        TableScan scan;
        Table table = (Table)FieldUtils.readField((Object)cpuInstance, (String)"table", (boolean)true);
        SparkReadConf readConf = SparkReadConf.fromReflect(FieldUtils.readField((Object)cpuInstance, (String)"readConf", (boolean)true));
        Schema expectedSchema = (Schema)FieldUtils.readField((Object)cpuInstance, (String)"expectedSchema", (boolean)true);
        List filters = (List)FieldUtils.readField((Object)cpuInstance, (String)"filterExpressions", (boolean)true);
        try {
            scan = (TableScan)FieldUtils.readField((Object)cpuInstance, (String)"scan", (boolean)true);
        }
        catch (IllegalArgumentException ignored) {
            scan = GpuSparkBatchQueryScan.buildScan(cpuInstance, table, readConf, expectedSchema, filters);
        }
        return new GpuSparkBatchQueryScan(SparkSession.active(), table, scan, readConf, expectedSchema, filters, rapidsConf);
    }

    private static TableScan buildScan(Scan cpuInstance, Table table, SparkReadConf readConf, Schema expectedSchema, List<Expression> filterExpressions) throws IllegalAccessException {
        Long snapshotId = (Long)FieldUtils.readField((Object)cpuInstance, (String)"snapshotId", (boolean)true);
        Long startSnapshotId = (Long)FieldUtils.readField((Object)cpuInstance, (String)"startSnapshotId", (boolean)true);
        Long endSnapshotId = (Long)FieldUtils.readField((Object)cpuInstance, (String)"endSnapshotId", (boolean)true);
        Long asOfTimestamp = (Long)FieldUtils.readField((Object)cpuInstance, (String)"asOfTimestamp", (boolean)true);
        Long splitSize = (Long)FieldUtils.readField((Object)cpuInstance, (String)"splitSize", (boolean)true);
        Integer splitLookback = (Integer)FieldUtils.readField((Object)cpuInstance, (String)"splitLookback", (boolean)true);
        Long splitOpenFileCost = (Long)FieldUtils.readField((Object)cpuInstance, (String)"splitOpenFileCost", (boolean)true);
        TableScan scan = table.newScan().caseSensitive(readConf.caseSensitive()).project(expectedSchema);
        if (snapshotId != null) {
            scan = scan.useSnapshot(snapshotId.longValue());
        }
        if (asOfTimestamp != null) {
            scan = scan.asOfTime(asOfTimestamp.longValue());
        }
        if (startSnapshotId != null) {
            scan = endSnapshotId != null ? scan.appendsBetween(startSnapshotId.longValue(), endSnapshotId.longValue()) : scan.appendsAfter(startSnapshotId.longValue());
        }
        if (splitSize != null) {
            scan = scan.option("read.split.target-size", splitSize.toString());
        }
        if (splitLookback != null) {
            scan = scan.option("read.split.planning-lookback", splitLookback.toString());
        }
        if (splitOpenFileCost != null) {
            scan = scan.option("read.split.open-file-cost", splitOpenFileCost.toString());
        }
        for (Expression filter : filterExpressions) {
            scan = scan.filter(filter);
        }
        return scan;
    }

    GpuSparkBatchQueryScan(SparkSession spark, Table table, TableScan scan, SparkReadConf readConf, Schema expectedSchema, List<Expression> filters, RapidsConf rapidsConf) {
        super(spark, table, readConf, expectedSchema, filters, rapidsConf);
        this.scan = scan;
        this.snapshotId = readConf.snapshotId();
        this.startSnapshotId = readConf.startSnapshotId();
        this.endSnapshotId = readConf.endSnapshotId();
        this.asOfTimestamp = readConf.asOfTimestamp();
        this.runtimeFilterExpressions = Lists.newArrayList();
        if (scan == null) {
            this.specIds = Collections.emptySet();
            this.files = Collections.emptyList();
            this.tasks = Collections.emptyList();
        }
    }

    Long snapshotId() {
        return this.snapshotId;
    }

    private Set<Integer> specIds() {
        if (this.specIds == null) {
            HashSet specIdSet = Sets.newHashSet();
            for (FileScanTask file : this.files()) {
                specIdSet.add(file.spec().specId());
            }
            this.specIds = specIdSet;
        }
        return this.specIds;
    }

    private List<FileScanTask> files() {
        if (this.files == null) {
            try (CloseableIterable filesIterable = this.scan.planFiles();){
                this.files = Lists.newArrayList((Iterable)filesIterable);
            }
            catch (IOException e) {
                throw new UncheckedIOException("Failed to close table scan: " + this.scan, e);
            }
        }
        return this.files;
    }

    @Override
    protected List<CombinedScanTask> tasks() {
        if (this.tasks == null) {
            CloseableIterable splitFiles = TableScanUtil.splitFiles((CloseableIterable)CloseableIterable.withNoopClose(this.files()), (long)this.scan.targetSplitSize());
            CloseableIterable scanTasks = TableScanUtil.planTasks((CloseableIterable)splitFiles, (long)this.scan.targetSplitSize(), (int)this.scan.splitLookback(), (long)this.scan.splitOpenFileCost());
            this.tasks = Lists.newArrayList((Iterable)scanTasks);
        }
        return this.tasks;
    }

    public NamedReference[] filterAttributes() {
        HashSet partitionFieldSourceIds = Sets.newHashSet();
        for (Integer specId : this.specIds()) {
            PartitionSpec spec = (PartitionSpec)this.table().specs().get(specId);
            for (PartitionField field : spec.fields()) {
                partitionFieldSourceIds.add(field.sourceId());
            }
        }
        Map<Integer, String> quotedNameById = SparkSchemaUtil.indexQuotedNameById(this.expectedSchema());
        return (NamedReference[])partitionFieldSourceIds.stream().filter(fieldId -> this.expectedSchema().findField(fieldId.intValue()) != null).map(fieldId -> Spark3Util.toNamedReference((String)quotedNameById.get(fieldId))).toArray(NamedReference[]::new);
    }

    public void filter(Filter[] filters) {
        Expression runtimeFilterExpr = this.convertRuntimeFilters(filters);
        if (runtimeFilterExpr != Expressions.alwaysTrue()) {
            HashMap evaluatorsBySpecId = Maps.newHashMap();
            for (Integer specId : this.specIds()) {
                PartitionSpec spec = (PartitionSpec)this.table().specs().get(specId);
                Expression inclusiveExpr = Projections.inclusive((PartitionSpec)spec, (boolean)this.caseSensitive()).project(runtimeFilterExpr);
                Evaluator inclusive = new Evaluator(spec.partitionType(), inclusiveExpr);
                evaluatorsBySpecId.put(specId, inclusive);
            }
            LOG.info("Trying to filter {} files using runtime filter {}", (Object)this.files().size(), (Object)runtimeFilterExpr);
            List filteredFiles = this.files().stream().filter(file -> {
                Evaluator evaluator = (Evaluator)evaluatorsBySpecId.get(file.spec().specId());
                return evaluator.eval(file.file().partition());
            }).collect(Collectors.toList());
            LOG.info("{}/{} files matched runtime filter {}", new Object[]{filteredFiles.size(), this.files().size(), runtimeFilterExpr});
            if (filteredFiles.size() < this.files().size()) {
                this.specIds = null;
                this.files = filteredFiles;
                this.tasks = null;
            }
            this.runtimeFilterExpressions.add(runtimeFilterExpr);
        }
    }

    private Expression convertRuntimeFilters(Filter[] filters) {
        True runtimeFilterExpr = Expressions.alwaysTrue();
        for (Filter filter : filters) {
            Expression expr = SparkFilters.convert(filter);
            if (expr != null) {
                try {
                    Binder.bind((Types.StructType)this.expectedSchema().asStruct(), (Expression)expr, (boolean)this.caseSensitive());
                    runtimeFilterExpr = Expressions.and((Expression)runtimeFilterExpr, (Expression)expr);
                }
                catch (ValidationException e) {
                    LOG.warn("Failed to bind {} to expected schema, skipping runtime filter", (Object)expr, (Object)e);
                }
                continue;
            }
            LOG.warn("Unsupported runtime filter {}", (Object)filter);
        }
        return runtimeFilterExpr;
    }

    @Override
    public Statistics estimateStatistics() {
        if (this.scan == null) {
            return this.estimateStatistics(null);
        }
        if (this.snapshotId != null) {
            Snapshot snapshot = this.table().snapshot(this.snapshotId.longValue());
            return this.estimateStatistics(snapshot);
        }
        if (this.asOfTimestamp != null) {
            long snapshotIdAsOfTime = SnapshotUtil.snapshotIdAsOfTime((Table)this.table(), (long)this.asOfTimestamp);
            Snapshot snapshot = this.table().snapshot(snapshotIdAsOfTime);
            return this.estimateStatistics(snapshot);
        }
        Snapshot snapshot = this.table().currentSnapshot();
        return this.estimateStatistics(snapshot);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        GpuSparkBatchQueryScan that = (GpuSparkBatchQueryScan)o;
        return this.table().name().equals(that.table().name()) && this.readSchema().equals((Object)that.readSchema()) && this.filterExpressions().toString().equals(that.filterExpressions().toString()) && this.runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) && Objects.equals(this.snapshotId, that.snapshotId) && Objects.equals(this.startSnapshotId, that.startSnapshotId) && Objects.equals(this.endSnapshotId, that.endSnapshotId) && Objects.equals(this.asOfTimestamp, that.asOfTimestamp);
    }

    public int hashCode() {
        return Objects.hash(this.table().name(), this.readSchema(), this.filterExpressions().toString(), this.runtimeFilterExpressions.toString(), this.snapshotId, this.startSnapshotId, this.endSnapshotId, this.asOfTimestamp);
    }

    public String toString() {
        return String.format("IcebergScan(table=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", this.table(), this.expectedSchema().asStruct(), this.filterExpressions(), this.runtimeFilterExpressions, this.caseSensitive());
    }
}

