package io.confluent.rest;

import io.confluent.rest.annotations.PerformanceMetric;
import io.confluent.rest.mapr.test.MaprHomeSupport;
import io.confluent.rest.mapr.test.MaprTestLoginModule;
import io.confluent.rest.mapr.test.MaprTestLoginRule;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicReference;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.websocket.DeploymentException;
import javax.websocket.EndpointConfig;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import javax.ws.rs.GET;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Configurable;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.SecurityContext;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.kafka.common.metrics.KafkaMetric;
import org.asynchttpclient.BoundRequestBuilder;
import org.asynchttpclient.Dsl;
import org.asynchttpclient.ws.WebSocket;
import org.asynchttpclient.ws.WebSocketListener;
import org.asynchttpclient.ws.WebSocketUpgradeHandler;
import org.eclipse.jetty.websocket.jsr356.server.ServerContainer;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/confluent/rest/SaslTest.class */
public class SaslTest {
    private static final Logger log;
    private static final String NEHA_BASIC_AUTH = "bmVoYTpha2Zhaw==";
    private static final String JUN_BASIC_AUTH = "anVuOmthZmthLQ==";
    private static final String HTTP_URI = "http://localhost:8080";
    private static final String WS_URI = "ws://localhost:8080/ws";
    private static final Pattern WS_ERROR_PATTERN;
    private SaslTestApplication app;
    private CloseableHttpClient httpclient;

    @Rule
    public final TemporaryFolder tmpFolder = new TemporaryFolder();

    @Rule
    public final MaprTestLoginRule loginModule = MaprTestLoginRule.forHadoopSimpleAndJpam();

    /* loaded from: input_file:io/confluent/rest/SaslTest$SaslTestApplication.class */
    private static class SaslTestApplication extends Application<TestRestConfig> {
        private SaslTestApplication(TestRestConfig testRestConfig) {
            super(testRestConfig);
        }

        public void setupResources(Configurable<?> configurable, TestRestConfig testRestConfig) {
            configurable.register(new SaslTestResource());
        }

        protected void registerWebSocketEndpoints(ServerContainer serverContainer) {
            try {
                serverContainer.addEndpoint(ServerEndpointConfig.Builder.create(WSEndpoint.class, WSEndpoint.class.getAnnotation(ServerEndpoint.class).value()).build());
            } catch (DeploymentException e) {
                Assert.fail("Invalid test");
            }
        }

        public Map<String, String> getMetricsTags() {
            return Collections.singletonMap("instance-id", "1");
        }

        public /* bridge */ /* synthetic */ void setupResources(Configurable configurable, RestConfig restConfig) {
            setupResources((Configurable<?>) configurable, (TestRestConfig) restConfig);
        }
    }

    @Produces({"text/plain"})
    @Path("/")
    /* loaded from: input_file:io/confluent/rest/SaslTest$SaslTestResource.class */
    public static class SaslTestResource {
        @GET
        @Path("/principal")
        @PerformanceMetric("principal")
        public String principal(@Context SecurityContext securityContext) {
            return securityContext.getUserPrincipal().getName();
        }

        @GET
        @Path("/role/{role}")
        @PerformanceMetric("role")
        public boolean hello(@PathParam("role") String str, @Context SecurityContext securityContext) {
            return securityContext.isUserInRole(str);
        }
    }

    @ServerEndpoint("/test")
    /* loaded from: input_file:io/confluent/rest/SaslTest$WSEndpoint.class */
    public static class WSEndpoint {
        @OnOpen
        public void onOpen(Session session, EndpointConfig endpointConfig) {
            session.getAsyncRemote().sendText("Test message", sendResult -> {
                if (sendResult.isOK()) {
                    return;
                }
                SaslTest.log.warn("Error sending websocket message for session {}", session.getId(), sendResult.getException());
            });
        }
    }

    @Before
    public void setUp() throws Exception {
        MaprTestLoginModule.restrict("jay", str -> {
            return str.equals("kafka");
        });
        MaprTestLoginModule.restrict("neha", str2 -> {
            return str2.equals("akfak");
        });
        MaprTestLoginModule.restrict("jun", str3 -> {
            return str3.equals("another-password");
        });
        this.httpclient = HttpClients.createDefault();
        TestMetricsReporter.reset();
        Properties properties = new Properties();
        properties.put("listeners", HTTP_URI);
        properties.put("metric.reporters", "io.confluent.rest.TestMetricsReporter");
        configAuthentication(properties);
        this.app = new SaslTestApplication(TestRestConfig.maprCompatible(properties));
        this.app.start();
    }

    @After
    public void cleanup() throws Exception {
        assertMetricsCollected();
        this.httpclient.close();
        this.app.stop();
    }

    private void configAuthentication(Properties properties) {
        properties.put("authentication.enable", true);
        properties.put("authentication.realm", "c3");
        properties.put("authentication.roles", Collections.singletonList("Administrators"));
    }

