/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.rest;

import com.google.common.base.Preconditions;
import io.confluent.rest.Application;
import io.confluent.rest.RestConfig;
import io.confluent.rest.TestMetricsReporter;
import io.confluent.rest.TestRestConfig;
import java.net.URI;
import java.time.Duration;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.core.Configurable;
import javax.ws.rs.core.Configuration;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.apache.kafka.common.metrics.KafkaMetric;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlets.DoSFilter;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.TestInfo;

@Tag(value="IntegrationTest")
class JettyDosFilterMultiListenerIntegrationTest {
    private static final int DOS_FILTER_MAX_REQUESTS_PER_CONNECTION_PER_SEC = 25;
    private static final int DOS_FILTER_MAX_REQUESTS_PER_SEC = 25;
    private ScheduledExecutorService executor;
    private Server server;
    private Client client;
    private TestDosFilterListener nonGlobalDosFilterListener = new TestDosFilterListener();
    private TestDosFilterListener globalDosFilterListener = new TestDosFilterListener();

    JettyDosFilterMultiListenerIntegrationTest() {
    }

    @BeforeEach
    public void setUp(TestInfo testInfo) throws Exception {
        TestMetricsReporter.reset();
        this.nonGlobalDosFilterListener.rejectedCounter.set(0);
        this.globalDosFilterListener.rejectedCounter.set(0);
        Properties props = new Properties();
        props.setProperty("debug", "false");
        props.put("metric.reporters", "io.confluent.rest.TestMetricsReporter");
        props.put("dos.filter.enabled", (Object)true);
        props.put("dos.filter.delay.ms", (Object)-1L);
        if (testInfo.getDisplayName().contains("test_dosFilterMultiListener_withGlobalDosFilterRejecting_CheckRelevantListenersCalled")) {
            props.put("dos.filter.max.requests.per.connection.per.sec", (Object)100);
            props.put("dos.filter.max.requests.per.sec", (Object)25);
        } else {
            props.put("dos.filter.max.requests.per.connection.per.sec", (Object)25);
            props.put("dos.filter.max.requests.per.sec", (Object)100);
        }
        TestRestConfig config = TestRestConfig.maprCompatible(props);
        ApplicationWithDoSFilterEnabled app = new ApplicationWithDoSFilterEnabled(config);
        app.addNonGlobalDosfilterListener(this.nonGlobalDosFilterListener);
        app.addGlobalDosfilterListener(this.globalDosFilterListener);
        app.createServer();
        this.server = app.createServer();
        this.server.start();
        this.executor = Executors.newScheduledThreadPool(4);
        this.client = ClientBuilder.newClient((Configuration)app.resourceConfig.getConfiguration());
    }

    @AfterEach
    public void tearDown() throws Exception {
        this.server.stop();
        this.server.join();
        this.client.close();
        this.awaitTerminationAfterShutdown(this.executor);
    }

