package com.datarobot.mlops.spooler.rabbitmq;

import com.datarobot.mlops.common.config.MappedConfig;
import com.datarobot.mlops.common.constants.ConfigConstants;
import com.datarobot.mlops.common.enums.DataFormat;
import com.datarobot.mlops.common.enums.SpoolerType;
import com.datarobot.mlops.common.exceptions.DRCommonException;
import com.datarobot.mlops.common.exceptions.DRQueueException;
import com.datarobot.mlops.common.records.Record;
import com.datarobot.mlops.common.spooler.RecordSpooler;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.Envelope;
import com.rabbitmq.client.GetResponse;
import com.rabbitmq.client.MessageProperties;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datarobot/mlops/spooler/rabbitmq/RabbitMQSpooler.class */
public class RabbitMQSpooler extends RecordSpooler {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) RabbitMQSpooler.class);
    public static final int RABBITMQ_MAX_RECORDS_TO_DEQUEUE = 10;
    private static final int RABBITMQ_MESSAGE_SIZE_LIMIT_IN_BYTE = 52428800;
    private static final String DEFAULT_SPOOLER_DATA_FORMAT = "JSON";
    private static final String DEFAULT_TLS_VERSION = "TLSv1.2";
    private Channel channel;
    private String queueName;
    private String queueUrl;
    private Connection connection;
    private boolean enableSSL;
    private String SSLCaCertificatePath;
    private String SSLCertificatePath;
    private String SSLKeyfilePath;
    private String SSLTlsVersion;

    public RabbitMQSpooler(MappedConfig mappedConfig) {
        super(mappedConfig);
        this.enableSSL = false;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public SpoolerType getType() {
        return SpoolerType.RABBITMQ;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public List<String> getRequiredConfigKeys() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(ConfigConstants.RABBITMQ_QUEUE_URL_STR);
        arrayList.add(ConfigConstants.RABBITMQ_QUEUE_NAME_STR);
        return arrayList;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public List<String> getOptionalConfigKeys() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(ConfigConstants.RABBITMQ_SSL_CA_CERTIFICATE_PATH_STR);
        arrayList.add(ConfigConstants.RABBITMQ_SSL_CERTIFICATE_PATH_STR);
        arrayList.add(ConfigConstants.RABBITMQ_SSL_KEYFILE_PATH_STR);
        arrayList.add(ConfigConstants.RABBITMQ_SSL_TLS_VERSION_STR);
        return arrayList;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler
    public void verifyConfig() throws DRCommonException {
        List<String> FindMissingConfigKeys = FindMissingConfigKeys();
        if (FindMissingConfigKeys.size() > 0) {
            throw new DRCommonException("Missing required configuration for: " + FindMissingConfigKeys);
        }
        this.queueUrl = this.config.getStringValue(ConfigConstants.RABBITMQ_QUEUE_URL_STR);
        this.queueName = this.config.getStringValue(ConfigConstants.RABBITMQ_QUEUE_NAME_STR);
        if (this.queueUrl == null || this.queueUrl.isEmpty()) {
            throw new DRCommonException("Missing queueUrl for RabbitMQ channel");
        }
        if (this.queueName == null || this.queueName.isEmpty()) {
            throw new DRCommonException("Missing queueName for RabbitMQ channel");
        }
        DataFormat fromString = DataFormat.fromString(this.config.getValueWithDefault(ConfigConstants.MLOPS_SPOOLER_DATA_FORMAT, DEFAULT_SPOOLER_DATA_FORMAT));
        if (fromString != DataFormat.JSON) {
            throw new DRCommonException("Data Format: '" + fromString.toString() + "' is not supported for the Rabbit MQ spooler");
        }
        this.SSLCaCertificatePath = this.config.getValueWithDefault(ConfigConstants.RABBITMQ_SSL_CA_CERTIFICATE_PATH_STR, (String) null);
        this.SSLCertificatePath = this.config.getValueWithDefault(ConfigConstants.RABBITMQ_SSL_CERTIFICATE_PATH_STR, (String) null);
        this.SSLKeyfilePath = this.config.getValueWithDefault(ConfigConstants.RABBITMQ_SSL_KEYFILE_PATH_STR, (String) null);
        this.SSLTlsVersion = this.config.getValueWithDefault(ConfigConstants.RABBITMQ_SSL_TLS_VERSION_STR, DEFAULT_TLS_VERSION);
        this.enableSSL = (this.SSLCaCertificatePath == null || this.SSLCaCertificatePath.isEmpty() || this.SSLCertificatePath == null || this.SSLCertificatePath.isEmpty() || this.SSLKeyfilePath == null || this.SSLKeyfilePath.isEmpty()) ? false : true;
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public int getMessageByteSizeLimit() {
        return RABBITMQ_MESSAGE_SIZE_LIMIT_IN_BYTE;
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public int enqueue(Collection<Record> collection) throws DRQueueException {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Iterator<Record> it = collection.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            try {
                String json = next.toJson();
                try {
                    logger.debug("Sending message: size " + json.getBytes().length);
                    this.channel.basicPublish("", this.queueName, MessageProperties.PERSISTENT_TEXT_PLAIN, json.getBytes(StandardCharsets.UTF_8));
                } catch (Exception e) {
                    logger.error("Failed to send messages to RabbitMQ queue. Error: " + e.getMessage());
                    arrayList.add(next);
                    arrayList.getClass();
                    it.forEachRemaining((v1) -> {
                        r1.add(v1);
                    });
                    throw new DRQueueException("Failed to enqueue records", arrayList);
                }
            } catch (DRCommonException e2) {
                logger.error(String.format("Failed to serialize data record. Error %s", e2.getMessage()));
                i++;
            }
        }
        return collection.size() - i;
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public Collection<Record> dequeue() throws DRQueueException {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            try {
                GetResponse basicGet = this.channel.basicGet(this.queueName, !this.enableDequeueAckRecord);
                if (basicGet == null) {
                    break;
                }
                try {
                    byte[] body = basicGet.getBody();
                    logger.debug("Got a message from queue of size: " + body.length);
                    Record fromJson = Record.fromJson(new String(body, StandardCharsets.UTF_8));
                    arrayList.add(fromJson);
                    addPendingRecord(fromJson.getId(), basicGet.getEnvelope());
                } catch (DRCommonException e) {
                    logger.error("Failed to de-serialized dequeued record: " + e.getMessage());
                }
            } catch (Exception e2) {
                String str = "Failed to receive messages from RabbitMQ queue. Error: " + e2.getMessage();
                logger.error(str);
                throw new DRQueueException(str, arrayList);
            }
        }
        updateEmptyCount(arrayList.size());
        return arrayList;
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler, com.datarobot.mlops.common.spooler.Spooler
    public void ackRecords(Collection<String> collection) throws DRQueueException {
        if (this.enableDequeueAckRecord) {
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = collection.iterator();
            while (it.hasNext()) {
                String next = it.next();
                try {
                    Envelope envelope = (Envelope) this.recordsPendingAck.get(next);
                    if (envelope != null) {
                        this.channel.basicAck(envelope.getDeliveryTag(), false);
                        logger.debug("Confirm ack for: " + next);
                        this.recordsPendingAck.remove(next);
                    }
                } catch (Exception e) {
                    arrayList.add(next);
                    arrayList.getClass();
                    it.forEachRemaining((v1) -> {
                        r1.add(v1);
                    });
                    throw new DRQueueException(arrayList, "Failed to send ack");
                }
            }
        }
    }

    @Override // com.datarobot.mlops.common.spooler.RecordSpooler, com.datarobot.mlops.common.spooler.Spooler
    public void nackRecords(Collection<String> collection) throws DRQueueException {
        if (this.enableDequeueAckRecord) {
            ArrayList arrayList = new ArrayList();
            Iterator<String> it = collection.iterator();
            while (it.hasNext()) {
                String next = it.next();
                try {
                    Envelope envelope = (Envelope) this.recordsPendingAck.get(next);
                    if (envelope != null) {
                        this.channel.basicNack(envelope.getDeliveryTag(), false, true);
                        logger.debug("Confirm nack for: " + next);
                        this.recordsPendingAck.remove(next);
                    }
                } catch (Exception e) {
                    arrayList.add(next);
                    arrayList.getClass();
                    it.forEachRemaining((v1) -> {
                        r1.add(v1);
                    });
                    throw new DRQueueException(arrayList, "Failed to send nack");
                }
            }
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public void open() throws DRCommonException {
        verifyConfig();
        ConnectionFactory connectionFactory = new ConnectionFactory();
        try {
            logger.info("Connecting to " + this.queueUrl + " with queueName '" + this.queueName + "'");
            connectionFactory.setUri(this.queueUrl);
            if (this.enableSSL) {
                connectionFactory.useSslProtocol(SSLContextBuilder.buildContextFromCertificates(UUID.randomUUID().toString(), this.SSLCaCertificatePath, this.SSLCertificatePath, this.SSLKeyfilePath, this.SSLTlsVersion));
                connectionFactory.enableHostnameVerification();
            }
            this.connection = connectionFactory.newConnection();
            this.channel = this.connection.createChannel();
            this.channel.queueDeclare(this.queueName, true, false, false, null);
            this.channel.basicQos(1);
            logger.info("Connection to " + this.queueUrl + " with queueName '" + this.queueName + "' successful");
            logger.info("Enable ack dequeue: " + this.enableDequeueAckRecord);
        } catch (IOException | URISyntaxException | KeyManagementException | NoSuchAlgorithmException | TimeoutException e) {
            String str = "Error initializing RabbitMQ Spooler with queueUrl '" + this.queueUrl + "' and queueName '" + this.queueName + "'." + (e.getMessage() != null ? " " + e.getMessage() : "");
            logger.error(str);
            logger.debug("\nDetails: " + Arrays.toString(e.getStackTrace()));
            throw new DRCommonException(str);
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public void close() {
        try {
            if (this.channel != null && this.channel.isOpen()) {
                logger.info("Closing channel for queue: '" + this.queueName + ";");
                this.channel.close();
                this.channel = null;
            }
            if (this.connection != null && this.connection.isOpen()) {
                logger.info("Closing connection to: '" + this.queueUrl + ";");
                this.connection.close();
                this.connection = null;
            }
        } catch (IOException e) {
            logger.warn("RabbitMQ connection closed with error: " + e.getMessage());
        } catch (TimeoutException e2) {
            logger.warn("RabbitMQ connection closed with timeout error: " + e2.getMessage());
        }
    }

    @Override // com.datarobot.mlops.common.spooler.Spooler
    public boolean needsRetry() {
        return true;
    }
}
