package com.predic8.membrane.core.interceptor.jwt;

import com.bornium.security.oauth2openid.Constants;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCChildElement;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.core.Router;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.http.Response;
import com.predic8.membrane.core.interceptor.AbstractInterceptor;
import com.predic8.membrane.core.interceptor.Outcome;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import org.jose4j.base64url.Base64Url;
import org.jose4j.jwk.RsaJsonWebKey;
import org.jose4j.jwt.consumer.JwtConsumer;
import org.jose4j.jwt.consumer.JwtConsumerBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.xml.BeanDefinitionParserDelegate;

@MCElement(name = "jwtAuth")
/* loaded from: input_file:lib/service-proxy-core-4.8.5.jar:com/predic8/membrane/core/interceptor/jwt/JwtAuthInterceptor.class */
public class JwtAuthInterceptor extends AbstractInterceptor {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) JwtAuthInterceptor.class);
    public static final String ERROR_JWT_NOT_FOUND = "Could not retrieve JWT";
    public static final String ERROR_MALFORMED_COMPACT_SERIALIZATION = "JWTs compact serialization not valid";
    public static final String ERROR_DECODED_HEADER_NOT_JSON = "JWT header is not valid JSON";
    public static final String ERROR_NO_KID_GIVEN = "JWT does not contain a kid";
    public static final String ERROR_UNKNOWN_KEY = "JWT signed by unknown key";
    public static final String ERROR_VALIDATION_FAILED = "JWT validation failed";
    ObjectMapper mapper = new ObjectMapper();
    JwtRetriever jwtRetriever;
    Jwks jwks;
    String expectedAud;
    volatile HashMap<String, RsaJsonWebKey> kidToKey;

    @Override // com.predic8.membrane.core.interceptor.AbstractInterceptor, com.predic8.membrane.core.interceptor.Interceptor
    public void init(Router router) throws Exception {
        super.init(router);
        if (this.jwtRetriever == null) {
            this.jwtRetriever = new HeaderJwtRetriever("Authorization", Constants.PARAMETER_VALUE_BEARER);
        }
        this.jwks.init(router.getResolverMap(), router.getBaseLocation());
        this.kidToKey = (HashMap) this.jwks.getJwks().stream().map(jwk -> {
            try {
                return new RsaJsonWebKey((Map<String, Object>) this.mapper.readValue(jwk.getJwk(router.getResolverMap(), router.getBaseLocation(), this.mapper), Map.class));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }).collect(HashMap::new, (hashMap, rsaJsonWebKey) -> {
        }, (hashMap2, hashMap3) -> {
            hashMap2.putAll(hashMap3);
        });
        if (this.kidToKey.size() == 0) {
            throw new RuntimeException("No JWKs given or none resolvable - please specify at least one resolvable JWK");
        }
    }

    @Override // com.predic8.membrane.core.interceptor.AbstractInterceptor, com.predic8.membrane.core.interceptor.Interceptor
    public Outcome handleRequest(Exchange exchange) throws Exception {
        try {
            String str = this.jwtRetriever.get(exchange);
            if (str == null) {
                return setJsonErrorAndReturn(null, exchange, 400, ERROR_JWT_NOT_FOUND);
            }
            try {
                String[] split = str.split(Pattern.quote("."));
                if (split.length < 3) {
                    return setJsonErrorAndReturn(null, exchange, 400, ERROR_MALFORMED_COMPACT_SERIALIZATION);
                }
                try {
                    try {
                        Object obj = ((Map) this.mapper.readValue(new String(Base64Url.decode(split[0])), Map.class)).get("kid");
                        if (obj == null) {
                            throw new RuntimeException();
                        }
                        RsaJsonWebKey rsaJsonWebKey = this.kidToKey.get(obj.toString());
                        if (rsaJsonWebKey == null) {
                            return setJsonErrorAndReturn(null, exchange, 400, ERROR_UNKNOWN_KEY);
                        }
                        try {
                            exchange.getProperties().put("jwt", createValidator(rsaJsonWebKey).processToClaims(str).getClaimsMap());
                            return Outcome.CONTINUE;
                        } catch (Exception e) {
                            return setJsonErrorAndReturn(e, exchange, 400, ERROR_VALIDATION_FAILED);
                        }
                    } catch (Exception e2) {
                        return setJsonErrorAndReturn(e2, exchange, 400, ERROR_NO_KID_GIVEN);
                    }
                } catch (Exception e3) {
                    return setJsonErrorAndReturn(e3, exchange, 400, ERROR_DECODED_HEADER_NOT_JSON);
                }
            } catch (Exception e4) {
                return setJsonErrorAndReturn(e4, exchange, 400, ERROR_MALFORMED_COMPACT_SERIALIZATION);
            }
        } catch (Exception e5) {
            return setJsonErrorAndReturn(e5, exchange, 400, ERROR_JWT_NOT_FOUND);
        }
    }

    private JwtConsumer createValidator(RsaJsonWebKey rsaJsonWebKey) {
        JwtConsumerBuilder verificationKey = new JwtConsumerBuilder().setRequireExpirationTime().setAllowedClockSkewInSeconds(30).setRequireSubject().setVerificationKey(rsaJsonWebKey.getRsaPublicKey());
        if (this.expectedAud != null && !this.expectedAud.isEmpty()) {
            verificationKey.setExpectedAudience(this.expectedAud);
        }
        return verificationKey.build();
    }

    private Outcome setJsonErrorAndReturn(Exception exc, Exchange exchange, int i, String str) {
        if (exc != null) {
            LOG.error("", (Throwable) exc);
        }
        try {
            exchange.setResponse(Response.ResponseBuilder.newInstance().status(i, "Bad Request").body(this.mapper.writeValueAsString(ImmutableMap.builder().put("code", Integer.valueOf(i)).put(BeanDefinitionParserDelegate.DESCRIPTION_ELEMENT, str).build())).build());
            return Outcome.RETURN;
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public JwtRetriever getJwtRetriever() {
        return this.jwtRetriever;
    }

    @MCChildElement
    public void setJwtRetriever(JwtRetriever jwtRetriever) {
        this.jwtRetriever = jwtRetriever;
    }

    public Jwks getJwks() {
        return this.jwks;
    }

    @MCChildElement(order = 1)
    public void setJwks(Jwks jwks) {
        this.jwks = jwks;
    }

    public String getExpectedAud() {
        return this.expectedAud;
    }

    @MCAttribute
    public JwtAuthInterceptor setExpectedAud(String str) {
        this.expectedAud = str;
        return this;
    }
}