    @RepeatedTest(value=5, name="{displayName} :: repetition {currentRepetition} of {totalRepetitions}")
    @DisplayName(value="test_dosFilterMultiListener_noRequestRejected_CheckNoListenerCalled")
    public void test_dosFilterMultiListener_noRequestRejected_CheckNoListenerCalled() {
        int warmupRequests = 10;
        int totalRequests = 20;
        int response200s = this.hammerAtConstantRate(this.server.getURI(), "/public/hello", Duration.ofMillis(1L), 10, 20);
        Assertions.assertEquals((int)10, (int)response200s);
        for (KafkaMetric metric : TestMetricsReporter.getMetricTimeseries()) {
            Object metricValue;
            if (metric.metricName().name().equals("request-error-count") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("sampledstat"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error count metrics should be measurable");
                double errorCountValue = (Double)metricValue;
                Assertions.assertEquals((double)0.0, (double)errorCountValue, (String)("Actual: " + errorCountValue));
            }
            if (metric.metricName().name().equals("request-error-total") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("cumulativesum"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error total metrics should be measurable");
                double errorTotalValue = (Double)metricValue;
                Assertions.assertEquals((double)0.0, (double)errorTotalValue, (String)("Actual: " + errorTotalValue));
            }
            if (!metric.metricName().name().equals("request-error-rate") || !metric.metricName().group().equals("jetty-metrics") || !metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) continue;
            Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("rate"));
            metricValue = metric.metricValue();
            Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error rate metrics should be measurable");
            double errorRateValue = (Double)metricValue;
            Assertions.assertEquals((double)0.0, (double)errorRateValue, (String)("Actual: " + errorRateValue));
        }
        Assertions.assertEquals((int)this.nonGlobalDosFilterListener.rejectedCounter.get(), (int)0);
        Assertions.assertEquals((int)this.globalDosFilterListener.rejectedCounter.get(), (int)0);
    }

    @RepeatedTest(value=5, name="{displayName} :: repetition {currentRepetition} of {totalRepetitions}")
    @DisplayName(value="test_dosFilterMultiListener_withNonGlobalDosFilterRejecting_CheckRelevantListenersCalled")
    public void test_dosFilterMultiListener_withNonGlobalDosFilterRejecting_CheckRelevantListenersCalled() {
        int warmupRequests = 20;
        int totalRequests = 100;
        int response200s = this.hammerAtConstantRate(this.server.getURI(), "/public/hello", Duration.ofMillis(1L), 20, 100);
        Assertions.assertEquals((int)5, (int)response200s);
        for (KafkaMetric metric : TestMetricsReporter.getMetricTimeseries()) {
            Object metricValue;
            if (metric.metricName().name().equals("request-error-count") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("sampledstat"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error count metrics should be measurable");
                double errorCountValue = (Double)metricValue;
                Assertions.assertEquals((double)75.0, (double)errorCountValue, (String)("Actual: " + errorCountValue));
            }
            if (metric.metricName().name().equals("request-error-total") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("cumulativesum"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error total metrics should be measurable");
                double errorTotalValue = (Double)metricValue;
                Assertions.assertEquals((double)75.0, (double)errorTotalValue, (String)("Actual: " + errorTotalValue));
            }
            if (!metric.metricName().name().equals("request-error-rate") || !metric.metricName().group().equals("jetty-metrics") || !metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) continue;
            Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("rate"));
            metricValue = metric.metricValue();
            Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error rate metrics should be measurable");
            double errorRateValue = (Double)metricValue;
            Assertions.assertEquals((double)Math.floor(2.5), (double)Math.floor(errorRateValue), (String)("Actual: " + errorRateValue));
        }
        Assertions.assertEquals((int)this.nonGlobalDosFilterListener.rejectedCounter.get(), (int)75);
        Assertions.assertEquals((int)this.globalDosFilterListener.rejectedCounter.get(), (int)0);
    }

    @RepeatedTest(value=5, name="{displayName} :: repetition {currentRepetition} of {totalRepetitions}")
    @DisplayName(value="test_dosFilterMultiListener_withGlobalDosFilterRejecting_CheckRelevantListenersCalled")
    public void test_dosFilterMultiListener_withGlobalDosFilterRejecting_CheckRelevantListenersCalled() {
        int warmupRequests = 20;
        int totalRequests = 100;
        int response200s = this.hammerAtConstantRate(this.server.getURI(), "/public/hello", Duration.ofMillis(1L), 20, 100);
        Assertions.assertEquals((int)5, (int)response200s);
        for (KafkaMetric metric : TestMetricsReporter.getMetricTimeseries()) {
            Object metricValue;
            if (metric.metricName().name().equals("request-error-count") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("sampledstat"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error count metrics should be measurable");
                double errorCountValue = (Double)metricValue;
                Assertions.assertEquals((double)75.0, (double)errorCountValue, (String)("Actual: " + errorCountValue));
            }
            if (metric.metricName().name().equals("request-error-total") && metric.metricName().group().equals("jetty-metrics") && metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) {
                Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("cumulativesum"));
                metricValue = metric.metricValue();
                Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error total metrics should be measurable");
                double errorTotalValue = (Double)metricValue;
                Assertions.assertEquals((double)75.0, (double)errorTotalValue, (String)("Actual: " + errorTotalValue));
            }
            if (!metric.metricName().name().equals("request-error-rate") || !metric.metricName().group().equals("jetty-metrics") || !metric.metricName().tags().getOrDefault("http_status_code", "").equals("429")) continue;
            Assertions.assertTrue((boolean)metric.measurable().toString().toLowerCase().startsWith("rate"));
            metricValue = metric.metricValue();
            Assertions.assertTrue((boolean)(metricValue instanceof Double), (String)"Error rate metrics should be measurable");
            double errorRateValue = (Double)metricValue;
            Assertions.assertEquals((double)Math.floor(2.5), (double)Math.floor(errorRateValue), (String)("Actual: " + errorRateValue));
        }
        Assertions.assertEquals((int)this.globalDosFilterListener.rejectedCounter.get(), (int)75);
        Assertions.assertEquals((int)this.nonGlobalDosFilterListener.rejectedCounter.get(), (int)0);
    }

    private int hammerAtConstantRate(URI server, String path, Duration rate, int warmupRequests, int totalRequests) {
        Preconditions.checkArgument((!rate.isNegative() ? 1 : 0) != 0, (Object)"rate must be non-negative");
        Preconditions.checkArgument((warmupRequests <= totalRequests ? 1 : 0) != 0, (Object)"warmupRequests must be at most totalRequests");
        List responses = IntStream.range(0, totalRequests).mapToObj(i -> this.executor.schedule(() -> this.client.target(server).path(path).request(new MediaType[]{MediaType.APPLICATION_JSON_TYPE}).get(), (long)i * rate.toMillis(), TimeUnit.MILLISECONDS)).collect(Collectors.toList()).stream().map(future -> {
            try {
                return (Response)future.get();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }).collect(Collectors.toList());
        for (Response response2 : responses) {
            int status = response2.getStatus();
            if (status == 200 || status == 429) continue;
            Assertions.fail((String)String.format("Expected HTTP 200 or HTTP 429, but got HTTP %d instead: %s", status, response2.readEntity(String.class)));
        }
        return (int)responses.subList(warmupRequests, responses.size()).stream().filter(response -> response.getStatus() == Response.Status.OK.getStatusCode()).count();
    }

    private void awaitTerminationAfterShutdown(ExecutorService threadPool) {
        threadPool.shutdown();
        try {
            if (!threadPool.awaitTermination(60L, TimeUnit.SECONDS)) {
                threadPool.shutdownNow();
            }
        }
        catch (InterruptedException ex) {
            threadPool.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    class TestDosFilterListener
    extends DoSFilter.Listener {
        AtomicInteger rejectedCounter = new AtomicInteger(0);

        TestDosFilterListener() {
        }

        public DoSFilter.Action onRequestOverLimit(HttpServletRequest request, DoSFilter.OverLimit overlimit, DoSFilter dosFilter) {
            DoSFilter.Action action = DoSFilter.Action.fromDelay((long)dosFilter.getDelayMs());
            if (action == DoSFilter.Action.REJECT) {
                this.rejectedCounter.addAndGet(1);
            }
            return action;
        }
    }

    @Produces(value={"application/json"})
    @Path(value="/public/")
    public static class PublicResource {
        @GET
        @Path(value="/hello")
        public String hello() {
            return "hello";
        }
    }

    private static class ApplicationWithDoSFilterEnabled
    extends Application<TestRestConfig> {
        Configurable<?> resourceConfig;

        ApplicationWithDoSFilterEnabled(TestRestConfig props) {
            super((RestConfig)props);
        }

        public void setupResources(Configurable<?> config, TestRestConfig appConfig) {
            this.resourceConfig = config;
            config.register(PublicResource.class);
            config.property("jersey.config.server.response.setStatusOverSendError", (Object)true);
        }
    }
}

