package tech.molecules.leet.chem.descriptor.featurepair;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:tech/molecules/leet/chem/descriptor/featurepair/FlexoDataset.class */
public class FlexoDataset {
    private List<FlexoDatapoint> data;

    /* loaded from: input_file:tech/molecules/leet/chem/descriptor/featurepair/FlexoDataset$FlexoDatapoint.class */
    public static class FlexoDatapoint {
        private List<String> spheres;
        private double macrocycleScore;
        private int[] histogramCounts;

        public FlexoDatapoint(List<String> list, double d, int[] iArr) {
            this.spheres = list;
            this.macrocycleScore = d;
            this.histogramCounts = iArr;
        }
    }

    /* loaded from: input_file:tech/molecules/leet/chem/descriptor/featurepair/FlexoDataset$FlexoPathEncoder.class */
    public static class FlexoPathEncoder {
        private Map<String, Integer> encoding;
        private int intUnknown;

        public FlexoPathEncoder(Map<String, Integer> map, int i) {
            this.encoding = map;
            this.intUnknown = i;
        }

        public int encode(String str) {
            if (str.isEmpty()) {
                return 0;
            }
            return this.encoding.containsKey(str) ? this.encoding.get(str).intValue() : this.intUnknown;
        }

        public void store(String str) throws IOException {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
            for (Map.Entry<String, Integer> entry : this.encoding.entrySet()) {
                bufferedWriter.write(entry.getKey() + "," + entry.getValue() + "\n");
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }

        public static FlexoPathEncoder load(String str) throws IOException {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            int i = -1;
            HashMap hashMap = new HashMap();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return new FlexoPathEncoder(hashMap, i + 1);
                }
                i = Math.max(i, Integer.parseInt(readLine.split(",")[1]));
            }
        }
    }

    public FlexoDataset(List<FlexoDatapoint> list) {
        this.data = list;
    }

    public List<FlexoDatapoint> getData() {
        return this.data;
    }

    public static void createPyTorchInputData_Standard256(FlexoDataset flexoDataset, String str, String str2) throws IOException {
        HashMap hashMap = new HashMap();
        for (FlexoDatapoint flexoDatapoint : flexoDataset.getData()) {
            if (flexoDatapoint.spheres.size() <= 32) {
                for (int i = 0; i < flexoDatapoint.spheres.size(); i++) {
                    String str3 = (String) flexoDatapoint.spheres.get(i);
                    if (!str3.isEmpty()) {
                        if (hashMap.containsKey(str3)) {
                            hashMap.put(str3, Integer.valueOf(((Integer) hashMap.get(str3)).intValue() + 1));
                        } else {
                            hashMap.put(str3, 1);
                        }
                    }
                }
            }
        }
        List list = (List) hashMap.entrySet().stream().map(entry -> {
            return Pair.of((String) entry.getKey(), (Integer) entry.getValue());
        }).collect(Collectors.toList());
        list.sort((pair, pair2) -> {
            return -Integer.compare(((Integer) pair.getRight()).intValue(), ((Integer) pair2.getRight()).intValue());
        });
        HashMap hashMap2 = new HashMap();
        for (int i2 = 0; i2 < 254; i2++) {
            if (list.size() > i2) {
                hashMap2.put((String) ((Pair) list.get(i2)).getLeft(), Integer.valueOf(i2 + 1));
            }
        }
        FlexoPathEncoder flexoPathEncoder = new FlexoPathEncoder(hashMap2, 255);
        flexoPathEncoder.store(str2);
        createPyTorchInputData_Standard256(flexoDataset, str, flexoPathEncoder);
    }

    public static void createPyTorchInputData_Standard256(FlexoDataset flexoDataset, String str, FlexoPathEncoder flexoPathEncoder) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
        for (FlexoDatapoint flexoDatapoint : flexoDataset.getData()) {
            if (flexoDatapoint.spheres.size() <= 32) {
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < 32; i++) {
                    if (i < flexoDatapoint.spheres.size()) {
                        arrayList.add("" + flexoPathEncoder.encode((String) flexoDatapoint.spheres.get(i)));
                    } else {
                        arrayList.add("0");
                    }
                }
                bufferedWriter.write(String.join(",", arrayList));
                ArrayList arrayList2 = new ArrayList();
                for (int i2 = 0; i2 < flexoDatapoint.histogramCounts.length; i2++) {
                    arrayList2.add("" + flexoDatapoint.histogramCounts[i2]);
                }
                bufferedWriter.write(",");
                bufferedWriter.write(String.join(",", arrayList2));
                bufferedWriter.write("\n");
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }
}
