package org.datavec.python;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.BooleanWritable;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.json.JSONObject;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.holder.ObjectMapperHolder;
import org.nd4j.shade.jackson.core.JsonProcessingException;

/* loaded from: input_file:org/datavec/python/PythonTransform.class */
public class PythonTransform implements Transform {
    private String code;
    private PythonVariables inputs;
    private PythonVariables outputs;
    private String name;
    private Schema inputSchema;
    private Schema outputSchema;
    private String outputDict;
    private boolean returnAllVariables;
    private boolean setupAndRun;
    private PythonJob pythonJob;

    /* loaded from: input_file:org/datavec/python/PythonTransform$PythonTransformBuilder.class */
    public static class PythonTransformBuilder {
        private String code;
        private PythonVariables inputs;
        private PythonVariables outputs;
        private String name;
        private Schema inputSchema;
        private Schema outputSchema;
        private String outputDict;
        private boolean returnAllInputs;
        private boolean setupAndRun;

        PythonTransformBuilder() {
        }

        public PythonTransformBuilder code(String str) {
            this.code = str;
            return this;
        }

        public PythonTransformBuilder inputs(PythonVariables pythonVariables) {
            this.inputs = pythonVariables;
            return this;
        }

        public PythonTransformBuilder outputs(PythonVariables pythonVariables) {
            this.outputs = pythonVariables;
            return this;
        }

        public PythonTransformBuilder name(String str) {
            this.name = str;
            return this;
        }

        public PythonTransformBuilder inputSchema(Schema schema) {
            this.inputSchema = schema;
            return this;
        }

        public PythonTransformBuilder outputSchema(Schema schema) {
            this.outputSchema = schema;
            return this;
        }

        public PythonTransformBuilder outputDict(String str) {
            this.outputDict = str;
            return this;
        }

        public PythonTransformBuilder returnAllInputs(boolean z) {
            this.returnAllInputs = z;
            return this;
        }

        public PythonTransformBuilder setupAndRun(boolean z) {
            this.setupAndRun = z;
            return this;
        }

        public PythonTransform build() {
            return new PythonTransform(this.code, this.inputs, this.outputs, this.name, this.inputSchema, this.outputSchema, this.outputDict, this.returnAllInputs, this.setupAndRun);
        }

        public String toString() {
            return "PythonTransform.PythonTransformBuilder(code=" + this.code + ", inputs=" + this.inputs + ", outputs=" + this.outputs + ", name=" + this.name + ", inputSchema=" + this.inputSchema + ", outputSchema=" + this.outputSchema + ", outputDict=" + this.outputDict + ", returnAllInputs=" + this.returnAllInputs + ", setupAndRun=" + this.setupAndRun + ")";
        }
    }

