package io.github.tfahub.libsvm;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Reader;
import java.io.Writer;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.logging.Level;
import java.util.logging.Logger;

/* loaded from: input_file:io/github/tfahub/libsvm/SvmTrain.class */
public class SvmTrain {
    private static final Logger logger = Logger.getLogger(SvmTrain.class.getName());
    private SvmParameter parameter = new SvmParameter();
    private SvmProblem problem;

    public SvmTrain svmType(SvmType svmType) {
        this.parameter.svmType = svmType;
        return this;
    }

    public SvmTrain kernelType(KernelType kernelType) {
        this.parameter.kernelType = kernelType;
        return this;
    }

    public SvmTrain degree(int i) {
        this.parameter.degree = i;
        return this;
    }

    public SvmTrain gamma(double d) {
        this.parameter.gamma = d;
        return this;
    }

    public SvmTrain coef0(double d) {
        this.parameter.coef0 = d;
        return this;
    }

    public SvmTrain cacheSize(long j) {
        this.parameter.cacheSize = j;
        return this;
    }

    public SvmTrain eps(double d) {
        this.parameter.eps = d;
        return this;
    }

    public SvmTrain cost(double d) {
        this.parameter.c = d;
        return this;
    }

    public SvmTrain weight(int i, int[] iArr, double[] dArr) {
        this.parameter.nrWeight = i;
        this.parameter.weightLabel = iArr;
        this.parameter.weight = dArr;
        return this;
    }

    public SvmTrain nu(double d) {
        this.parameter.nu = d;
        return this;
    }

    public SvmTrain epsilon(double d) {
        this.parameter.p = d;
        return this;
    }

    public SvmTrain shrinking(boolean z) {
        this.parameter.shrinking = z;
        return this;
    }

    public SvmTrain probability(boolean z) {
        this.parameter.probability = z;
        return this;
    }

    public SvmTrain parameter(SvmParameter svmParameter) {
        this.parameter = svmParameter;
        return this;
    }

    public SvmTrain problem(File file) throws IOException {
        problem(file.toPath());
        return this;
    }

