package smile.plot.swing;

import java.awt.Color;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import smile.data.DataFrame;
import smile.math.MathEx;

/* loaded from: input_file:smile/plot/swing/ScatterPlot.class */
public class ScatterPlot extends Plot {
    final Point[] points;
    final Legend[] legends;

    public ScatterPlot(Point... pointArr) {
        this(pointArr, null);
    }

    public ScatterPlot(Point[] pointArr, Legend[] legendArr) {
        this.points = pointArr;
        this.legends = legendArr;
    }

    @Override // smile.plot.swing.Shape
    public void paint(Graphics graphics) {
        for (Point point : this.points) {
            point.paint(graphics);
        }
    }

    @Override // smile.plot.swing.Plot
    public Optional<Legend[]> legends() {
        return Optional.ofNullable(this.legends);
    }

    @Override // smile.plot.swing.Plot
    public double[] getLowerBound() {
        double[] colMin = MathEx.colMin(this.points[0].points);
        for (int i = 1; i < this.points.length; i++) {
            for (double[] dArr : this.points[i].points) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (colMin[i2] > dArr[i2]) {
                        colMin[i2] = dArr[i2];
                    }
                }
            }
        }
        return colMin;
    }

    @Override // smile.plot.swing.Plot
    public double[] getUpperBound() {
        double[] colMax = MathEx.colMax(this.points[0].points);
        for (int i = 1; i < this.points.length; i++) {
            for (double[] dArr : this.points[i].points) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (colMax[i2] < dArr[i2]) {
                        colMax[i2] = dArr[i2];
                    }
                }
            }
        }
        return colMax;
    }

    public static ScatterPlot of(double[][] dArr) {
        return new ScatterPlot(Point.of(dArr));
    }

    public static ScatterPlot of(double[][] dArr, Color color) {
        return new ScatterPlot(Point.of(dArr, color));
    }

    public static ScatterPlot of(double[][] dArr, char c) {
        return new ScatterPlot(Point.of(dArr, c));
    }

    public static ScatterPlot of(double[][] dArr, char c, Color color) {
        return new ScatterPlot(new Point(dArr, c, color));
    }

    public static ScatterPlot of(double[][] dArr, String[] strArr, char c) {
        if (dArr.length != strArr.length) {
            throw new IllegalArgumentException("The number of points and that of labels are not the same.");
        }
        Map map = (Map) IntStream.range(0, dArr.length).boxed().collect(Collectors.groupingBy(num -> {
            return strArr[num.intValue()];
        }));
        Point[] pointArr = new Point[map.size()];
        Legend[] legendArr = new Legend[map.size()];
        int i = 0;
        for (Map.Entry entry : map.entrySet()) {
            Color color = Palette.get(i);
            pointArr[i] = new Point((double[][]) ((List) entry.getValue()).stream().map(num2 -> {
                return dArr[num2.intValue()];
            }).toArray(i2 -> {
                return new double[i2];
            }), c, color);
            legendArr[i] = new Legend((String) entry.getKey(), color);
            i++;
        }
        return new ScatterPlot(pointArr, legendArr);
    }

    public static ScatterPlot of(double[][] dArr, int[] iArr, char c) {
        return of(dArr, (String[]) Arrays.stream(iArr).mapToObj(i -> {
            return String.format("class %d", Integer.valueOf(i));
        }).toArray(i2 -> {
            return new String[i2];
        }), c);
    }

    public static ScatterPlot of(DataFrame dataFrame, String str, String str2, char c, Color color) {
        int indexOf = dataFrame.schema().indexOf(str);
        int indexOf2 = dataFrame.schema().indexOf(str2);
        return of((double[][]) dataFrame.stream().map(row -> {
            return new double[]{row.getDouble(indexOf), row.getDouble(indexOf2)};
        }).toArray(i -> {
            return new double[i];
        }), c, color);
    }

    public static ScatterPlot of(DataFrame dataFrame, String str, String str2, String str3, char c) {
        int indexOf = dataFrame.schema().indexOf(str);
        int indexOf2 = dataFrame.schema().indexOf(str2);
        return of((double[][]) dataFrame.stream().map(row -> {
            return new double[]{row.getDouble(indexOf), row.getDouble(indexOf2)};
        }).toArray(i -> {
            return new double[i];
        }), dataFrame.column(str3).toStringArray(), c);
    }

    public static ScatterPlot of(DataFrame dataFrame, String str, String str2, String str3, char c, Color color) {
        int indexOf = dataFrame.schema().indexOf(str);
        int indexOf2 = dataFrame.schema().indexOf(str2);
        int indexOf3 = dataFrame.schema().indexOf(str3);
        return of((double[][]) dataFrame.stream().map(row -> {
            return new double[]{row.getDouble(indexOf), row.getDouble(indexOf2), row.getDouble(indexOf3)};
        }).toArray(i -> {
            return new double[i];
        }), c, color);
    }

    public static ScatterPlot of(DataFrame dataFrame, String str, String str2, String str3, String str4, char c) {
        int indexOf = dataFrame.schema().indexOf(str);
        int indexOf2 = dataFrame.schema().indexOf(str2);
        int indexOf3 = dataFrame.schema().indexOf(str3);
        return of((double[][]) dataFrame.stream().map(row -> {
            return new double[]{row.getDouble(indexOf), row.getDouble(indexOf2), row.getDouble(indexOf3)};
        }).toArray(i -> {
            return new double[i];
        }), dataFrame.column(str4).toStringArray(), c);
    }
}
