/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.web.security.oidc.client.web;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import org.apache.nifi.authorization.user.NiFiUser;
import org.apache.nifi.authorization.user.NiFiUserUtils;
import org.apache.nifi.web.security.cookie.ApplicationCookieName;
import org.apache.nifi.web.security.cookie.ApplicationCookieService;
import org.apache.nifi.web.security.cookie.StandardApplicationCookieService;
import org.apache.nifi.web.security.jwt.provider.BearerTokenProvider;
import org.apache.nifi.web.security.jwt.provider.SupportedClaim;
import org.apache.nifi.web.security.oidc.client.web.OidcAuthorizedClient;
import org.apache.nifi.web.security.oidc.client.web.OidcRegistrationProperty;
import org.apache.nifi.web.security.token.LoginAuthenticationToken;
import org.apache.nifi.web.util.RequestUriBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.web.filter.OncePerRequestFilter;

public class OidcBearerTokenRefreshFilter
extends OncePerRequestFilter {
    private static final String ROOT_PATH = "/";
    private static final Logger logger = LoggerFactory.getLogger(OidcBearerTokenRefreshFilter.class);
    private final AntPathRequestMatcher currentUserRequestMatcher = new AntPathRequestMatcher("/flow/current-user");
    private final ApplicationCookieService applicationCookieService = new StandardApplicationCookieService();
    private final Duration refreshWindow;
    private final ConcurrentMap<String, Instant> refreshRequests = new ConcurrentHashMap<String, Instant>();
    private final BearerTokenProvider bearerTokenProvider;
    private final BearerTokenResolver bearerTokenResolver;
    private final JwtDecoder jwtDecoder;
    private final OAuth2AuthorizedClientRepository authorizedClientRepository;
    private final OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> refreshTokenResponseClient;

    public OidcBearerTokenRefreshFilter(Duration refreshWindow, BearerTokenProvider bearerTokenProvider, BearerTokenResolver bearerTokenResolver, JwtDecoder jwtDecoder, OAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> refreshTokenResponseClient) {
        this.refreshWindow = Objects.requireNonNull(refreshWindow, "Refresh Window required");
        this.bearerTokenProvider = Objects.requireNonNull(bearerTokenProvider, "Bearer Token Provider required");
        this.bearerTokenResolver = Objects.requireNonNull(bearerTokenResolver, "Bearer Token Resolver required");
        this.jwtDecoder = Objects.requireNonNull(jwtDecoder, "JWT Decoder required");
        this.authorizedClientRepository = Objects.requireNonNull(authorizedClientRepository, "Authorized Client Repository required");
        this.refreshTokenResponseClient = Objects.requireNonNull(refreshTokenResponseClient, "Refresh Token Response Client required");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String userIdentity;
        if (this.currentUserRequestMatcher.matches(request) && this.refreshRequests.putIfAbsent(userIdentity = NiFiUserUtils.getNiFiUserIdentity(), Instant.now()) == null) {
            logger.debug("Identity [{}] Bearer Token refresh processing started", (Object)userIdentity);
            try {
                this.processRequest(userIdentity, request, response);
            }
            catch (Exception e) {
                logger.error("Identity [{}] Bearer Token refresh processing failed", (Object)userIdentity, (Object)e);
            }
            finally {
                this.refreshRequests.remove(userIdentity);
                logger.debug("Identity [{}] Bearer Token refresh processing completed", (Object)userIdentity);
            }
        }
        filterChain.doFilter((ServletRequest)request, (ServletResponse)response);
    }

    private void processRequest(String userIdentity, HttpServletRequest request, HttpServletResponse response) {
        if (this.isRefreshRequired(userIdentity, request)) {
            logger.info("Identity [{}] Bearer Token refresh required", (Object)userIdentity);
            OidcAuthorizedClient authorizedClient = this.loadAuthorizedClient(request);
            if (authorizedClient == null) {
                logger.warn("Identity [{}] OIDC Authorized Client not found", (Object)userIdentity);
            } else {
                OAuth2AccessTokenResponse tokenResponse = this.getRefreshTokenResponse(authorizedClient, request, response);
                if (tokenResponse == null) {
                    logger.warn("Identity [{}] OpenID Connect Refresh Token not found", (Object)userIdentity);
                } else {
                    URI resourceUri = RequestUriBuilder.fromHttpServletRequest((HttpServletRequest)request).path(ROOT_PATH).build();
                    String bearerToken = this.getBearerToken(userIdentity, tokenResponse);
                    this.applicationCookieService.addSessionCookie(resourceUri, response, ApplicationCookieName.AUTHORIZATION_BEARER, bearerToken);
                }
            }
        }
    }

    private boolean isRefreshRequired(String userIdentity, HttpServletRequest request) {
        boolean required;
        String token = this.bearerTokenResolver.resolve(request);
        if (token == null) {
            logger.debug("Identity [{}] Bearer Token not found", (Object)userIdentity);
            required = false;
        } else {
            Jwt jwt = this.jwtDecoder.decode(token);
            Instant expiresAt = Objects.requireNonNull(jwt.getExpiresAt(), "Bearer Token expiration claim not found");
            Instant refreshRequired = Instant.now().plus(this.refreshWindow);
            required = refreshRequired.isAfter(expiresAt);
        }
        return required;
    }

    private OidcAuthorizedClient loadAuthorizedClient(HttpServletRequest request) {
        SecurityContext context = SecurityContextHolder.getContext();
        Authentication principal = context.getAuthentication();
        return (OidcAuthorizedClient)this.authorizedClientRepository.loadAuthorizedClient(OidcRegistrationProperty.REGISTRATION_ID.getProperty(), principal, request);
    }

    private String getBearerToken(String userIdentity, OAuth2AccessTokenResponse tokenResponse) {
        OAuth2AccessToken accessToken = tokenResponse.getAccessToken();
        Instant sessionExpiration = this.getSessionExpiration(accessToken);
        Set<? extends GrantedAuthority> providerAuthorities = this.getProviderAuthorities();
        LoginAuthenticationToken loginAuthenticationToken = new LoginAuthenticationToken(userIdentity, sessionExpiration, providerAuthorities);
        return this.bearerTokenProvider.getBearerToken(loginAuthenticationToken);
    }

    private Instant getSessionExpiration(OAuth2AccessToken accessToken) {
        Instant tokenExpiration = accessToken.getExpiresAt();
        if (tokenExpiration == null) {
            throw new IllegalArgumentException("OpenID Connect Access Token expiration claim not found");
        }
        return tokenExpiration;
    }

    private OAuth2AccessTokenResponse getRefreshTokenResponse(OidcAuthorizedClient authorizedClient, HttpServletRequest request, HttpServletResponse response) {
        OAuth2AccessTokenResponse tokenResponse;
        if (authorizedClient.getRefreshToken() == null) {
            tokenResponse = null;
        } else {
            tokenResponse = this.getRefreshTokenResponse(authorizedClient);
            OAuth2RefreshToken responseRefreshToken = tokenResponse.getRefreshToken();
            OAuth2RefreshToken refreshToken = responseRefreshToken == null ? authorizedClient.getRefreshToken() : responseRefreshToken;
            OidcAuthorizedClient refreshedAuthorizedClient = new OidcAuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), tokenResponse.getAccessToken(), refreshToken, authorizedClient.getIdToken());
            OAuth2AuthenticationToken authenticationToken = this.getAuthenticationToken(authorizedClient);
            this.authorizedClientRepository.saveAuthorizedClient((OAuth2AuthorizedClient)refreshedAuthorizedClient, (Authentication)authenticationToken, request, response);
        }
        return tokenResponse;
    }

    private OAuth2AccessTokenResponse getRefreshTokenResponse(OidcAuthorizedClient authorizedClient) {
        ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
        OAuth2AccessToken accessToken = authorizedClient.getAccessToken();
        OAuth2RefreshToken refreshToken = Objects.requireNonNull(authorizedClient.getRefreshToken(), "Refresh Token required");
        OAuth2RefreshTokenGrantRequest grantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, accessToken, refreshToken);
        return this.refreshTokenResponseClient.getTokenResponse((AbstractOAuth2AuthorizationGrantRequest)grantRequest);
    }

    private OAuth2AuthenticationToken getAuthenticationToken(OidcAuthorizedClient authorizedClient) {
        ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
        OidcIdToken idToken = authorizedClient.getIdToken();
        DefaultOidcUser oidcUser = new DefaultOidcUser(Collections.emptyList(), idToken, SupportedClaim.SUBJECT.getClaim());
        return new OAuth2AuthenticationToken((OAuth2User)oidcUser, Collections.emptyList(), clientRegistration.getRegistrationId());
    }

    private Set<? extends GrantedAuthority> getProviderAuthorities() {
        NiFiUser user = NiFiUserUtils.getNiFiUser();
        Set providerGroups = user.getIdentityProviderGroups();
        Set authorities = providerGroups == null ? Collections.emptySet() : providerGroups.stream().map(SimpleGrantedAuthority::new).collect(Collectors.toSet());
        return authorities;
    }
}