    public PythonTransform(String str, PythonVariables pythonVariables, PythonVariables pythonVariables2, String str2, Schema schema, Schema schema2, String str3, boolean z, boolean z2) {
        this.name = UUID.randomUUID().toString();
        this.setupAndRun = false;
        Preconditions.checkNotNull(str, "No code found to run!");
        this.code = str;
        this.returnAllVariables = z;
        this.setupAndRun = z2;
        if (pythonVariables != null) {
            this.inputs = pythonVariables;
        }
        if (pythonVariables2 != null) {
            this.outputs = pythonVariables2;
        }
        if (str2 != null) {
            this.name = str2;
        }
        if (str3 != null) {
            this.outputDict = str3;
            this.outputs = new PythonVariables();
            this.outputs.addDict(str3);
        }
        if (schema != null) {
            try {
                this.inputSchema = schema;
                if (pythonVariables == null || pythonVariables.isEmpty()) {
                    this.inputs = PythonUtils.schemaToPythonVariables(schema);
                }
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }
        if (schema2 != null) {
            this.outputSchema = schema2;
            if (pythonVariables2 == null || pythonVariables2.isEmpty()) {
                this.outputs = PythonUtils.schemaToPythonVariables(schema2);
            }
        }
        try {
            this.pythonJob = PythonJob.builder().name("a" + UUID.randomUUID().toString().replace("-", "_")).code(str).setupRunMode(z2).build();
        } catch (Exception e2) {
            throw new IllegalStateException("Error creating python job: " + e2);
        }
    }

    public void setInputSchema(Schema schema) {
        Preconditions.checkNotNull(schema, "No input schema found!");
        this.inputSchema = schema;
        try {
            this.inputs = PythonUtils.schemaToPythonVariables(schema);
            if (this.outputSchema == null && this.outputDict == null) {
                this.outputSchema = schema;
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public Schema getInputSchema() {
        return this.inputSchema;
    }

    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Writable>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(map(it.next()));
        }
        return arrayList;
    }

    public Object map(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public List<Writable> map(List<Writable> list) {
        PythonVariables pyInputsFromWritables = getPyInputsFromWritables(list);
        Preconditions.checkNotNull(pyInputsFromWritables, "Inputs must not be null!");
        try {
            if (this.returnAllVariables) {
                return getWritablesFromPyOutputs(this.pythonJob.execAndReturnAllVariables(pyInputsFromWritables));
            }
            if (this.outputDict != null) {
                this.pythonJob.exec(pyInputsFromWritables, this.outputs);
                return getWritablesFromPyOutputs(PythonUtils.expandInnerDict(this.outputs, this.outputDict));
            }
            this.pythonJob.exec(pyInputsFromWritables, this.outputs);
            return getWritablesFromPyOutputs(this.outputs);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public String[] outputColumnNames() {
        return this.outputs.getVariables();
    }

    public String outputColumnName() {
        return outputColumnNames()[0];
    }

    public String[] columnNames() {
        return this.outputs.getVariables();
    }

    public String columnName() {
        return columnNames()[0];
    }

    public Schema transform(Schema schema) {
        return this.outputSchema;
    }

    private PythonVariables getPyInputsFromWritables(List<Writable> list) {
        PythonVariables pythonVariables = new PythonVariables();
        for (String str : this.inputs.getVariables()) {
            LongWritable longWritable = (Writable) list.get(this.inputSchema.getIndexOfColumn(str));
            PythonType type = this.inputs.getType(str);
            switch (type.getName()) {
                case INT:
                    if (longWritable instanceof LongWritable) {
                        pythonVariables.addInt(str, longWritable.get());
                        break;
                    } else {
                        pythonVariables.addInt(str, ((IntWritable) longWritable).get());
                        break;
                    }
                case FLOAT:
                    if (longWritable instanceof DoubleWritable) {
                        pythonVariables.addFloat(str, ((DoubleWritable) longWritable).get());
                        break;
                    } else {
                        pythonVariables.addFloat(str, ((FloatWritable) longWritable).get());
                        break;
                    }
                case STR:
                    pythonVariables.addStr(str, longWritable.toString());
                    break;
                case NDARRAY:
                    pythonVariables.addNDArray(str, ((NDArrayWritable) longWritable).get());
                    break;
                case BOOL:
                    pythonVariables.addBool(str, ((BooleanWritable) longWritable).get());
                    break;
                default:
                    throw new RuntimeException("Unsupported input type:" + type);
            }
        }
        return pythonVariables;
    }

    private List<Writable> getWritablesFromPyOutputs(PythonVariables pythonVariables) {
        ArrayList arrayList = new ArrayList();
        String[] variables = pythonVariables.getVariables();
        Schema.Builder builder = new Schema.Builder();
        for (String str : variables) {
            PythonType type = pythonVariables.getType(str);
            switch (type.getName()) {
                case INT:
                    builder.addColumnLong(str);
                    break;
                case FLOAT:
                    builder.addColumnDouble(str);
                    break;
                case STR:
                case DICT:
                case LIST:
                    builder.addColumnString(str);
                    break;
                case NDARRAY:
                    builder.addColumnNDArray(str, pythonVariables.getNDArrayValue(str).shape());
                    break;
                case BOOL:
                    builder.addColumnBoolean(str);
                    break;
                default:
                    throw new IllegalStateException("Unable to support type " + type.getName());
            }
        }
        this.outputSchema = builder.build();
        for (String str2 : variables) {
            PythonType type2 = pythonVariables.getType(str2);
            switch (type2.getName()) {
                case INT:
                    arrayList.add(new LongWritable(pythonVariables.getIntValue(str2).longValue()));
                    break;
                case FLOAT:
                    arrayList.add(new DoubleWritable(pythonVariables.getFloatValue(str2).doubleValue()));
                    break;
                case STR:
                    arrayList.add(new Text(pythonVariables.getStrValue(str2)));
                    break;
                case NDARRAY:
                    arrayList.add(new NDArrayWritable(pythonVariables.getNDArrayValue(str2)));
                    break;
                case BOOL:
                    arrayList.add(new BooleanWritable(pythonVariables.getBooleanValue(str2)));
                    break;
                case DICT:
                    Map<?, ?> dictValue = pythonVariables.getDictValue(str2);
                    HashMap hashMap = new HashMap();
                    for (Map.Entry<?, ?> entry : dictValue.entrySet()) {
                        if (entry.getValue() != JSONObject.NULL) {
                            hashMap.put(entry.getKey(), entry.getValue());
                        }
                    }
                    try {
                        arrayList.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(hashMap)));
                        break;
                    } catch (JsonProcessingException e) {
                        throw new IllegalStateException("Unable to serialize dictionary " + str2 + " to json!");
                    }
                case LIST:
                    try {
                        arrayList.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(pythonVariables.getListValue(str2).toArray())));
                        break;
                    } catch (JsonProcessingException e2) {
                        throw new IllegalStateException("Unable to serialize list vlaue " + str2 + " to json!");
                    }
                default:
                    throw new IllegalStateException("Unable to support type " + type2.getName());
            }
        }
        return arrayList;
    }

    public static PythonTransformBuilder builder() {
        return new PythonTransformBuilder();
    }

    public PythonTransform() {
        this.name = UUID.randomUUID().toString();
        this.setupAndRun = false;
    }

    public String getCode() {
        return this.code;
    }

    public PythonVariables getInputs() {
        return this.inputs;
    }

    public PythonVariables getOutputs() {
        return this.outputs;
    }

    public String getName() {
        return this.name;
    }

    public Schema getOutputSchema() {
        return this.outputSchema;
    }

    public String getOutputDict() {
        return this.outputDict;
    }

    public boolean isReturnAllVariables() {
        return this.returnAllVariables;
    }

    public boolean isSetupAndRun() {
        return this.setupAndRun;
    }

    public PythonJob getPythonJob() {
        return this.pythonJob;
    }

    public void setCode(String str) {
        this.code = str;
    }

    public void setInputs(PythonVariables pythonVariables) {
        this.inputs = pythonVariables;
    }

    public void setOutputs(PythonVariables pythonVariables) {
        this.outputs = pythonVariables;
    }

    public void setName(String str) {
        this.name = str;
    }

    public void setOutputSchema(Schema schema) {
        this.outputSchema = schema;
    }

    public void setOutputDict(String str) {
        this.outputDict = str;
    }

    public void setReturnAllVariables(boolean z) {
        this.returnAllVariables = z;
    }

    public void setSetupAndRun(boolean z) {
        this.setupAndRun = z;
    }

    public void setPythonJob(PythonJob pythonJob) {
        this.pythonJob = pythonJob;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof PythonTransform)) {
            return false;
        }
        PythonTransform pythonTransform = (PythonTransform) obj;
        if (!pythonTransform.canEqual(this)) {
            return false;
        }
        String code = getCode();
        String code2 = pythonTransform.getCode();
        if (code == null) {
            if (code2 != null) {
                return false;
            }
        } else if (!code.equals(code2)) {
            return false;
        }
        PythonVariables inputs = getInputs();
        PythonVariables inputs2 = pythonTransform.getInputs();
        if (inputs == null) {
            if (inputs2 != null) {
                return false;
            }
        } else if (!inputs.equals(inputs2)) {
            return false;
        }
        PythonVariables outputs = getOutputs();
        PythonVariables outputs2 = pythonTransform.getOutputs();
        if (outputs == null) {
            if (outputs2 != null) {
                return false;
            }
        } else if (!outputs.equals(outputs2)) {
            return false;
        }
        String name = getName();
        String name2 = pythonTransform.getName();
        if (name == null) {
            if (name2 != null) {
                return false;
            }
        } else if (!name.equals(name2)) {
            return false;
        }
        Schema inputSchema = getInputSchema();
        Schema inputSchema2 = pythonTransform.getInputSchema();
        if (inputSchema == null) {
            if (inputSchema2 != null) {
                return false;
            }
        } else if (!inputSchema.equals(inputSchema2)) {
            return false;
        }
        Schema outputSchema = getOutputSchema();
        Schema outputSchema2 = pythonTransform.getOutputSchema();
        if (outputSchema == null) {
            if (outputSchema2 != null) {
                return false;
            }
        } else if (!outputSchema.equals(outputSchema2)) {
            return false;
        }
        String outputDict = getOutputDict();
        String outputDict2 = pythonTransform.getOutputDict();
        if (outputDict == null) {
            if (outputDict2 != null) {
                return false;
            }
        } else if (!outputDict.equals(outputDict2)) {
            return false;
        }
        if (isReturnAllVariables() != pythonTransform.isReturnAllVariables() || isSetupAndRun() != pythonTransform.isSetupAndRun()) {
            return false;
        }
        PythonJob pythonJob = getPythonJob();
        PythonJob pythonJob2 = pythonTransform.getPythonJob();
        return pythonJob == null ? pythonJob2 == null : pythonJob.equals(pythonJob2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof PythonTransform;
    }

