package com.databricks.jdbc.api.impl.arrow;

import com.databricks.jdbc.api.impl.arrow.ArrowResultChunk;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.DatabricksClientType;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.exception.DatabricksValidationException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.core.ExternalLink;
import java.time.Instant;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:com/databricks/jdbc/api/impl/arrow/ChunkLinkDownloadService.class */
public class ChunkLinkDownloadService {
    private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger((Class<?>) ChunkLinkDownloadService.class);
    private final IDatabricksSession session;
    private final StatementId statementId;
    private final long totalChunks;
    private final Map<Long, CompletableFuture<ExternalLink>> chunkIndexToLinkFuture;
    private final AtomicLong nextBatchStartIndex;
    private final AtomicBoolean isDownloadInProgress;
    private final AtomicBoolean isDownloadChainStarted;
    private volatile boolean isShutdown;
    private volatile CompletableFuture<Void> currentDownloadTask;
    private final Object resetLock = new Object();
    private final Map<Long, ArrowResultChunk> chunkIndexToChunksMap;

    public ChunkLinkDownloadService(IDatabricksSession iDatabricksSession, StatementId statementId, long j, Map<Long, ArrowResultChunk> map, long j2) {
        LOGGER.info("Initializing ChunkLinkDownloadService for statement %s with total chunks: %d, starting at index: %d", statementId, Long.valueOf(j), Long.valueOf(j2));
        this.session = iDatabricksSession;
        this.statementId = statementId;
        this.totalChunks = j;
        this.nextBatchStartIndex = new AtomicLong(j2);
        this.isDownloadInProgress = new AtomicBoolean(false);
        this.isDownloadChainStarted = new AtomicBoolean(false);
        this.isShutdown = false;
        this.chunkIndexToLinkFuture = new ConcurrentHashMap();
        long j3 = 0;
        while (true) {
            long j4 = j3;
            if (j4 >= j) {
                break;
            }
            this.chunkIndexToLinkFuture.put(Long.valueOf(j4), new CompletableFuture<>());
            j3 = j4 + 1;
        }
        this.chunkIndexToChunksMap = map;
        if (iDatabricksSession.getConnectionContext().getClientType() == DatabricksClientType.SEA && this.isDownloadChainStarted.compareAndSet(false, true)) {
            LOGGER.info("Auto-triggering download chain for SEA client type");
            triggerNextBatchDownload();
        }
    }

    public CompletableFuture<ExternalLink> getLinkForChunk(long j) throws ExecutionException, InterruptedException {
        if (this.isShutdown) {
            LOGGER.warn("Attempt to get link for chunk %d while chunk download service is shutdown", Long.valueOf(j));
            return createExceptionalFuture(new DatabricksValidationException("Chunk Link Download Service is shutdown"));
        }
        if (j >= this.totalChunks) {
            LOGGER.error("Requested chunk index %d exceeds total chunks %d", Long.valueOf(j), Long.valueOf(this.totalChunks));
            return createExceptionalFuture(new DatabricksValidationException("Chunk index exceeds total chunks"));
        }
        LOGGER.debug("Getting link for chunk %d", Long.valueOf(j));
        handleExpiredLinksAndReset(j);
        if (this.isDownloadChainStarted.compareAndSet(false, true)) {
            LOGGER.info("Initiating first download chain for chunk %d", Long.valueOf(j));
            triggerNextBatchDownload();
        }
        return this.chunkIndexToLinkFuture.get(Long.valueOf(j));
    }

    public void shutdown() {
        LOGGER.info("Shutting down ChunkLinkDownloadService for statement %s", this.statementId);
        this.isShutdown = true;
        this.chunkIndexToLinkFuture.forEach((l, completableFuture) -> {
            if (completableFuture.isDone()) {
                return;
            }
            LOGGER.debug("Completing future for chunk %d exceptionally due to shutdown", l);
            completableFuture.completeExceptionally(new DatabricksValidationException("Service was shut down"));
        });
    }

    private void triggerNextBatchDownload() {
        if (this.isShutdown || !this.isDownloadInProgress.compareAndSet(false, true)) {
            LOGGER.debug("Skipping batch download - Service shutdown: %s, Download in progress: %s", Boolean.valueOf(this.isShutdown), Boolean.valueOf(this.isDownloadInProgress.get()));
            return;
        }
        long j = this.nextBatchStartIndex.get();
        if (j >= this.totalChunks) {
            LOGGER.info("No more chunks to download. Current index: %d, Total chunks: %d", Long.valueOf(j), Long.valueOf(this.totalChunks));
            this.isDownloadInProgress.set(false);
        } else {
            LOGGER.info("Starting batch download from index %d", Long.valueOf(j));
            this.currentDownloadTask = CompletableFuture.runAsync(() -> {
                try {
                    Collection<ExternalLink> resultChunks = this.session.getDatabricksClient().getResultChunks(this.statementId, j);
                    LOGGER.info("Retrieved %d links for batch starting at %d for statement id %s", Integer.valueOf(resultChunks.size()), Long.valueOf(j), this.statementId);
                    for (ExternalLink externalLink : resultChunks) {
                        CompletableFuture<ExternalLink> completableFuture = this.chunkIndexToLinkFuture.get(externalLink.getChunkIndex());
                        if (completableFuture != null) {
                            LOGGER.debug("Completing future for chunk %d for statement id %s", externalLink.getChunkIndex(), this.statementId);
                            completableFuture.complete(externalLink);
                        }
                    }
                    if (!resultChunks.isEmpty()) {
                        long asLong = resultChunks.stream().mapToLong((v0) -> {
                            return v0.getChunkIndex();
                        }).max().getAsLong();
                        this.nextBatchStartIndex.set(asLong + 1);
                        LOGGER.debug("Updated next batch start index to %d", Long.valueOf(asLong + 1));
                        this.isDownloadInProgress.set(false);
                        if (asLong + 1 < this.totalChunks) {
                            LOGGER.debug("Triggering next batch download");
                            triggerNextBatchDownload();
                        }
                    }
                } catch (DatabricksSQLException e) {
                    handleBatchDownloadError(j, e);
                }
            });
        }
    }

