package com.amazonaws.athena.connector.lambda.data;

import com.amazonaws.athena.connector.lambda.data.BlockWriter;
import com.amazonaws.athena.connector.lambda.domain.predicate.ConstraintEvaluator;
import com.amazonaws.athena.connector.lambda.domain.spill.S3SpillLocation;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.security.AesGcmBlockCrypto;
import com.amazonaws.athena.connector.lambda.security.BlockCrypto;
import com.amazonaws.athena.connector.lambda.security.EncryptionKey;
import com.amazonaws.athena.connector.lambda.security.NoOpBlockCrypto;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.ByteStreams;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.StampedLock;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

/* loaded from: input_file:com/amazonaws/athena/connector/lambda/data/S3BlockSpiller.class */
public class S3BlockSpiller implements AutoCloseable, BlockSpiller {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) S3BlockSpiller.class);
    private static final long ASYNC_SHUTDOWN_MILLIS = 10000;
    private static final int MAX_ROWS_PER_CALL = 100;
    private static final String SPILL_QUEUE_CAPACITY = "SPILL_QUEUE_CAPACITY";
    private static final String SPILL_PUT_REQUEST_HEADERS_ENV = "spill_put_request_headers";
    private final S3Client amazonS3;
    private final BlockCrypto blockCrypto;
    private final BlockAllocator allocator;
    private final SpillConfig spillConfig;
    private final Schema schema;
    private final long maxRowsPerCall;
    private final List<SpillLocation> spillLocations;
    private final AtomicReference<Block> inProgressBlock;
    private final ExecutorService asyncSpillPool;
    private final ReadWriteLock spillLock;
    private final AtomicLong spillNumber;
    private final AtomicReference<RuntimeException> asyncException;
    private final ConstraintEvaluator constraintEvaluator;
    private final AtomicLong totalBytesSpilled;
    private final long startTime;
    private final Map<String, String> configOptions;

    public S3BlockSpiller(S3Client s3Client, SpillConfig spillConfig, BlockAllocator blockAllocator, Schema schema, ConstraintEvaluator constraintEvaluator, Map<String, String> map) {
        this(s3Client, spillConfig, blockAllocator, schema, constraintEvaluator, 100, map);
    }

    public S3BlockSpiller(S3Client s3Client, SpillConfig spillConfig, BlockAllocator blockAllocator, Schema schema, ConstraintEvaluator constraintEvaluator, int i, Map<String, String> map) {
        this.spillLocations = new ArrayList();
        this.inProgressBlock = new AtomicReference<>();
        this.spillLock = new StampedLock().asReadWriteLock();
        this.spillNumber = new AtomicLong(0L);
        this.asyncException = new AtomicReference<>(null);
        this.totalBytesSpilled = new AtomicLong();
        this.startTime = System.currentTimeMillis();
        this.configOptions = map;
        this.amazonS3 = (S3Client) Objects.requireNonNull(s3Client, "amazonS3 was null");
        this.spillConfig = (SpillConfig) Objects.requireNonNull(spillConfig, "spillConfig was null");
        this.allocator = (BlockAllocator) Objects.requireNonNull(blockAllocator, "allocator was null");
        this.schema = (Schema) Objects.requireNonNull(schema, "schema was null");
        this.blockCrypto = spillConfig.getEncryptionKey() != null ? new AesGcmBlockCrypto(blockAllocator) : new NoOpBlockCrypto(blockAllocator);
        this.asyncSpillPool = spillConfig.getNumSpillThreads() <= 0 ? null : makeAsyncSpillPool(spillConfig);
        this.maxRowsPerCall = i;
        this.constraintEvaluator = constraintEvaluator;
    }

    @Override // com.amazonaws.athena.connector.lambda.data.BlockSpiller, com.amazonaws.athena.connector.lambda.data.BlockWriter
    public ConstraintEvaluator getConstraintEvaluator() {
        return this.constraintEvaluator;
    }

    @Override // com.amazonaws.athena.connector.lambda.data.BlockWriter
    public void writeRows(BlockWriter.RowWriter rowWriter) {
        ensureInit();
        Block block = this.inProgressBlock.get();
        int rowCount = block.getRowCount();
        try {
            int writeRows = rowWriter.writeRows(block, rowCount);
            if (writeRows > this.maxRowsPerCall) {
                throw new RuntimeException("Call generated more than " + this.maxRowsPerCall + "rows. Generating too many rows per call to writeRows(...) can result in blocks that exceed the max size.");
            }
            if (writeRows > 0) {
                block.setRowCount(rowCount + writeRows);
            }
            if (block.getSize() > this.spillConfig.getMaxBlockBytes()) {
                logger.info("writeRow: Spilling block with {} rows and {} bytes and config {} bytes", Integer.valueOf(block.getRowCount()), Long.valueOf(block.getSize()), Long.valueOf(this.spillConfig.getMaxBlockBytes()));
                spillBlock(block);
                this.inProgressBlock.set(this.allocator.createBlock(this.schema));
                this.inProgressBlock.get().constrain(this.constraintEvaluator);
            }
        } catch (Exception e) {
            if (!(e instanceof RuntimeException)) {
                throw new RuntimeException(e);
            }
        }
    }

    @Override // com.amazonaws.athena.connector.lambda.data.BlockSpiller
    public boolean spilled() {
        boolean z;
        if (this.asyncException.get() != null) {
            throw this.asyncException.get();
        }
        Lock writeLock = this.spillLock.writeLock();
        try {
            writeLock.lock();
            ensureInit();
            Block block = this.inProgressBlock.get();
            if (this.spillLocations.isEmpty()) {
                if (block.getSize() < this.spillConfig.getMaxInlineBlockSize()) {
                    z = false;
                    return z;
                }
            }
            z = true;
            return z;
        } finally {
            writeLock.unlock();
        }
    }

    @Override // com.amazonaws.athena.connector.lambda.data.BlockSpiller
    public Block getBlock() {
        if (spilled()) {
            throw new RuntimeException("Blocks have spilled, calls to getBlock not permitted. use getSpillLocations instead.");
        }
        logger.info("getBlock: Inline Block size[{}] bytes vs {}", Long.valueOf(this.inProgressBlock.get().getSize()), Long.valueOf(this.spillConfig.getMaxInlineBlockSize()));
        return this.inProgressBlock.get();
    }

    @Override // com.amazonaws.athena.connector.lambda.data.BlockSpiller
    public List<SpillLocation> getSpillLocations() {
        if (!spilled()) {
            throw new RuntimeException("Blocks have not spilled, calls to getSpillLocations not permitted. use getBlock instead.");
        }
        Lock writeLock = this.spillLock.writeLock();
        try {
            Block block = this.inProgressBlock.get();
            if (block.getRowCount() > 0) {
                logger.info("getSpillLocations: Spilling final block with {} rows and {} bytes and config {} bytes", Integer.valueOf(block.getRowCount()), Long.valueOf(block.getSize()), Long.valueOf(this.spillConfig.getMaxBlockBytes()));
                spillBlock(block);
                this.inProgressBlock.set(this.allocator.createBlock(this.schema));
                this.inProgressBlock.get().constrain(this.constraintEvaluator);
            }
            writeLock.lock();
            List<SpillLocation> list = this.spillLocations;
            writeLock.unlock();
            return list;
        } catch (Throwable th) {
            writeLock.unlock();
            throw th;
        }
    }

    @Override // java.lang.AutoCloseable, com.amazonaws.athena.connector.lambda.data.BlockSpiller
    public void close() {
        logger.info("close: Spilled a total of {} bytes in {} ms", Long.valueOf(this.totalBytesSpilled.get()), Long.valueOf(System.currentTimeMillis() - this.startTime));
        if (this.asyncSpillPool == null) {
            return;
        }
        this.asyncSpillPool.shutdown();
        try {
            if (!this.asyncSpillPool.awaitTermination(ASYNC_SHUTDOWN_MILLIS, TimeUnit.MILLISECONDS)) {
                this.asyncSpillPool.shutdownNow();
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            this.asyncSpillPool.shutdownNow();
        }
    }

    private Map<String, String> getRequestHeadersFromEnv() {
        String str = this.configOptions.get(SPILL_PUT_REQUEST_HEADERS_ENV);
        if (str == null || str.isEmpty()) {
            return Collections.emptyMap();
        }
        try {
            return (Map) new ObjectMapper().readValue(str, new TypeReference<Map<String, String>>() { // from class: com.amazonaws.athena.connector.lambda.data.S3BlockSpiller.1
            });
        } catch (JsonProcessingException e) {
            logger.error(String.format("Invalid value for environment variable: %s : %s", SPILL_PUT_REQUEST_HEADERS_ENV, str), (Throwable) e);
            return Collections.emptyMap();
        }
    }

    protected SpillLocation write(Block block) {
        try {
            S3SpillLocation makeSpillLocation = makeSpillLocation();
            EncryptionKey encryptionKey = this.spillConfig.getEncryptionKey();
            logger.info("write: Started encrypting block for write to {}", makeSpillLocation);
            byte[] encrypt = this.blockCrypto.encrypt(encryptionKey, block);
            this.totalBytesSpilled.addAndGet(encrypt.length);
            logger.info("write: Started spilling block of size {} bytes", Integer.valueOf(encrypt.length));
            this.amazonS3.putObject((PutObjectRequest) PutObjectRequest.builder().bucket(makeSpillLocation.getBucket()).key(makeSpillLocation.getKey()).contentLength(Long.valueOf(encrypt.length)).metadata(getRequestHeadersFromEnv()).mo2994build(), RequestBody.fromBytes(encrypt));
            logger.info("write: Completed spilling block of size {} bytes", Integer.valueOf(encrypt.length));
            return makeSpillLocation;
        } catch (RuntimeException e) {
            this.asyncException.compareAndSet(null, e);
            logger.warn("write: Encountered error while writing block.", (Throwable) e);
            throw e;
        }
    }

    protected Block read(S3SpillLocation s3SpillLocation, EncryptionKey encryptionKey, Schema schema) {
        try {
            logger.debug("write: Started reading block from S3");
            ResponseInputStream<GetObjectResponse> object = this.amazonS3.getObject((GetObjectRequest) GetObjectRequest.builder().bucket(s3SpillLocation.getBucket()).key(s3SpillLocation.getKey()).mo2994build());
            logger.debug("write: Completed reading block from S3");
            Block decrypt = this.blockCrypto.decrypt(encryptionKey, ByteStreams.toByteArray(object), schema);
            logger.debug("write: Completed decrypting block of size.");
            return decrypt;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void spillBlock(Block block) {
        if (this.asyncSpillPool == null) {
            this.spillLocations.add(write(block));
            safeClose(block);
        } else {
            Lock readLock = this.spillLock.readLock();
            try {
                readLock.lock();
                this.asyncSpillPool.submit(() -> {
                    try {
                        this.spillLocations.add(write(block));
                        safeClose(block);
                        readLock.unlock();
                    } catch (Throwable th) {
                        readLock.unlock();
                        throw th;
                    }
                });
            } catch (Exception e) {
                readLock.unlock();
                throw e;
            }
        }
    }

    private void ensureInit() {
        if (this.inProgressBlock.get() == null) {
            this.inProgressBlock.set(this.allocator.createBlock(this.schema));
            this.inProgressBlock.get().constrain(this.constraintEvaluator);
        }
    }

    private S3SpillLocation makeSpillLocation() {
        S3SpillLocation s3SpillLocation = (S3SpillLocation) this.spillConfig.getSpillLocation();
        if (!s3SpillLocation.isDirectory()) {
            throw new RuntimeException("Split's SpillLocation must be a directory because multiple blocks may be spilled.");
        }
        return new S3SpillLocation(s3SpillLocation.getBucket(), s3SpillLocation.getKey() + "." + this.spillNumber.getAndIncrement(), false);
    }

    private void safeClose(AutoCloseable autoCloseable) {
        try {
            autoCloseable.close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private ThreadPoolExecutor makeAsyncSpillPool(SpillConfig spillConfig) {
        int numSpillThreads = spillConfig.getNumSpillThreads();
        String str = StringUtils.isNotBlank(this.configOptions.get(SPILL_QUEUE_CAPACITY)) ? this.configOptions.get(SPILL_QUEUE_CAPACITY) : this.configOptions.get(SPILL_QUEUE_CAPACITY.toLowerCase());
        if (str != null) {
            numSpillThreads = Integer.parseInt(str);
            logger.debug("Setting Spill Queue Capacity to {}", Integer.valueOf(numSpillThreads));
        }
        return new ThreadPoolExecutor(spillConfig.getNumSpillThreads(), spillConfig.getNumSpillThreads(), 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue(numSpillThreads), (runnable, threadPoolExecutor) -> {
            if (threadPoolExecutor.isShutdown()) {
                return;
            }
            try {
                threadPoolExecutor.getQueue().put(runnable);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RejectedExecutionException("Received an exception while submitting spillBlock task: ", e);
            }
        });
    }
}
