package org.canova.image.lfw;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.FileFileFilter;
import org.canova.api.util.ArchiveUtils;
import org.canova.image.loader.ImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/canova/image/lfw/LFWLoader.class */
public class LFWLoader {
    private File baseDir;
    public static final String LFW = "lfw";
    private File lfwDir;
    public static final String LFW_URL = "http://vis-www.cs.umass.edu/lfw/lfw.tgz";
    private File lfwTarFile;
    private static Logger log = LoggerFactory.getLogger(LFWLoader.class);
    private int numNames;
    private int numPixelColumns;
    private ImageLoader loader;
    private List<String> images;
    private List<String> outcomes;

    public LFWLoader() {
        this(28, 28);
    }

    public LFWLoader(int i, int i2) {
        this.baseDir = new File(System.getProperty("user.home"));
        this.lfwDir = new File(this.baseDir, LFW);
        this.lfwTarFile = new File(this.lfwDir, "lfw.tgz");
        this.loader = new ImageLoader(28, 28);
        this.images = new ArrayList();
        this.outcomes = new ArrayList();
        this.loader = new ImageLoader(i, i2);
    }

    public void getIfNotExists() throws Exception {
        if (!this.lfwDir.exists()) {
            this.lfwDir.mkdir();
            log.info("Grabbing LFW...");
            ReadableByteChannel newChannel = Channels.newChannel(new URL(LFW_URL).openStream());
            if (!this.lfwTarFile.exists()) {
                this.lfwTarFile.createNewFile();
            }
            FileOutputStream fileOutputStream = new FileOutputStream(this.lfwTarFile);
            fileOutputStream.getChannel().transferFrom(newChannel, 0L, Long.MAX_VALUE);
            fileOutputStream.flush();
            IOUtils.closeQuietly(fileOutputStream);
            newChannel.close();
            log.info("Downloaded lfw");
            untarFile(this.baseDir, this.lfwTarFile);
        }
        File file = null;
        try {
            file = this.lfwDir.listFiles()[0].listFiles()[0];
        } catch (Exception e) {
            FileUtils.deleteDirectory(this.lfwDir);
            log.warn("Error opening first image; probably corrupt download...trying again", e);
            getIfNotExists();
        }
        this.numPixelColumns = ArrayUtil.flatten(this.loader.fromFile(file)).length;
        this.numNames = this.lfwDir.getAbsoluteFile().listFiles().length;
        Iterator it = FileUtils.listFiles(this.lfwDir, FileFileFilter.FILE, DirectoryFileFilter.DIRECTORY).iterator();
        while (it.hasNext()) {
            this.images.add(((File) it.next()).getAbsolutePath());
        }
        for (File file2 : this.lfwDir.getAbsoluteFile().listFiles()) {
            this.outcomes.add(file2.getAbsolutePath());
        }
    }

    public DataSet convertListPairs(List<DataSet> list) {
        INDArray create = Nd4j.create(list.size(), this.numPixelColumns);
        INDArray create2 = Nd4j.create(list.size(), this.numNames);
        for (int i = 0; i < list.size(); i++) {
            create.putRow(i, list.get(i).getFeatureMatrix());
            create2.putRow(i, list.get(i).getLabels());
        }
        return new DataSet(create, create2);
    }

    public DataSet getDataFor(int i) {
        File file = new File(this.images.get(i));
        try {
            return new DataSet(this.loader.asRowVector(file), FeatureUtil.toOutcomeVector(this.outcomes.indexOf(file.getParentFile().getAbsolutePath()), this.outcomes.size()));
        } catch (Exception e) {
            throw new IllegalStateException("Unable to getFromOrigin data for image " + i + " for path " + this.images.get(i));
        }
    }

    public List<DataSet> getFeatureMatrix(int i) throws Exception {
        ArrayList arrayList = new ArrayList(i);
        int i2 = 0;
        for (File file : this.lfwDir.listFiles()) {
            arrayList.addAll(getImages(i2, file));
            i2++;
            if (arrayList.size() >= i) {
                break;
            }
        }
        return arrayList;
    }

    public DataSet getAllImagesAsMatrix() throws Exception {
        return convertListPairs(getImagesAsList());
    }

    public DataSet getAllImagesAsMatrix(int i) throws Exception {
        return convertListPairs(getImagesAsList().subList(0, i));
    }

    public List<DataSet> getImagesAsList() throws Exception {
        ArrayList arrayList = new ArrayList();
        File[] listFiles = this.lfwDir.listFiles();
        for (int i = 0; i < listFiles.length; i++) {
            arrayList.addAll(getImages(i, listFiles[i]));
        }
        return arrayList;
    }

    public List<DataSet> getImages(int i, File file) throws Exception {
        File[] listFiles = file.listFiles();
        ArrayList arrayList = new ArrayList();
        for (File file2 : listFiles) {
            arrayList.add(fromImageFile(i, file2));
        }
        return arrayList;
    }

    public DataSet fromImageFile(int i, File file) throws Exception {
        return new DataSet(ArrayUtil.toNDArray(this.loader.flattenedImageFromFile(file)), FeatureUtil.toOutcomeVector(i, this.numNames));
    }

    public void untarFile(File file, File file2) throws IOException {
        log.info("Untaring File: " + file2.toString());
        ArchiveUtils.unzipFileTo(file2.getAbsolutePath(), file.getAbsolutePath());
    }

    public int getNumNames() {
        return this.numNames;
    }

    public int getNumPixelColumns() {
        return this.numPixelColumns;
    }
}
