package com.datarobot.mlops.common.spooler;

import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.services.sqs.AmazonSQS;
import com.amazonaws.services.sqs.AmazonSQSClientBuilder;
import com.amazonaws.services.sqs.model.BatchResultErrorEntry;
import com.amazonaws.services.sqs.model.DeleteMessageBatchRequestEntry;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.QueueAttributeName;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
import com.amazonaws.services.sqs.model.SendMessageBatchResult;
import com.amazonaws.services.sqs.model.SendMessageRequest;
import com.amazonaws.services.sqs.model.SetQueueAttributesRequest;
import com.datarobot.mlops.common.config.MappedConfig;
import com.datarobot.mlops.common.constants.ConfigConstants;
import com.datarobot.mlops.common.enums.DataFormat;
import com.datarobot.mlops.common.enums.SQSQueueType;
import com.datarobot.mlops.common.enums.SpoolerType;
import com.datarobot.mlops.common.exceptions.DRCommonException;
import com.datarobot.mlops.common.exceptions.DRQueueException;
import com.datarobot.mlops.common.records.Record;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datarobot/mlops/common/spooler/SQSSpooler.class */
public class SQSSpooler extends RecordSpooler {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) SQSSpooler.class);
    private static final String FIFO_SUFFIX = ".fifo";
    private static final String DEFAULT_MESSAGE_GROUP_ID = "MLOpsAgentGroup";
    private static final int MAX_BATCH_NUMBER = 10;
    private static final int DEFAULT_VISIBILITY_TIMEOUT_IN_SECOND = 120;
    private static final int DEFAULT_BATCH_NUM = 10;
    private static final int SQS_MESSAGE_SIZE_LIMIT_IN_BYTES = 256000;
    private static final String DEFAULT_SPOOLER_DATA_FORMAT = "JSON";
    private AmazonSQS sqsClient;
    private String queueName;
    private String queueUrl;
    private SQSQueueType queueType;
    private String messageGroupId;
    private int visibilityTimeout;

    public SQSSpooler(MappedConfig mappedConfig) {
        super(mappedConfig);
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public SpoolerType getType() {
        return SpoolerType.SQS;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public List<String> getRequiredConfigKeys() {
        return new ArrayList();
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public List<String> getOptionalConfigKeys() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(ConfigConstants.SQS_QUEUE_NAME_STR);
        arrayList.add(ConfigConstants.SQS_QUEUE_URL_STR);
        arrayList.add(ConfigConstants.SQS_VISIBILITY_TIMEOUT_STR);
        return arrayList;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public void verifyConfig() throws DRCommonException {
        List<String> FindMissingConfigKeys = FindMissingConfigKeys();
        if (FindMissingConfigKeys.size() > 0) {
            throw new DRCommonException("Missing required configuration for: " + FindMissingConfigKeys.toString());
        }
        this.queueName = this.config.getValueWithDefault(ConfigConstants.SQS_QUEUE_NAME_STR, "");
        this.queueUrl = this.config.getValueWithDefault(ConfigConstants.SQS_QUEUE_URL_STR, "");
        this.visibilityTimeout = this.config.getValueWithDefault(ConfigConstants.SQS_VISIBILITY_TIMEOUT_STR, 120);
        DataFormat fromString = DataFormat.fromString(this.config.getValueWithDefault(ConfigConstants.MLOPS_SPOOLER_DATA_FORMAT, "JSON"));
        if (fromString != DataFormat.JSON) {
            throw new DRCommonException("Data Format: '" + fromString.toString() + "' is not supported for the SQS spooler");
        }
        if (this.queueName.isEmpty() && this.queueUrl.isEmpty()) {
            throw new DRCommonException("Must configure either MLOPS_SQS_QUEUE_NAME or MLOPS_SQS_QUEUE_URL");
        }
        if (!this.queueUrl.isEmpty()) {
            MappedConfig.validateUrl(this.queueUrl);
        }
        if (this.visibilityTimeout < 0) {
            throw new DRCommonException("Visibility timeout should be > 0; current value: " + this.visibilityTimeout);
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public void open() throws DRCommonException {
        verifyConfig();
        try {
            this.sqsClient = AmazonSQSClientBuilder.standard().build();
            if (this.queueUrl == null || this.queueUrl.isEmpty()) {
                logger.debug("Looking up SQS queue name: '" + this.queueName + "'");
                try {
                    this.queueUrl = this.sqsClient.getQueueUrl(this.queueName).getQueueUrl();
                    logger.info("SQS configured with queue name " + this.queueName);
                } catch (Exception e) {
                    String str = "Failed to find queue url for queueName '" + this.queueName + "'. Error: " + e.getMessage();
                    logger.error(str);
                    throw new DRCommonException(str);
                }
            }
            logger.info("SQS configured with queue url " + this.queueUrl);
            try {
                updateVisibilityTimeout();
                this.queueType = SQSQueueType.STANDARD;
                this.messageGroupId = this.config.getValueWithDefault(ConfigConstants.DEPLOYMENT_ID_STR, DEFAULT_MESSAGE_GROUP_ID);
                if (this.queueUrl.endsWith(FIFO_SUFFIX)) {
                    this.queueType = SQSQueueType.FIFO;
                }
                logger.info("Connection to SQS queue: '" + this.queueUrl + "' successful");
            } catch (Exception e2) {
                String str2 = "Failed to update queue visibility. Error: " + e2.getMessage();
                logger.error(str2);
                throw new DRCommonException(str2);
            }
        } catch (Exception e3) {
            String str3 = "Failed to create AWS SQS client. Error: " + e3.getMessage();
            logger.error(str3);
            throw new DRCommonException(str3);
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public int getMessageByteSizeLimit() {
        return SQS_MESSAGE_SIZE_LIMIT_IN_BYTES;
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public int enqueue(Collection<Record> collection) throws DRQueueException {
        String str;
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        ArrayList<BatchResultErrorEntry> arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        Iterator<Record> it2 = collection.iterator();
        while (it2.hasNext()) {
            Record next = it2.next();
            try {
                SendMessageBatchRequestEntry sendMessageBatchRequestEntry = getSendMessageBatchRequestEntry(next.toJson());
                hashMap.put(sendMessageBatchRequestEntry.getId(), next);
                arrayList3.add(sendMessageBatchRequestEntry);
                i3++;
                try {
                    if (arrayList3.size() == 10 || i3 == collection.size()) {
                        List<BatchResultErrorEntry> failed = sendMessageBatch(arrayList3).getFailed();
                        arrayList2.addAll(failed);
                        i2 += failed.size();
                        Iterator it3 = arrayList2.iterator();
                        while (it3.hasNext()) {
                            arrayList.add(hashMap.get(((BatchResultErrorEntry) it3.next()).getId()));
                        }
                        arrayList3.clear();
                        hashMap.clear();
                    }
                } catch (Exception e) {
                    if (e instanceof AmazonServiceException) {
                        AmazonServiceException amazonServiceException = (AmazonServiceException) e;
                        str = "Failed to send batch messages to AWS SQS queue (" + this.queueUrl + "), Error: " + amazonServiceException.getMessage() + ", " + amazonServiceException.getErrorMessage();
                    } else {
                        str = e instanceof AmazonClientException ? "Failed to connect to AWS SQS queue (" + this.queueUrl + "), Error: " + ((AmazonClientException) e).getMessage() : "Enqueue Failed: " + e.getMessage();
                    }
                    logger.error(str);
                    arrayList.addAll(hashMap.values());
                    arrayList3.clear();
                    hashMap.clear();
                    arrayList.getClass();
                    it2.forEachRemaining((v1) -> {
                        r1.add(v1);
                    });
                    throw new DRQueueException("Failed to enqueue records", arrayList);
                }
            } catch (DRCommonException e2) {
                i++;
                logger.error("Failed to deserialize data record, Error: " + e2.getMessage());
            }
        }
        if (i2 > 0 || i > 0) {
            int i4 = i + i2;
            logger.error("Failed to send " + i4 + " messages to SQS queue - " + this.queueUrl);
            logger.debug("Failed to send " + i4 + " messages to SQS queue - " + this.queueUrl + ", " + i + " of them failed is because failed to serialize the data record, " + i2 + " of them failed is because errors in AWS SQS");
            if (i2 > 0) {
                logger.debug("Detailed error about each message");
                for (BatchResultErrorEntry batchResultErrorEntry : arrayList2) {
                    logger.debug("Message id " + batchResultErrorEntry.getId() + ", with error code " + batchResultErrorEntry.getCode() + " and error is " + batchResultErrorEntry.getMessage());
                }
                throw new DRQueueException("Failed to enqueue records", arrayList);
            }
        }
        return collection.size() - i;
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public Collection<Record> dequeue() throws DRQueueException {
        String str;
        ArrayList arrayList = new ArrayList();
        try {
            for (Message message : receiveMessages()) {
                try {
                    Record fromJson = Record.fromJson(message.getBody());
                    arrayList.add(fromJson);
                    addPendingRecord(fromJson.getId(), message);
                } catch (DRCommonException e) {
                    logger.error("Failed to deserialize data record, Error: " + e.getMessage());
                }
            }
            return arrayList;
        } catch (Exception e2) {
            if (e2 instanceof AmazonServiceException) {
                AmazonServiceException amazonServiceException = (AmazonServiceException) e2;
                str = "Failed to receive messages from AWS SQS queue (" + this.queueUrl + "), Error: " + amazonServiceException.getMessage() + ", " + amazonServiceException.getErrorMessage();
            } else {
                str = e2 instanceof AmazonClientException ? "Failed to connect to AWS SQS queue (" + this.queueUrl + "), Error: " + ((AmazonClientException) e2).getMessage() : "Dequeue Failed: " + e2.getMessage();
            }
            logger.error(str);
            throw new DRQueueException(str, arrayList);
        }
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler, com.datarobot.mlops.common.spooler.Spooler
    public void ackRecords(Collection<String> collection) throws DRQueueException {
        if (this.enableDequeueAckRecord) {
            deleteMessageBatch((List) collection.stream().map(str -> {
                return (Message) this.recordsPendingAck.remove(str);
            }).filter((v0) -> {
                return Objects.nonNull(v0);
            }).collect(Collectors.toList()));
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public void close() {
        if (this.sqsClient != null) {
            logger.info("Shutting down SQS client");
            this.sqsClient.shutdown();
            this.sqsClient = null;
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public boolean needsRetry() {
        return true;
    }

    private void sendMessage(String str) {
        this.sqsClient.sendMessage(getSendMessageRequest(str));
    }

    private SendMessageBatchResult sendMessageBatch(List<SendMessageBatchRequestEntry> list) {
        return this.sqsClient.sendMessageBatch(this.queueUrl, list);
    }

    private List<Message> receiveMessages() {
        List<Message> messages = this.sqsClient.receiveMessage(new ReceiveMessageRequest(this.queueUrl)).getMessages();
        if (!this.enableDequeueAckRecord) {
            deleteMessageBatch(messages);
        }
        return messages;
    }

    private void setVisibilityTimeOut(String str) {
        this.sqsClient.setQueueAttributes(new SetQueueAttributesRequest().withQueueUrl(this.queueUrl).addAttributesEntry(QueueAttributeName.VisibilityTimeout.toString(), str));
    }

    private void deleteMessageBatch(List<Message> list) {
        ArrayList arrayList = new ArrayList();
        for (Message message : list) {
            arrayList.add(new DeleteMessageBatchRequestEntry(message.getMessageId(), message.getReceiptHandle()));
            if (arrayList.size() == 10) {
                this.sqsClient.deleteMessageBatch(this.queueUrl, arrayList);
                arrayList.clear();
            }
        }
        if (arrayList.isEmpty()) {
            return;
        }
        this.sqsClient.deleteMessageBatch(this.queueUrl, arrayList);
    }

    public void updateVisibilityTimeout() throws DRCommonException {
        String queueAttributeName = QueueAttributeName.VisibilityTimeout.toString();
        try {
            if (!this.sqsClient.getQueueAttributes(this.queueUrl, Collections.singletonList(queueAttributeName)).getAttributes().containsKey(queueAttributeName)) {
                setVisibilityTimeOut(String.valueOf(this.visibilityTimeout));
            }
        } catch (AmazonServiceException e) {
            String str = "Failed to set message visibility timeout to AWS SQS queue (" + this.queueUrl + "), Error: " + e.getErrorMessage();
            logger.error(str);
            throw new DRCommonException(str);
        } catch (AmazonClientException e2) {
            String str2 = "Failed to connect to AWS SQS queue (" + this.queueUrl + "), Error: " + e2.getMessage();
            logger.error(str2);
            throw new DRCommonException(str2);
        }
    }

    private SendMessageRequest getSendMessageRequest(String str) {
        SendMessageRequest sendMessageRequest = new SendMessageRequest(this.queueUrl, str);
        if (this.queueType == SQSQueueType.FIFO) {
            sendMessageRequest.setMessageGroupId(this.messageGroupId);
        }
        return sendMessageRequest;
    }

    private SendMessageBatchRequestEntry getSendMessageBatchRequestEntry(String str) {
        SendMessageBatchRequestEntry sendMessageBatchRequestEntry = new SendMessageBatchRequestEntry(UUID.randomUUID().toString(), str);
        if (this.queueType == SQSQueueType.FIFO) {
            sendMessageBatchRequestEntry.setMessageGroupId(this.messageGroupId);
        }
        return sendMessageBatchRequestEntry;
    }
}
