package io.datarouter.web.user.authenticate.saml;

import io.datarouter.util.Require;
import java.io.ByteArrayInputStream;
import java.io.StringWriter;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Iterator;
import java.util.stream.Stream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.namespace.QName;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.security.impl.RandomIdentifierGenerationStrategy;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.decoder.MessageDecodingException;
import org.opensaml.messaging.encoder.MessageEncodingException;
import org.opensaml.messaging.handler.MessageHandlerException;
import org.opensaml.messaging.handler.impl.BasicMessageHandlerChain;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.SignableSAMLObject;
import org.opensaml.saml.common.binding.security.impl.MessageLifetimeSecurityHandler;
import org.opensaml.saml.common.binding.security.impl.ReceivedEndpointSecurityHandler;
import org.opensaml.saml.common.messaging.context.SAMLBindingContext;
import org.opensaml.saml.common.messaging.context.SAMLEndpointContext;
import org.opensaml.saml.common.messaging.context.SAMLPeerEntityContext;
import org.opensaml.saml.saml2.binding.decoding.impl.HTTPPostDecoder;
import org.opensaml.saml.saml2.binding.encoding.impl.HTTPRedirectDeflateEncoder;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.AuthnContextClassRef;
import org.opensaml.saml.saml2.core.AuthnContextComparisonTypeEnumeration;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameIDPolicy;
import org.opensaml.saml.saml2.core.RequestedAuthnContext;
import org.opensaml.saml.saml2.core.Response;
import org.opensaml.saml.saml2.core.Scoping;
import org.opensaml.saml.saml2.core.impl.AuthnContextClassRefBuilder;
import org.opensaml.saml.saml2.metadata.Endpoint;
import org.opensaml.saml.saml2.metadata.SingleSignOnService;
import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.crypto.KeySupport;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.xmlsec.SignatureSigningParameters;
import org.opensaml.xmlsec.context.SecurityParametersContext;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.support.SignatureException;
import org.opensaml.xmlsec.signature.support.SignatureValidator;
import org.opensaml.xmlsec.signature.support.Signer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;

/* loaded from: input_file:io/datarouter/web/user/authenticate/saml/SamlTool.class */
public class SamlTool {
    private static final Logger logger = LoggerFactory.getLogger(SamlTool.class);
    private static final RandomIdentifierGenerationStrategy secureRandomIdGenerator = new RandomIdentifierGenerationStrategy();
    public static final String ROLE_GROUP_ATTRIBUTE_NAME = "groupAttributes";
    public static final String ROLE_ATTRIBUTE_NAME = "roleAttributes";
    public static final String DEFAULT_ENTITY_ID = "https://datarouter.io";

    public static MessageContext buildAuthnRequestAndContext(AuthnRequestMessageConfig authnRequestMessageConfig) {
        AuthnRequest authnRequest = (AuthnRequest) build(AuthnRequest.DEFAULT_ELEMENT_NAME);
        authnRequest.setIssueInstant(Instant.now());
        authnRequest.setDestination(authnRequestMessageConfig.identityProviderSingleSignOnServiceUrl);
        authnRequest.setProtocolBinding("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST");
        authnRequest.setAssertionConsumerServiceURL(authnRequestMessageConfig.serviceProviderAssertionConsumerServiceUrl);
        authnRequest.setID(generateSecureRandomId());
        authnRequest.setIssuer(buildIssuer(authnRequestMessageConfig.serviceProviderEntityId));
        authnRequest.setNameIDPolicy(buildNameIdPolicy());
        authnRequest.setRequestedAuthnContext(buildRequestedAuthnContext());
        authnRequestMessageConfig.proxyCount.ifPresent(num -> {
            authnRequest.setScoping(buildScoping(num));
        });
        logSamlObject("SamlTool.buildAuthnRequestAndContext", authnRequest);
        MessageContext messageContext = new MessageContext();
        messageContext.setMessage(authnRequest);
        messageContext.getSubcontext(SAMLBindingContext.class, true).setRelayState(authnRequestMessageConfig.relayState);
        messageContext.getSubcontext(SAMLPeerEntityContext.class, true).getSubcontext(SAMLEndpointContext.class, true).setEndpoint(buildIdpEndpoint(authnRequestMessageConfig.identityProviderSingleSignOnServiceUrl));
        authnRequestMessageConfig.signingKeyPair.ifPresent(keyPair -> {
            SignatureSigningParameters signatureSigningParameters = new SignatureSigningParameters();
            signatureSigningParameters.setSigningCredential(new BasicCredential(keyPair.getPublic(), keyPair.getPrivate()));
            signatureSigningParameters.setSignatureAlgorithm("http://www.w3.org/2000/09/xmldsig#rsa-sha1");
            messageContext.getSubcontext(SecurityParametersContext.class, true).setSignatureSigningParameters(signatureSigningParameters);
        });
        return messageContext;
    }