    private void handleBatchDownloadError(long j, DatabricksSQLException databricksSQLException) {
        LOGGER.error(databricksSQLException, "Failed to download links for batch starting at %d: %s", Long.valueOf(j), databricksSQLException.getMessage());
        this.chunkIndexToLinkFuture.forEach((l, completableFuture) -> {
            if (completableFuture.isDone()) {
                return;
            }
            LOGGER.debug("Completing future for chunk %d exceptionally due to batch download error", l);
            completableFuture.completeExceptionally(databricksSQLException);
        });
        this.isDownloadInProgress.set(false);
    }

    private CompletableFuture<ExternalLink> createExceptionalFuture(Exception exc) {
        CompletableFuture<ExternalLink> completableFuture = new CompletableFuture<>();
        completableFuture.completeExceptionally(exc);
        return completableFuture;
    }

    private void handleExpiredLinksAndReset(long j) throws ExecutionException, InterruptedException {
        synchronized (this.resetLock) {
            if (isChunkLinkExpiredForPendingDownload(j)) {
                LOGGER.info("Detected expired link for chunk %d, re-triggering batch download from the smallest index with the expired link", Long.valueOf(j));
                long j2 = 1;
                while (true) {
                    if (j2 >= this.totalChunks) {
                        break;
                    }
                    if (isChunkLinkExpiredForPendingDownload(j2)) {
                        LOGGER.info("Found the smallest index %d with the expired link, initiating reset", Long.valueOf(j2));
                        cancelCurrentDownloadTask();
                        resetFuturesFromIndex(j2);
                        prepareNewBatchDownload(j2);
                        break;
                    }
                    j2++;
                }
            }
        }
    }

    private boolean isChunkLinkExpiredForPendingDownload(long j) throws ExecutionException, InterruptedException {
        CompletableFuture<ExternalLink> completableFuture = this.chunkIndexToLinkFuture.get(Long.valueOf(j));
        return completableFuture.isDone() && isChunkLinkExpired(completableFuture.get()) && this.chunkIndexToChunksMap.get(Long.valueOf(j)).getStatus() != ArrowResultChunk.ChunkStatus.DOWNLOAD_SUCCEEDED;
    }

    private void cancelCurrentDownloadTask() {
        if (this.currentDownloadTask == null || this.currentDownloadTask.isDone()) {
            return;
        }
        LOGGER.debug("Cancelling current download task");
        this.currentDownloadTask.cancel(true);
        try {
            this.currentDownloadTask.get(100L, TimeUnit.MILLISECONDS);
        } catch (Exception e) {
            LOGGER.trace("Expected exception while cancelling download task: %s", e.getMessage());
        }
        this.currentDownloadTask = null;
    }

    private void resetFuturesFromIndex(long j) {
        LOGGER.info("Resetting futures from index %d", Long.valueOf(j));
        long j2 = j;
        while (true) {
            long j3 = j2;
            if (j3 >= this.totalChunks) {
                return;
            }
            CompletableFuture<ExternalLink> completableFuture = this.chunkIndexToLinkFuture.get(Long.valueOf(j3));
            if (completableFuture != null && !completableFuture.isDone()) {
                LOGGER.debug("Cancelling future for chunk %d", Long.valueOf(j3));
                completableFuture.cancel(true);
            }
            this.chunkIndexToLinkFuture.put(Long.valueOf(j3), new CompletableFuture<>());
            j2 = j3 + 1;
        }
    }

    private void prepareNewBatchDownload(long j) {
        LOGGER.info("Preparing new batch download from index %d", Long.valueOf(j));
        this.nextBatchStartIndex.set(j);
        this.isDownloadInProgress.set(false);
        this.isDownloadChainStarted.set(false);
    }

    private boolean isChunkLinkExpired(ExternalLink externalLink) {
        if (externalLink != null && externalLink.getExpiration() != null) {
            return Instant.parse(externalLink.getExpiration()).minusSeconds(ArrowResultChunk.SECONDS_BUFFER_FOR_EXPIRY.intValue()).isBefore(Instant.now());
        }
        LOGGER.warn("Link or expiration is null, assuming link is expired");
        return true;
    }
}