    public SvmTrain problem(Path path) throws IOException {
        BufferedReader newBufferedReader = Files.newBufferedReader(path, StandardCharsets.ISO_8859_1);
        Throwable th = null;
        try {
            try {
                problem(newBufferedReader);
                if (newBufferedReader != null) {
                    if (0 != 0) {
                        try {
                            newBufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newBufferedReader.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Throwable th3) {
            if (newBufferedReader != null) {
                if (th != null) {
                    try {
                        newBufferedReader.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    newBufferedReader.close();
                }
            }
            throw th3;
        }
    }

    public SvmTrain problem(URL url) throws IOException {
        InputStream openStream = url.openStream();
        Throwable th = null;
        try {
            try {
                problem(openStream);
                if (openStream != null) {
                    if (0 != 0) {
                        try {
                            openStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        openStream.close();
                    }
                }
                return this;
            } finally {
            }
        } catch (Throwable th3) {
            if (openStream != null) {
                if (th != null) {
                    try {
                        openStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    openStream.close();
                }
            }
            throw th3;
        }
    }

    public SvmTrain problem(InputStream inputStream) throws IOException {
        problem(new InputStreamReader(inputStream, StandardCharsets.ISO_8859_1));
        return this;
    }

    public SvmTrain problem(Reader reader) throws IOException {
        problem(Utils.readProblem(reader));
        return this;
    }

    public SvmTrain problem(SvmProblem svmProblem) {
        this.problem = svmProblem;
        return this;
    }

    public SvmModel train() {
        return Svm.svmTrain(this.problem, prepareTrain(this.parameter, this.problem));
    }

    public SvmModel train(File file) throws IOException {
        return train(file.toPath());
    }

    public SvmModel train(Path path) throws IOException {
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(path, StandardCharsets.ISO_8859_1, new OpenOption[0]);
        Throwable th = null;
        try {
            try {
                SvmModel train = train(newBufferedWriter);
                if (newBufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            newBufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newBufferedWriter.close();
                    }
                }
                return train;
            } finally {
            }
        } catch (Throwable th3) {
            if (newBufferedWriter != null) {
                if (th != null) {
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    newBufferedWriter.close();
                }
            }
            throw th3;
        }
    }

    public SvmModel train(OutputStream outputStream) throws IOException {
        return train(new OutputStreamWriter(outputStream, StandardCharsets.ISO_8859_1));
    }

    public SvmModel train(Writer writer) throws IOException {
        SvmModel train = train();
        Svm.svmSaveModel(writer instanceof BufferedWriter ? (BufferedWriter) writer : new BufferedWriter(writer), train);
        return train;
    }

    public void crossValidation(int i) {
        SvmParameter prepareTrain = prepareTrain(this.parameter, this.problem);
        if (i < 2) {
            throw new IllegalParameterException("n-fold cross validation: n must >= 2; nrFold=" + i);
        }
        double[] dArr = new double[this.problem.l];
        Svm.svmCrossValidation(this.problem, prepareTrain, i, dArr);
        if (prepareTrain.svmType != SvmType.EPSILON_SVR && prepareTrain.svmType != SvmType.NU_SVR) {
            int i2 = 0;
            for (int i3 = 0; i3 < this.problem.l; i3++) {
                if (dArr[i3] == this.problem.y[i3]) {
                    i2++;
                }
            }
            if (logger.isLoggable(Level.INFO)) {
                logger.info("Cross Validation Accuracy = " + ((100.0d * i2) / this.problem.l) + "%");
                return;
            }
            return;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        for (int i4 = 0; i4 < this.problem.l; i4++) {
            double d7 = this.problem.y[i4];
            double d8 = dArr[i4];
            d += (d8 - d7) * (d8 - d7);
            d2 += d8;
            d3 += d7;
            d4 += d8 * d8;
            d5 += d7 * d7;
            d6 += d8 * d7;
        }
        if (logger.isLoggable(Level.INFO)) {
            logger.info("Cross Validation Mean squared error = " + (d / this.problem.l));
            logger.info("Cross Validation Squared correlation coefficient = " + ((((this.problem.l * d6) - (d2 * d3)) * ((this.problem.l * d6) - (d2 * d3))) / (((this.problem.l * d4) - (d2 * d2)) * ((this.problem.l * d5) - (d3 * d3)))));
        }
    }

    private static SvmParameter prepareTrain(SvmParameter svmParameter, SvmProblem svmProblem) {
        SvmParameter svmParameter2;
        Utils.requireNonNull(svmParameter, "parameter should not be null");
        Utils.requireNonNull(svmProblem, "problem should not be null");
        int maxIndex = getMaxIndex(svmProblem);
        if (svmParameter.kernelType == KernelType.PRECOMPUTED) {
            for (int i = 0; i < svmProblem.l; i++) {
                if (svmProblem.x[i][0].index != 0) {
                    throw new IllegalParameterException("Wrong kernel matrix: first column must be 0:sample_serial_number; column[%d][0]=%s", Integer.valueOf(i), svmProblem.x[i][0]);
                }
                if (((int) svmProblem.x[i][0].value) <= 0 || ((int) svmProblem.x[i][0].value) > maxIndex) {
                    throw new IllegalParameterException("Wrong input format: sample_serial_number out of range; column[%d][0]=%s", Integer.valueOf(i), svmProblem.x[i][0]);
                }
            }
        }
        if (svmParameter.gamma != 0.0d || maxIndex <= 0) {
            svmParameter2 = svmParameter;
        } else {
            svmParameter2 = svmParameter.m7clone();
            svmParameter2.gamma = 1.0d / maxIndex;
        }
        Svm.svmCheckParameter(svmProblem, svmParameter2);
        return svmParameter2;
    }

    private static int getMaxIndex(SvmProblem svmProblem) {
        int i = 0;
        for (SvmNode[] svmNodeArr : svmProblem.x) {
            if (svmNodeArr.length > 0) {
                i = Math.max(i, svmNodeArr[svmNodeArr.length - 1].index);
            }
        }
        return i;
    }
}