    @Test
    public void testNoAuthAttempt() throws Exception {
        CloseableHttpResponse makeGetRequest = makeGetRequest("/test");
        try {
            Assert.assertEquals(Response.Status.UNAUTHORIZED.getStatusCode(), makeGetRequest.getStatusLine().getStatusCode());
            if (makeGetRequest != null) {
                makeGetRequest.close();
            }
        } catch (Throwable th) {
            if (makeGetRequest != null) {
                try {
                    makeGetRequest.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testNoAuthAttemptOnWs() throws Exception {
        Assert.assertEquals(Response.Status.UNAUTHORIZED.getStatusCode(), makeWsGetRequest(null));
    }

    @Test
    public void testBadLoginAttempt() throws Exception {
        CloseableHttpResponse makeGetRequest = makeGetRequest("/test", "dGVzdA==");
        try {
            Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), makeGetRequest.getStatusLine().getStatusCode());
            if (makeGetRequest != null) {
                makeGetRequest.close();
            }
        } catch (Throwable th) {
            if (makeGetRequest != null) {
                try {
                    makeGetRequest.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testBadLoginAttemptOnWs() throws Exception {
        Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), makeWsGetRequest("dGVzdA=="));
    }

    @Test
    public void testAuthorizedAttempt() throws Exception {
        CloseableHttpResponse makeGetRequest = makeGetRequest("/principal", NEHA_BASIC_AUTH);
        try {
            Assert.assertEquals(Response.Status.OK.getStatusCode(), makeGetRequest.getStatusLine().getStatusCode());
            Assert.assertEquals("neha", EntityUtils.toString(makeGetRequest.getEntity()));
            if (makeGetRequest != null) {
                makeGetRequest.close();
            }
            CloseableHttpResponse makeGetRequest2 = makeGetRequest("/role/Administrators", NEHA_BASIC_AUTH);
            try {
                Assert.assertEquals(Response.Status.OK.getStatusCode(), makeGetRequest2.getStatusLine().getStatusCode());
                Assert.assertEquals("false", EntityUtils.toString(makeGetRequest2.getEntity()));
                if (makeGetRequest2 != null) {
                    makeGetRequest2.close();
                }
                makeGetRequest = makeGetRequest("/role/blah", NEHA_BASIC_AUTH);
                try {
                    Assert.assertEquals(Response.Status.OK.getStatusCode(), makeGetRequest.getStatusLine().getStatusCode());
                    Assert.assertEquals("false", EntityUtils.toString(makeGetRequest.getEntity()));
                    if (makeGetRequest != null) {
                        makeGetRequest.close();
                    }
                } finally {
                }
            } finally {
            }
        } finally {
            if (makeGetRequest != null) {
                try {
                    makeGetRequest.close();
                } catch (Throwable th) {
                    th.addSuppressed(th);
                }
            }
        }
    }

    @Test
    public void testAuthorizedAttemptOnWs() throws Exception {
        Assert.assertEquals(Response.Status.OK.getStatusCode(), makeWsGetRequest(NEHA_BASIC_AUTH));
    }

    @Test
    public void testUnauthorizedAttempt() throws Exception {
        CloseableHttpResponse makeGetRequest = makeGetRequest("/principal", JUN_BASIC_AUTH);
        try {
            Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), makeGetRequest.getStatusLine().getStatusCode());
            if (makeGetRequest != null) {
                makeGetRequest.close();
            }
        } catch (Throwable th) {
            if (makeGetRequest != null) {
                try {
                    makeGetRequest.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testUnAuthorizedAttemptOnWs() throws Exception {
        Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), makeWsGetRequest(JUN_BASIC_AUTH));
    }

    private void assertMetricsCollected() {
        Assert.assertNotEquals("Expected to have metrics.", 0L, TestMetricsReporter.getMetricTimeseries().size());
        for (KafkaMetric kafkaMetric : TestMetricsReporter.getMetricTimeseries()) {
            if (kafkaMetric.metricName().name().equals("request-latency-max")) {
                Assert.assertTrue("Metrics should be collected (max latency shouldn't be 0)", kafkaMetric.value() != 0.0d);
            }
        }
    }

    private CloseableHttpResponse makeGetRequest(String str, String str2) throws Exception {
        log.debug("Making GET http://localhost:8080" + str);
        HttpGet httpGet = new HttpGet(HTTP_URI + str);
        if (str2 != null) {
            httpGet.setHeader("Authorization", "Basic " + str2);
        }
        return this.httpclient.execute(httpGet);
    }

    private CloseableHttpResponse makeGetRequest(String str) throws Exception {
        return makeGetRequest(str, null);
    }

    private int makeWsGetRequest(String str) throws Exception {
        log.debug("Making WebSocket GET ws://localhost:8080/ws/test");
        final AtomicReference atomicReference = new AtomicReference();
        WebSocketUpgradeHandler build = new WebSocketUpgradeHandler.Builder().addWebSocketListener(new WebSocketListener() { // from class: io.confluent.rest.SaslTest.1
            public void onOpen(WebSocket webSocket) {
            }

            public void onClose(WebSocket webSocket, int i, String str2) {
            }

            public void onError(Throwable th) {
                SaslTest.log.info("Websocket failed", th);
                atomicReference.set(th);
            }
        }).build();
        BoundRequestBuilder prepareGet = Dsl.asyncHttpClient().prepareGet("ws://localhost:8080/ws/test");
        if (str != null) {
            prepareGet = (BoundRequestBuilder) prepareGet.addHeader("Authorization", "Basic " + str);
        }
        WebSocket webSocket = (WebSocket) prepareGet.setRequestTimeout(10000).execute(build).get();
        if (atomicReference.get() != null) {
            return extractStatusCode(((Throwable) atomicReference.get()).getMessage());
        }
        webSocket.sendCloseFrame();
        return Response.Status.OK.getStatusCode();
    }

    private static int extractStatusCode(String str) {
        Matcher matcher = WS_ERROR_PATTERN.matcher(str);
        Assert.assertTrue("Test invalid", matcher.matches());
        return Integer.parseInt(matcher.group(1));
    }

    static {
        MaprHomeSupport.activate();
        log = LoggerFactory.getLogger(SaslTest.class);
        WS_ERROR_PATTERN = Pattern.compile(".*code=(\\d+).*");
    }
}
