package com.databricks.jdbc.auth;

import com.databricks.jdbc.TestConstants;
import com.databricks.jdbc.auth.JwtPrivateKeyClientCredentials;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.exception.DatabricksHttpException;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.oauth.Token;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.SignedJWT;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.ECParameterSpec;
import java.util.Map;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpUriRequest;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith({MockitoExtension.class})
/* loaded from: input_file:com/databricks/jdbc/auth/JwtPrivateKeyClientCredentialsTest.class */
public class JwtPrivateKeyClientCredentialsTest {

    @Mock
    CloseableHttpResponse httpResponse;

    @Mock
    HttpEntity httpEntity;

    @Mock
    RSAPrivateKey rsaPrivateKey;

    @Mock
    ECPrivateKey ecPrivateKey;

    @Mock
    ECParameterSpec parameterSpec;

    @Mock
    IDatabricksHttpClient httpClient;
    private JwtPrivateKeyClientCredentials clientCredentials = new JwtPrivateKeyClientCredentials.Builder().withHttpClient(this.httpClient).withClientId(TestConstants.TEST_CLIENT_ID).withJwtKid(TestConstants.TEST_JWT_KID).withJwtKeyFile(TestConstants.TEST_JWT_KEY_FILE).withJwtAlgorithm(TestConstants.TEST_JWT_ALGORITHM).withTokenUrl(TestConstants.TEST_TOKEN_URL).build();

    @ParameterizedTest
    @CsvSource({"RS384,RS384", "RS512,RS512", "PS256,PS256", "PS384,PS384", "PS512,PS512", "RS256,RS256", "ES384,ES384", "ES512,ES512", "ES256,ES256", "null,RS256", "HS256,RS256"})
    public void testDetermineSignatureAlgorithm(String str, JWSAlgorithm jWSAlgorithm) {
        Assertions.assertEquals(jWSAlgorithm, this.clientCredentials.determineSignatureAlgorithm(str));
    }

    @Test
    public void testRetrieveTokenExceptionHandling() throws DatabricksHttpException {
        Mockito.when(this.httpClient.execute((HttpUriRequest) ArgumentMatchers.any())).thenThrow(new Throwable[]{new DatabricksHttpException("Network error")});
        Assertions.assertTrue(((Exception) Assertions.assertThrows(DatabricksException.class, () -> {
            JwtPrivateKeyClientCredentials jwtPrivateKeyClientCredentials = this.clientCredentials;
            JwtPrivateKeyClientCredentials.retrieveToken(this.httpClient, TestConstants.TEST_TOKEN_URL, Map.of(), Map.of());
        })).getMessage().contains("Failed to retrieve custom M2M token"));
    }

    @Test
    public void testRetrieveToken() throws DatabricksHttpException, IOException {
        Mockito.when(this.httpClient.execute((HttpUriRequest) ArgumentMatchers.any())).thenReturn(this.httpResponse);
        Mockito.when(this.httpResponse.getEntity()).thenReturn(this.httpEntity);
        Mockito.when(this.httpEntity.getContent()).thenReturn(new ByteArrayInputStream(TestConstants.TEST_OAUTH_RESPONSE.getBytes()));
        JwtPrivateKeyClientCredentials jwtPrivateKeyClientCredentials = this.clientCredentials;
        Token retrieveToken = JwtPrivateKeyClientCredentials.retrieveToken(this.httpClient, TestConstants.TEST_TOKEN_URL, Map.of(), Map.of());
        Assertions.assertEquals(retrieveToken.getAccessToken(), TestConstants.TEST_ACCESS_TOKEN);
        Assertions.assertEquals(retrieveToken.getTokenType(), "Bearer");
    }

    @Test
    void testFetchSignedJWTWithRSAKey() throws Exception {
        Mockito.when(this.rsaPrivateKey.getAlgorithm()).thenReturn("RSA");
        Mockito.when(this.rsaPrivateKey.getModulus()).thenReturn(new BigInteger(2048, new SecureRandom()).setBit(2047));
        Mockito.when(this.rsaPrivateKey.getPrivateExponent()).thenReturn(new BigInteger(10, new SecureRandom()));
        SignedJWT fetchSignedJWT = this.clientCredentials.fetchSignedJWT(this.rsaPrivateKey);
        Assertions.assertNotNull(fetchSignedJWT);
        Assertions.assertEquals(TestConstants.TEST_CLIENT_ID, fetchSignedJWT.getJWTClaimsSet().getSubject());
        Assertions.assertEquals(TestConstants.TEST_CLIENT_ID, fetchSignedJWT.getJWTClaimsSet().getIssuer());
        Assertions.assertEquals(TestConstants.TEST_TOKEN_URL, fetchSignedJWT.getJWTClaimsSet().getAudience().get(0));
    }
}
