package com.actelion.research.calc.regression.svm;

import com.actelion.research.calc.CorrelationCalculator;
import com.actelion.research.calc.Matrix;
import com.actelion.research.calc.ProgressController;
import com.actelion.research.calc.regression.ARegressionMethod;
import com.actelion.research.calc.regression.ParameterRegressionMethod;
import com.actelion.research.util.Formatter;
import com.actelion.research.util.datamodel.DoubleArray;
import com.actelion.research.util.datamodel.ModelXYIndex;
import java.util.ArrayList;
import java.util.Collections;
import org.machinelearning.svm.libsvm.svm_model;
import org.machinelearning.svm.libsvm.svm_problem;

/* loaded from: input_file:com/actelion/research/calc/regression/svm/SVMRegression.class */
public class SVMRegression extends ARegressionMethod<ParameterSVM> implements Comparable<SVMRegression> {
    public static final int LIMIT_ROWS_ANALYTICAL = 1000;
    public static boolean VERBOSE = false;
    public static final double NU = 0.1d;
    private svm_model modelSVM;

    public SVMRegression() {
        setParameterRegressionMethod(new ParameterSVM(SVMParameterHelper.regressionEpsilonSVR()));
    }

    public SVMRegression(ParameterSVM parameterSVM) {
        setParameterRegressionMethod(parameterSVM);
    }

    @Override // com.actelion.research.calc.regression.ICalculateModel
    public Matrix createModel(ModelXYIndex modelXYIndex) {
        ModelXYIndex modelXYIndex2;
        if (((ParameterSVM) this.parameterRegressionMethod).getSvmParameter().degree == -1) {
            int rows = modelXYIndex.X.rows();
            if (rows > 1000) {
                ArrayList arrayList = new ArrayList(rows);
                for (int i = 0; i < rows; i++) {
                    arrayList.add(Integer.valueOf(i));
                }
                Collections.shuffle(arrayList);
                modelXYIndex2 = modelXYIndex.sub(new ArrayList(arrayList.subList(0, 1000)));
            } else {
                modelXYIndex2 = modelXYIndex;
            }
            setParameterRegressionMethod(AnalyticalParameterCalculatorSVM.calculate(modelXYIndex2));
        }
        int rows2 = modelXYIndex.X.rows();
        int cols = modelXYIndex.Y.cols();
        if (cols != 1) {
            throw new RuntimeException("Only one y column allowed!");
        }
        Matrix matrix = new Matrix(rows2, cols);
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = rows2;
        svm_problemVar.x = Matrix2SVMNodeConverter.convert(modelXYIndex.X);
        if (getParameter().getGamma() <= 0.0d) {
            getParameter().setGamma(1.0d / modelXYIndex.X.cols());
        }
        boolean z = false;
        try {
            double[] colAsDouble = modelXYIndex.Y.getColAsDouble(0);
            DoubleArray doubleArray = new DoubleArray(colAsDouble);
            svm_problemVar.y = colAsDouble;
            String svm_check_parameter = svm.svm_check_parameter(svm_problemVar, getParameter().getSvmParameter());
            if (svm_check_parameter != null) {
                System.err.print("SVMRegressionMultiY svm_check_parameter error: " + svm_check_parameter + "\n");
                z = true;
            }
            ProgressController progressController = getProgressController();
            this.modelSVM = svm.svm_train(svm_problemVar, getParameter().getSvmParameter(), progressController);
            if (progressController != null) {
                progressController.startProgress("Calculate train data fit", 0, rows2);
            }
            DoubleArray doubleArray2 = new DoubleArray(rows2);
            int i2 = 0;
            while (true) {
                if (i2 >= rows2) {
                    break;
                }
                double svm_predict = svm.svm_predict(this.modelSVM, svm_problemVar.x[i2]);
                doubleArray2.add(svm_predict);
                matrix.set(i2, 0, svm_predict);
                if (progressController != null) {
                    progressController.updateProgress(i2);
                    if (progressController.threadMustDie()) {
                        z = true;
                        break;
                    }
                }
                i2++;
            }
            if (VERBOSE) {
                double calculateCorrelation = new CorrelationCalculator().calculateCorrelation(doubleArray, doubleArray2, 0);
                System.out.println("SVMRegressionMultiY model r2 " + Formatter.format4(Double.valueOf(calculateCorrelation * calculateCorrelation)) + ".");
            }
        } catch (Exception e) {
            e.printStackTrace();
            if (VERBOSE) {
                System.err.println("SVMRegressionMultiY break.");
            }
            z = true;
        }
        if (z) {
            matrix = null;
        }
        return matrix;
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public Matrix calculateYHat(Matrix matrix) {
        int rows = matrix.rows();
        Matrix matrix2 = new Matrix(rows, 1);
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = rows;
        svm_problemVar.x = Matrix2SVMNodeConverter.convert(matrix);
        for (int i = 0; i < rows; i++) {
            matrix2.set(i, 0, svm.svm_predict(this.modelSVM, svm_problemVar.x[i]));
        }
        return matrix2;
    }

    @Override // com.actelion.research.calc.regression.ICalculateYHat
    public double calculateYHat(double[] dArr) {
        double svm_predict;
        synchronized (this) {
            svm_problem svm_problemVar = new svm_problem();
            svm_problemVar.l = 1;
            svm_problemVar.x = Matrix2SVMNodeConverter.convertSingleRow(dArr);
            svm_predict = svm.svm_predict(this.modelSVM, svm_problemVar.x[0]);
        }
        return svm_predict;
    }

    public void setNu(double d) {
        getParameter().setNu(d);
    }

    public void setGamma(double d) {
        getParameter().setGamma(d);
    }

    @Override // java.lang.Comparable
    public int compareTo(SVMRegression sVMRegression) {
        return getParameter().compareTo((ParameterRegressionMethod) sVMRegression.getParameter());
    }
}