    public int hashCode() {
        String code = getCode();
        int hashCode = (1 * 59) + (code == null ? 43 : code.hashCode());
        PythonVariables inputs = getInputs();
        int hashCode2 = (hashCode * 59) + (inputs == null ? 43 : inputs.hashCode());
        PythonVariables outputs = getOutputs();
        int hashCode3 = (hashCode2 * 59) + (outputs == null ? 43 : outputs.hashCode());
        String name = getName();
        int hashCode4 = (hashCode3 * 59) + (name == null ? 43 : name.hashCode());
        Schema inputSchema = getInputSchema();
        int hashCode5 = (hashCode4 * 59) + (inputSchema == null ? 43 : inputSchema.hashCode());
        Schema outputSchema = getOutputSchema();
        int hashCode6 = (hashCode5 * 59) + (outputSchema == null ? 43 : outputSchema.hashCode());
        String outputDict = getOutputDict();
        int hashCode7 = (((((hashCode6 * 59) + (outputDict == null ? 43 : outputDict.hashCode())) * 59) + (isReturnAllVariables() ? 79 : 97)) * 59) + (isSetupAndRun() ? 79 : 97);
        PythonJob pythonJob = getPythonJob();
        return (hashCode7 * 59) + (pythonJob == null ? 43 : pythonJob.hashCode());
    }

    public String toString() {
        return "PythonTransform(code=" + getCode() + ", inputs=" + getInputs() + ", outputs=" + getOutputs() + ", name=" + getName() + ", inputSchema=" + getInputSchema() + ", outputSchema=" + getOutputSchema() + ", outputDict=" + getOutputDict() + ", returnAllVariables=" + isReturnAllVariables() + ", setupAndRun=" + isSetupAndRun() + ", pythonJob=" + getPythonJob() + ")";
    }
}