    private static Scoping buildScoping(Integer num) {
        Scoping scoping = (Scoping) build(Scoping.DEFAULT_ELEMENT_NAME);
        scoping.setProxyCount(num);
        return scoping;
    }

    public static void redirectWithAuthnRequestContext(HttpServletResponse httpServletResponse, MessageContext messageContext) {
        HTTPRedirectDeflateEncoder hTTPRedirectDeflateEncoder = new HTTPRedirectDeflateEncoder();
        hTTPRedirectDeflateEncoder.setHttpServletResponse(httpServletResponse);
        hTTPRedirectDeflateEncoder.setMessageContext(messageContext);
        try {
            hTTPRedirectDeflateEncoder.initialize();
            hTTPRedirectDeflateEncoder.encode();
        } catch (ComponentInitializationException | MessageEncodingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private static NameIDPolicy buildNameIdPolicy() {
        NameIDPolicy nameIDPolicy = (NameIDPolicy) build(NameIDPolicy.DEFAULT_ELEMENT_NAME);
        nameIDPolicy.setAllowCreate(false);
        nameIDPolicy.setFormat("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress");
        return nameIDPolicy;
    }

    private static RequestedAuthnContext buildRequestedAuthnContext() {
        AuthnContextClassRef buildObject = new AuthnContextClassRefBuilder().buildObject();
        buildObject.setURI("urn:oasis:names:tc:SAML:2.0:ac:classes:Password");
        RequestedAuthnContext requestedAuthnContext = (RequestedAuthnContext) build(RequestedAuthnContext.DEFAULT_ELEMENT_NAME);
        requestedAuthnContext.setComparison(AuthnContextComparisonTypeEnumeration.MINIMUM);
        requestedAuthnContext.getAuthnContextClassRefs().add(buildObject);
        return requestedAuthnContext;
    }

    private static Endpoint buildIdpEndpoint(String str) {
        SingleSignOnService singleSignOnService = (SingleSignOnService) build(SingleSignOnService.DEFAULT_ELEMENT_NAME);
        singleSignOnService.setBinding("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect");
        singleSignOnService.setLocation(str);
        return singleSignOnService;
    }

    public static MessageContext getAndValidateResponseMessageContext(HttpServletRequest httpServletRequest, Credential credential) {
        MessageContext decodeResponse = decodeResponse(httpServletRequest);
        logSamlObject("SamlTool.getAndValidateResponseMessageContext", (SAMLObject) decodeResponse.getMessage());
        validateMessageContext(decodeResponse, httpServletRequest);
        Response response = (Response) decodeResponse.getMessage();
        verifySignature(response, credential);
        Iterator it = response.getAssertions().iterator();
        while (it.hasNext()) {
            verifySignature((Assertion) it.next(), credential);
        }
        return decodeResponse;
    }

    private static void validateMessageContext(MessageContext messageContext, HttpServletRequest httpServletRequest) {
        MessageLifetimeSecurityHandler messageLifetimeSecurityHandler = new MessageLifetimeSecurityHandler();
        messageLifetimeSecurityHandler.setClockSkew(Duration.ofMillis(1000L));
        messageLifetimeSecurityHandler.setMessageLifetime(Duration.ofMinutes(1L));
        messageLifetimeSecurityHandler.setRequiredRule(true);
        ReceivedEndpointSecurityHandler receivedEndpointSecurityHandler = new ReceivedEndpointSecurityHandler();
        receivedEndpointSecurityHandler.setHttpServletRequest(httpServletRequest);
        ArrayList arrayList = new ArrayList();
        arrayList.add(messageLifetimeSecurityHandler);
        arrayList.add(receivedEndpointSecurityHandler);
        BasicMessageHandlerChain basicMessageHandlerChain = new BasicMessageHandlerChain();
        basicMessageHandlerChain.setHandlers(arrayList);
        try {
            basicMessageHandlerChain.initialize();
            basicMessageHandlerChain.doInvoke(messageContext);
        } catch (ComponentInitializationException | MessageHandlerException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static KeyPair generateKeyPair() {
        try {
            return KeySupport.generateKeyPair("RSA", 4096, (String) null);
        } catch (NoSuchAlgorithmException | NoSuchProviderException e) {
            throw new RuntimeException("Failed to generate signingKeyPair", e);
        }
    }

    public static void signSamlObject(SignableSAMLObject signableSAMLObject, KeyPair keyPair) {
        BasicCredential simpleCredential = CredentialSupport.getSimpleCredential(keyPair.getPublic(), keyPair.getPrivate());
        Signature signature = (Signature) build(Signature.DEFAULT_ELEMENT_NAME);
        signature.setSigningCredential(simpleCredential);
        signature.setSignatureAlgorithm("http://www.w3.org/2000/09/xmldsig#rsa-sha1");
        signature.setCanonicalizationAlgorithm("http://www.w3.org/2001/10/xml-exc-c14n#");
        signableSAMLObject.setSignature(signature);
        try {
            XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(signableSAMLObject).marshall(signableSAMLObject);
            Signer.signObject(signature);
        } catch (MarshallingException | SignatureException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private static void verifySignature(SignableSAMLObject signableSAMLObject, Credential credential) {
        if (!signableSAMLObject.isSigned()) {
            throw new RuntimeException("The SAML object was not signed.");
        }
        Signature signature = signableSAMLObject.getSignature();
        try {
            new SAMLSignatureProfileValidator().validate(signature);
            SignatureValidator.validate(signature, credential);
        } catch (SignatureException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static Credential getCredentialFromEncodedRsaPublicKey(String str) {
        try {
            return new BasicCredential(KeyFactory.getInstance("RSA").generatePublic(new X509EncodedKeySpec(Base64.getDecoder().decode(str))));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
            logger.error("SAML KeyFactory failure", e);
            throw new RuntimeException(e);
        }
    }

    public static Credential getCredentialFromEncodedX509Certificate(String str) {
        try {
            return new BasicX509Credential((X509Certificate) CertificateFactory.getInstance("X.509").generateCertificate(new ByteArrayInputStream(Base64.getDecoder().decode(str))));
        } catch (CertificateException e) {
            logger.error("SAML CertificateFactory failure", e);
            throw new RuntimeException(e);
        }
    }

    private static MessageContext decodeResponse(HttpServletRequest httpServletRequest) {
        HTTPPostDecoder hTTPPostDecoder = new HTTPPostDecoder();
        hTTPPostDecoder.setHttpServletRequest(httpServletRequest);
        try {
            hTTPPostDecoder.initialize();
            hTTPPostDecoder.decode();
            return hTTPPostDecoder.getMessageContext();
        } catch (ComponentInitializationException | MessageDecodingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public static Issuer buildIssuer(String str) {
        Issuer issuer = (Issuer) build(Issuer.DEFAULT_ELEMENT_NAME);
        issuer.setValue(str);
        return issuer;
    }

    public static Stream<String> streamAttributeValuesByName(String str, Assertion assertion) {
        return assertion.getAttributeStatements().stream().map((v0) -> {
            return v0.getAttributes();
        }).flatMap((v0) -> {
            return v0.stream();
        }).filter(attribute -> {
            return str.equals(attribute.getName());
        }).map((v0) -> {
            return v0.getAttributeValues();
        }).flatMap((v0) -> {
            return v0.stream();
        }).filter(xMLObject -> {
            return xMLObject instanceof XSString;
        }).map(xMLObject2 -> {
            return ((XSString) xMLObject2).getValue();
        });
    }

    public static String getUrlInRequestContext(HttpServletRequest httpServletRequest, String str) {
        try {
            return new URL(new URL(httpServletRequest.getRequestURL().toString()), String.valueOf(httpServletRequest.getContextPath()) + str).toString();
        } catch (MalformedURLException e) {
            throw new RuntimeException("Failed to build URL. context: " + ((Object) httpServletRequest.getRequestURL()) + " path " + str);
        }
    }

    public static void logSamlObject(String str, SAMLObject sAMLObject) {
        if (sAMLObject == null) {
            logger.debug(String.valueOf(str) + " - SAMLObject is null");
            return;
        }
        Element dom = sAMLObject.getDOM();
        if (dom == null) {
            try {
                XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(sAMLObject).marshall(sAMLObject);
                dom = sAMLObject.getDOM();
            } catch (MarshallingException e) {
                logger.error(String.valueOf(str) + " - Failed to marshall SAMLObject", e);
                return;
            }
        }
        try {
            Transformer newTransformer = TransformerFactory.newInstance().newTransformer();
            newTransformer.setOutputProperty("indent", "yes");
            StreamResult streamResult = new StreamResult(new StringWriter());
            newTransformer.transform(new DOMSource(dom), streamResult);
            logger.debug(String.valueOf(str) + " - " + streamResult.getWriter().toString());
        } catch (TransformerException e2) {
            logger.error(String.valueOf(str) + " - Failed to log SAML object.", e2);
        }
    }

    public static <T> T build(QName qName) {
        return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
    }

    public static String generateSecureRandomId() {
        return secureRandomIdGenerator.generateIdentifier();
    }

    public static void throwUnlessHttps(HttpServletRequest httpServletRequest) {
        Require.equals("https", httpServletRequest.getScheme().toLowerCase(), "https is required for SAML authentication.");
    }
}
