package net.snowflake.ingest.internal.apache.parquet.crypto;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import javax.crypto.AEADBadTagException;
import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import net.snowflake.ingest.internal.apache.parquet.format.BlockCipher;

/* loaded from: input_file:net/snowflake/ingest/internal/apache/parquet/crypto/AesGcmDecryptor.class */
public class AesGcmDecryptor extends AesCipher implements BlockCipher.Decryptor {
    /* JADX INFO: Access modifiers changed from: package-private */
    public AesGcmDecryptor(byte[] bArr) {
        super(AesMode.GCM, bArr);
        try {
            this.cipher = Cipher.getInstance(AesMode.GCM.getCipherName());
        } catch (GeneralSecurityException e) {
            throw new ParquetCryptoRuntimeException("Failed to create GCM cipher", e);
        }
    }

    @Override // net.snowflake.ingest.internal.apache.parquet.format.BlockCipher.Decryptor
    public byte[] decrypt(byte[] bArr, byte[] bArr2) {
        return decrypt(bArr, 4, bArr.length - 4, bArr2);
    }

    public byte[] decrypt(byte[] bArr, int i, int i2, byte[] bArr2) {
        int i3 = (i2 - 16) - 12;
        if (i3 < 1) {
            throw new ParquetCryptoRuntimeException("Wrong input length " + i3);
        }
        System.arraycopy(bArr, i, this.localNonce, 0, 12);
        byte[] bArr3 = new byte[i3];
        int i4 = i2 - 12;
        int i5 = i + 12;
        try {
            this.cipher.init(2, this.aesKey, new GCMParameterSpec(128, this.localNonce));
            if (null != bArr2) {
                this.cipher.updateAAD(bArr2);
            }
            this.cipher.doFinal(bArr, i5, i4, bArr3, 0);
            return bArr3;
        } catch (AEADBadTagException e) {
            throw new TagVerificationException("GCM tag check failed", e);
        } catch (GeneralSecurityException e2) {
            throw new ParquetCryptoRuntimeException("Failed to decrypt", e2);
        }
    }

    @Override // net.snowflake.ingest.internal.apache.parquet.format.BlockCipher.Decryptor
    public ByteBuffer decrypt(ByteBuffer byteBuffer, byte[] bArr) {
        int limit = (((byteBuffer.limit() - byteBuffer.position()) - 4) - 16) - 12;
        if (limit < 1) {
            throw new ParquetCryptoRuntimeException("Wrong input length " + limit);
        }
        byteBuffer.position(byteBuffer.position() + 4);
        byteBuffer.get(this.localNonce);
        ByteBuffer slice = byteBuffer.slice();
        slice.limit(limit);
        try {
            this.cipher.init(2, this.aesKey, new GCMParameterSpec(128, this.localNonce));
            if (null != bArr) {
                this.cipher.updateAAD(bArr);
            }
            this.cipher.doFinal(byteBuffer, slice);
            slice.flip();
            return slice;
        } catch (AEADBadTagException e) {
            throw new TagVerificationException("GCM tag check failed", e);
        } catch (GeneralSecurityException e2) {
            throw new ParquetCryptoRuntimeException("Failed to decrypt", e2);
        }
    }

    @Override // net.snowflake.ingest.internal.apache.parquet.format.BlockCipher.Decryptor
    public byte[] decrypt(InputStream inputStream, byte[] bArr) throws IOException {
        byte[] bArr2 = new byte[4];
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 < 4) {
                int read = inputStream.read(bArr2, i2, 4 - i2);
                if (read <= 0) {
                    throw new IOException("Tried to read int (4 bytes), but only got " + i2 + " bytes.");
                }
                i = i2 + read;
            } else {
                int i3 = ((bArr2[3] & 255) << 24) | ((bArr2[2] & 255) << 16) | ((bArr2[1] & 255) << 8) | (bArr2[0] & 255);
                if (i3 < 1) {
                    throw new IOException("Wrong length of encrypted metadata: " + i3);
                }
                byte[] bArr3 = new byte[i3];
                int i4 = 0;
                while (true) {
                    int i5 = i4;
                    if (i5 >= i3) {
                        return decrypt(bArr3, 0, i3, bArr);
                    }
                    int read2 = inputStream.read(bArr3, i5, i3 - i5);
                    if (read2 <= 0) {
                        throw new IOException("Tried to read " + i3 + " bytes, but only got " + i5 + " bytes.");
                    }
                    i4 = i5 + read2;
                }
            }
        }
    }
}
