package org.datavec.spark.transform;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.spark.transform.model.Base64NDArrayBody;
import org.datavec.spark.transform.model.BatchCSVRecord;
import org.datavec.spark.transform.model.SequenceBatchCSVRecord;
import org.datavec.spark.transform.model.SingleCSVRecord;
import org.nd4j.serde.base64.Nd4jBase64;

/* loaded from: input_file:org/datavec/spark/transform/CSVSparkTransform.class */
public class CSVSparkTransform {
    private TransformProcess transformProcess;
    private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

    public Base64NDArrayBody toArray(BatchCSVRecord batchCSVRecord) throws IOException {
        return new Base64NDArrayBody(Nd4jBase64.base64String(ArrowConverter.toArray(LocalTransformExecutor.execute(ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumnsString(bufferAllocator, this.transformProcess.getInitialSchema(), batchCSVRecord.getRecordsAsString()), this.transformProcess.getInitialSchema()), this.transformProcess))));
    }

    public Base64NDArrayBody toArray(SingleCSVRecord singleCSVRecord) throws IOException {
        return new Base64NDArrayBody(Nd4jBase64.base64String(RecordConverter.toArray((List) LocalTransformExecutor.execute(Arrays.asList(ArrowConverter.toArrowWritablesSingle(ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, this.transformProcess.getInitialSchema(), singleCSVRecord.getValues()), this.transformProcess.getInitialSchema())), this.transformProcess).get(0))));
    }

    public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
        BatchCSVRecord batchCSVRecord2 = new BatchCSVRecord();
        List execute = LocalTransformExecutor.execute(ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumnsString(bufferAllocator, this.transformProcess.getInitialSchema(), batchCSVRecord.getRecordsAsString()), this.transformProcess.getInitialSchema()), this.transformProcess);
        int size = ((List) execute.get(0)).size();
        for (int i = 0; i < execute.size(); i++) {
            String[] strArr = new String[size];
            for (int i2 = 0; i2 < strArr.length; i2++) {
                strArr[i2] = ((Writable) ((List) execute.get(i)).get(i2)).toString();
            }
            batchCSVRecord2.add(new SingleCSVRecord(strArr));
        }
        return batchCSVRecord2;
    }

    public SingleCSVRecord transform(SingleCSVRecord singleCSVRecord) {
        List list = (List) LocalTransformExecutor.execute(Arrays.asList(ArrowConverter.toArrowWritablesSingle(ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, this.transformProcess.getInitialSchema(), singleCSVRecord.getValues()), this.transformProcess.getInitialSchema())), this.transformProcess).get(0);
        String[] strArr = new String[list.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = ((Writable) list.get(i)).toString();
        }
        return new SingleCSVRecord(strArr);
    }

    public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord batchCSVRecord) {
        List executeToSequence = LocalTransformExecutor.executeToSequence(ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, this.transformProcess.getInitialSchema(), Arrays.asList(batchCSVRecord.getRecordsAsString())), this.transformProcess.getInitialSchema()), this.transformProcess);
        SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
        for (int i = 0; i < executeToSequence.size(); i++) {
            sequenceBatchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromWritables((List) executeToSequence.get(i))));
        }
        return sequenceBatchCSVRecord;
    }

    public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord sequenceBatchCSVRecord) {
        List<List<List<String>>> recordsAsString = sequenceBatchCSVRecord.getRecordsAsString();
        boolean z = true;
        Integer num = null;
        for (List<List<String>> list : recordsAsString) {
            if (num == null) {
                num = Integer.valueOf(list.size());
            } else if (list.size() != num.intValue()) {
                z = false;
            }
        }
        return z ? SequenceBatchCSVRecord.fromWritables(LocalTransformExecutor.executeSequenceToSequence(new ArrowWritableRecordTimeSeriesBatch(ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, this.transformProcess.getInitialSchema(), recordsAsString), this.transformProcess.getInitialSchema(), recordsAsString.get(0).get(0).size()), this.transformProcess)) : SequenceBatchCSVRecord.fromWritables(LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(sequenceBatchCSVRecord.getRecordsAsString(), this.transformProcess.getInitialSchema()), this.transformProcess));
    }

    public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord sequenceBatchCSVRecord) {
        List<List<List<String>>> recordsAsString = sequenceBatchCSVRecord.getRecordsAsString();
        boolean z = true;
        Integer num = null;
        for (List<List<String>> list : recordsAsString) {
            if (num == null) {
                num = Integer.valueOf(list.size());
            } else if (list.size() != num.intValue()) {
                z = false;
            }
        }
        if (z) {
            try {
                return new Base64NDArrayBody(Nd4jBase64.base64String(RecordConverter.toTensor(LocalTransformExecutor.executeSequenceToSequence(new ArrowWritableRecordTimeSeriesBatch(ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, this.transformProcess.getInitialSchema(), recordsAsString), this.transformProcess.getInitialSchema(), recordsAsString.get(0).get(0).size()), this.transformProcess)).reshape(new int[]{recordsAsString.size(), recordsAsString.get(0).get(0).size(), recordsAsString.get(0).size()})));
            } catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }
        try {
            return new Base64NDArrayBody(Nd4jBase64.base64String(RecordConverter.toTensor(LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(sequenceBatchCSVRecord.getRecordsAsString(), this.transformProcess.getInitialSchema()), this.transformProcess)).reshape(new int[]{recordsAsString.size(), recordsAsString.get(0).get(0).size(), recordsAsString.get(0).size()})));
        } catch (IOException e2) {
            throw new IllegalStateException(e2);
        }
    }

    public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord batchCSVRecord) {
        try {
            return new Base64NDArrayBody(Nd4jBase64.base64String(RecordConverter.toTensor(LocalTransformExecutor.executeToSequence(ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumnsString(bufferAllocator, this.transformProcess.getInitialSchema(), batchCSVRecord.getRecordsAsString()), this.transformProcess.getInitialSchema()), this.transformProcess))));
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord sequenceBatchCSVRecord) {
        List<List<List<String>>> recordsAsString = sequenceBatchCSVRecord.getRecordsAsString();
        boolean z = true;
        Integer num = null;
        for (List<List<String>> list : recordsAsString) {
            if (num == null) {
                num = Integer.valueOf(list.size());
            } else if (list.size() != num.intValue()) {
                z = false;
            }
        }
        return z ? SequenceBatchCSVRecord.fromWritables(LocalTransformExecutor.executeSequenceToSequence(new ArrowWritableRecordTimeSeriesBatch(ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, this.transformProcess.getInitialSchema(), recordsAsString), this.transformProcess.getInitialSchema(), recordsAsString.get(0).get(0).size()), this.transformProcess)) : SequenceBatchCSVRecord.fromWritables(LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(sequenceBatchCSVRecord.getRecordsAsString(), this.transformProcess.getInitialSchema()), this.transformProcess));
    }

    public CSVSparkTransform(TransformProcess transformProcess) {
        this.transformProcess = transformProcess;
    }

    public TransformProcess getTransformProcess() {
        return this.transformProcess;
    }
}
