package de.bioforscher.singa.structure.algorithms.superimposition.consensus;

import de.bioforscher.singa.core.utility.Pair;
import de.bioforscher.singa.mathematics.graphs.trees.BinaryTree;
import de.bioforscher.singa.mathematics.graphs.trees.BinaryTreeNode;
import de.bioforscher.singa.mathematics.matrices.LabeledSymmetricMatrix;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposer;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposition;
import de.bioforscher.singa.structure.algorithms.superimposition.consensus.ConsensusBuilder;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationScheme;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationSchemeFactory;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationSchemeType;
import de.bioforscher.singa.structure.model.families.AminoAcidFamily;
import de.bioforscher.singa.structure.model.identifiers.LeafIdentifier;
import de.bioforscher.singa.structure.model.interfaces.Atom;
import de.bioforscher.singa.structure.model.interfaces.LeafSubstructure;
import de.bioforscher.singa.structure.model.oak.LeafSubstructureFactory;
import de.bioforscher.singa.structure.model.oak.OakAtom;
import de.bioforscher.singa.structure.model.oak.OakLeafSubstructure;
import de.bioforscher.singa.structure.model.oak.StructuralMotif;
import de.bioforscher.singa.structure.parser.pdb.structures.StructureWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/bioforscher/singa/structure/algorithms/superimposition/consensus/ConsensusAlignment.class */
public class ConsensusAlignment {
    private static final Logger logger = LoggerFactory.getLogger(ConsensusAlignment.class);
    private final List<ConsensusContainer> input;
    private final boolean idealSuperimposition;
    private final List<BinaryTree<ConsensusContainer>> consensusTrees;
    private final List<Double> alignmentTrace;
    private final List<Integer> alignmentCounts;
    private final Predicate<Atom> atomFilter;
    private final boolean alignWithinClusters;
    private final double clusterCutoff;
    private RepresentationScheme representationScheme;
    private double consensusScore;
    private int iterationCounter;
    private TreeMap<SubstructureSuperimposition, Pair<ConsensusContainer>> alignments;
    private LabeledSymmetricMatrix<ConsensusContainer> distanceMatrix;
    private List<BinaryTreeNode<ConsensusContainer>> leaves;
    private ConsensusContainer currentConsensus;
    private List<BinaryTree<ConsensusContainer>> clusters;

    /* JADX INFO: Access modifiers changed from: package-private */
    public ConsensusAlignment(ConsensusBuilder.Builder builder) {
        this.input = (List) builder.structuralMotifs.stream().map(ConsensusAlignment::toContainer).collect(Collectors.toList());
        logger.info("consensus alignment initialized with {} structures", Integer.valueOf(this.input.size()));
        this.clusterCutoff = builder.clusterCutoff;
        this.alignWithinClusters = builder.alignWithinClusters;
        this.atomFilter = builder.atomFilter;
        RepresentationSchemeType representationSchemeType = builder.representationSchemeType;
        if (representationSchemeType != null) {
            logger.info("using representation scheme {}", representationSchemeType);
            this.representationScheme = RepresentationSchemeFactory.createRepresentationScheme(representationSchemeType);
        }
        this.idealSuperimposition = builder.idealSuperimposition;
        if (((Set) this.input.stream().map((v0) -> {
            return v0.getStructuralMotif();
        }).map((v0) -> {
            return v0.getAllLeafSubstructures();
        }).map((v0) -> {
            return v0.size();
        }).collect(Collectors.toSet())).size() != 1) {
            throw new ConsensusException("all substructures must contain the same number of leaf structures to calculate a consensus alignment");
        }
        this.iterationCounter = 0;
        this.alignmentTrace = new ArrayList();
        this.alignmentCounts = new ArrayList();
        this.consensusTrees = new ArrayList();
        calculateInitialAlignments();
        logger.info("{} initial alignment pairs were computed, in total we have to compute {} alignments", Integer.valueOf(this.alignments.size()), Integer.valueOf(this.alignments.size() * (this.input.size() - 1)));
        createTreeLeaves();
        calculateConsensusAlignment();
        splitTopLevelTree();
        if (this.alignWithinClusters) {
            alignWithinClusters();
        }
    }

    private static ConsensusContainer toContainer(StructuralMotif structuralMotif) {
        return new ConsensusContainer(structuralMotif.getCopy(), false);
    }

    public List<Double> getAlignmentTrace() {
        return this.alignmentTrace;
    }

    public List<BinaryTree<ConsensusContainer>> getClusters() {
        return this.clusters;
    }

    public void writeClusters(Path path) throws IOException {
        logger.info("writing {} clusters to {}", Integer.valueOf(this.clusters.size()), path);
        Files.createDirectories(path, new FileAttribute[0]);
        for (int i = 0; i < this.clusters.size(); i++) {
            String str = "cluster_" + (i + 1) + "/";
            BinaryTree<ConsensusContainer> binaryTree = this.clusters.get(i);
            if (binaryTree.getLeafNodes().size() > 1) {
                StructureWriter.writeLeafSubstructures(((ConsensusContainer) binaryTree.getRoot().getData()).getStructuralMotif().getAllLeafSubstructures(), path.resolve(str + "consensus_" + (i + 1) + ".pdb"));
            }
            for (BinaryTreeNode binaryTreeNode : binaryTree.getLeafNodes()) {
                if (((ConsensusContainer) binaryTreeNode.getData()).getSuperimposition() != null) {
                    StructureWriter.writeLeafSubstructures(((ConsensusContainer) binaryTreeNode.getData()).getSuperimposition().getMappedFullCandidate(), path.resolve(str + ((ConsensusContainer) binaryTreeNode.getData()).toString() + ".pdb"));
                } else {
                    StructureWriter.writeLeafSubstructures(((ConsensusContainer) binaryTreeNode.getData()).getStructuralMotif().getAllLeafSubstructures(), path.resolve(str + ((ConsensusContainer) binaryTreeNode.getData()).toString() + ".pdb"));
                }
            }
        }
    }

    private void alignWithinClusters() {
        this.clusters.stream().filter(binaryTree -> {
            return binaryTree.size() > 1;
        }).forEach(binaryTree2 -> {
            ConsensusContainer consensusContainer = (ConsensusContainer) binaryTree2.getRoot().getData();
            binaryTree2.getLeafNodes().stream().map((v0) -> {
                return v0.getData();
            }).forEach(consensusContainer2 -> {
                SubstructureSuperimposition calculateIdealSubstructureSuperimposition;
                if (this.representationScheme == null) {
                    calculateIdealSubstructureSuperimposition = this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(consensusContainer.getStructuralMotif(), consensusContainer2.getStructuralMotif(), this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(consensusContainer.getStructuralMotif().getAllLeafSubstructures(), consensusContainer2.getStructuralMotif().getAllLeafSubstructures(), this.atomFilter);
                } else {
                    calculateIdealSubstructureSuperimposition = this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(consensusContainer.getStructuralMotif(), consensusContainer2.getStructuralMotif(), this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(consensusContainer.getStructuralMotif().getAllLeafSubstructures(), consensusContainer2.getStructuralMotif().getAllLeafSubstructures(), this.representationScheme);
                }
                consensusContainer2.setSuperimposition(calculateIdealSubstructureSuperimposition);
            });
        });
    }

    private void splitTopLevelTree() {
        this.clusters = new ArrayList();
        this.clusters.add(getTopConsensusTree());
        ListIterator<BinaryTree<ConsensusContainer>> listIterator = this.clusters.listIterator();
        while (listIterator.hasNext()) {
            BinaryTreeNode root = listIterator.next().getRoot();
            BinaryTreeNode left = root.getLeft();
            BinaryTreeNode right = root.getRight();
            double consensusDistance = left != null ? ((ConsensusContainer) left.getData()).getConsensusDistance() : 0.0d;
            double consensusDistance2 = right != null ? ((ConsensusContainer) right.getData()).getConsensusDistance() : 0.0d;
            if (consensusDistance > this.clusterCutoff || consensusDistance2 > this.clusterCutoff) {
                listIterator.remove();
                listIterator.add(new BinaryTree<>(root.getLeft()));
                listIterator.previous();
                listIterator.add(new BinaryTree<>(root.getRight()));
                listIterator.previous();
            }
        }
    }

    public BinaryTree<ConsensusContainer> getTopConsensusTree() {
        return this.consensusTrees.get(this.consensusTrees.size() - 1);
    }

    public List<BinaryTree<ConsensusContainer>> getConsensusTrees() {
        return this.consensusTrees;
    }

    private void calculateConsensusAlignment() {
        while (!this.alignments.isEmpty()) {
            findAndMergeClosestPair();
        }
    }

    public double getConsensusScore() {
        return this.consensusScore;
    }

    public double getNormalizedConsensusScore() {
        return this.consensusScore / (this.iterationCounter * this.input.get(0).getStructuralMotif().size());
    }

    private void findAndMergeClosestPair() {
        this.iterationCounter++;
        Pair<ConsensusContainer> value = this.alignments.firstEntry().getValue();
        double rmsd = this.alignments.firstKey().getRmsd();
        this.alignmentTrace.add(Double.valueOf(rmsd));
        this.alignmentCounts.add(Integer.valueOf(this.input.size()));
        logger.debug("closest pair for iteration {} is {} with RMSD {}", new Object[]{Integer.valueOf(this.iterationCounter), value, Double.valueOf(rmsd)});
        this.consensusScore += rmsd;
        createConsensus(this.alignments.firstEntry());
        updateAlignments(this.alignments.firstEntry());
    }

    private void updateAlignments(Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> entry) {
        Iterator<Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>>> it = this.alignments.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> next = it.next();
            boolean z = ((ConsensusContainer) next.getValue().getFirst()).equals(entry.getValue().getFirst()) || ((ConsensusContainer) next.getValue().getFirst()).equals(entry.getValue().getSecond());
            boolean z2 = ((ConsensusContainer) next.getValue().getSecond()).equals(entry.getValue().getSecond()) || ((ConsensusContainer) next.getValue().getSecond()).equals(entry.getValue().getFirst());
            if (z || z2) {
                it.remove();
            }
        }
        this.input.removeIf(consensusContainer -> {
            return consensusContainer.equals(((Pair) entry.getValue()).getFirst());
        });
        this.input.removeIf(consensusContainer2 -> {
            return consensusContainer2.equals(((Pair) entry.getValue()).getSecond());
        });
        for (ConsensusContainer consensusContainer3 : this.input) {
            this.alignments.put(this.representationScheme == null ? this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(this.currentConsensus.getStructuralMotif(), consensusContainer3.getStructuralMotif(), this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(this.currentConsensus.getStructuralMotif().getAllLeafSubstructures(), consensusContainer3.getStructuralMotif().getAllLeafSubstructures(), this.atomFilter) : this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(this.currentConsensus.getStructuralMotif(), consensusContainer3.getStructuralMotif(), this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(this.currentConsensus.getStructuralMotif().getAllLeafSubstructures(), consensusContainer3.getStructuralMotif().getAllLeafSubstructures(), this.representationScheme), new Pair<>(this.currentConsensus, consensusContainer3));
        }
        this.input.add(this.currentConsensus);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void createConsensus(Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> entry) {
        List list;
        List list2;
        BinaryTreeNode binaryTreeNode;
        List<LeafSubstructure<?>> allLeafSubstructures = ((ConsensusContainer) entry.getValue().getFirst()).getStructuralMotif().getAllLeafSubstructures();
        List<LeafSubstructure<?>> mappedFullCandidate = entry.getKey().getMappedFullCandidate();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        IntStream.range(0, allLeafSubstructures.size()).forEach(i -> {
        });
        linkedHashMap.entrySet().forEach(this::defineIntersectingAtoms);
        if (this.representationScheme == null) {
            list = (List) linkedHashMap.entrySet().stream().map(entry2 -> {
                return (List) ((LeafSubstructure) ((Pair) entry2.getKey()).getFirst()).getAllAtoms().stream().filter(this.atomFilter).filter(atom -> {
                    return ((Set) entry2.getValue()).contains(atom.getAtomName());
                }).sorted(Comparator.comparing((v0) -> {
                    return v0.getAtomName();
                })).collect(Collectors.toList());
            }).collect(Collectors.toList());
            list2 = (List) linkedHashMap.entrySet().stream().map(entry3 -> {
                return (List) ((LeafSubstructure) ((Pair) entry3.getKey()).getSecond()).getAllAtoms().stream().filter(this.atomFilter).filter(atom -> {
                    return ((Set) entry3.getValue()).contains(atom.getAtomName());
                }).sorted(Comparator.comparing((v0) -> {
                    return v0.getAtomName();
                })).collect(Collectors.toList());
            }).collect(Collectors.toList());
        } else {
            list = (List) linkedHashMap.entrySet().stream().map(entry4 -> {
                ArrayList arrayList = new ArrayList();
                arrayList.add(this.representationScheme.determineRepresentingAtom((LeafSubstructure) ((Pair) entry4.getKey()).getFirst()));
                return arrayList;
            }).collect(Collectors.toList());
            list2 = (List) linkedHashMap.entrySet().stream().map(entry5 -> {
                ArrayList arrayList = new ArrayList();
                arrayList.add(this.representationScheme.determineRepresentingAtom((LeafSubstructure) ((Pair) entry5.getKey()).getSecond()));
                return arrayList;
            }).collect(Collectors.toList());
        }
        ArrayList arrayList = new ArrayList();
        int i2 = 1;
        int i3 = 1;
        for (int i4 = 0; i4 < list.size(); i4++) {
            List list3 = (List) list.get(i4);
            List list4 = (List) list2.get(i4);
            ArrayList arrayList2 = new ArrayList();
            for (int i5 = 0; i5 < list3.size(); i5++) {
                Atom atom = (Atom) list3.get(i5);
                arrayList2.add(new OakAtom(i2, atom.getElement(), atom.getAtomName(), atom.getPosition().add(((Atom) list4.get(i5)).getPosition()).divide(2.0d)));
                i2++;
            }
            AminoAcidFamily family = allLeafSubstructures.get(i4).getFamily().equals(mappedFullCandidate.get(i4).getFamily()) ? mappedFullCandidate.get(i4).getFamily() : null;
            if (family == null) {
                family = AminoAcidFamily.UNKNOWN;
            }
            OakLeafSubstructure<?> createLeafSubstructure = LeafSubstructureFactory.createLeafSubstructure(new LeafIdentifier(i3), family);
            createLeafSubstructure.getClass();
            arrayList2.forEach(createLeafSubstructure::addAtom);
            arrayList.add(createLeafSubstructure);
            i3++;
        }
        this.currentConsensus = new ConsensusContainer(StructuralMotif.fromLeafSubstructures(arrayList), true);
        if (this.iterationCounter == 1) {
            binaryTreeNode = new BinaryTreeNode(this.currentConsensus, findLeave((ConsensusContainer) entry.getValue().getFirst()), findLeave((ConsensusContainer) entry.getValue().getSecond()));
        } else {
            BinaryTreeNode<ConsensusContainer> findNode = findNode((ConsensusContainer) entry.getValue().getFirst());
            if (findNode == null) {
                findNode = findLeave((ConsensusContainer) entry.getValue().getFirst());
            }
            BinaryTreeNode<ConsensusContainer> findNode2 = findNode((ConsensusContainer) entry.getValue().getSecond());
            if (findNode2 == null) {
                findNode2 = findLeave((ConsensusContainer) entry.getValue().getSecond());
            }
            binaryTreeNode = new BinaryTreeNode(this.currentConsensus, findNode, findNode2);
        }
        BinaryTree<ConsensusContainer> binaryTree = new BinaryTree<>(binaryTreeNode);
        this.currentConsensus.setConsensusTree(binaryTree);
        this.consensusTrees.add(binaryTree);
        ((ConsensusContainer) binaryTree.getRoot().getLeft().getData()).addToConsensusDistance(entry.getKey().getRmsd() / 2.0d);
        ((ConsensusContainer) binaryTree.getRoot().getRight().getData()).addToConsensusDistance(entry.getKey().getRmsd() / 2.0d);
    }

    private BinaryTreeNode<ConsensusContainer> findLeave(ConsensusContainer consensusContainer) {
        return this.leaves.stream().filter(binaryTreeNode -> {
            return ((ConsensusContainer) binaryTreeNode.getData()).equals(consensusContainer);
        }).findFirst().orElseThrow(() -> {
            return new ConsensusException("failed during tree construction");
        });
    }

    private BinaryTreeNode<ConsensusContainer> findNode(ConsensusContainer consensusContainer) {
        BinaryTreeNode<ConsensusContainer> binaryTreeNode = null;
        Iterator<BinaryTree<ConsensusContainer>> it = this.consensusTrees.iterator();
        while (it.hasNext()) {
            binaryTreeNode = it.next().findNode(consensusContainer);
            if (binaryTreeNode != null) {
                break;
            }
        }
        return binaryTreeNode;
    }

    private void defineIntersectingAtoms(Map.Entry<Pair<LeafSubstructure>, Set<String>> entry) {
        if (this.representationScheme == null) {
            entry.getValue().addAll((Collection) ((LeafSubstructure) entry.getKey().getFirst()).getAllAtoms().stream().filter(this.atomFilter).map((v0) -> {
                return v0.getAtomName();
            }).collect(Collectors.toSet()));
            entry.getValue().retainAll((Collection) ((LeafSubstructure) entry.getKey().getSecond()).getAllAtoms().stream().filter(this.atomFilter).map((v0) -> {
                return v0.getAtomName();
            }).collect(Collectors.toSet()));
        } else {
            entry.getValue().add(this.representationScheme.determineRepresentingAtom((LeafSubstructure) entry.getKey().getFirst()).getAtomName());
            entry.getValue().add(this.representationScheme.determineRepresentingAtom((LeafSubstructure) entry.getKey().getSecond()).getAtomName());
        }
    }

    private void createTreeLeaves() {
        this.leaves = (List) this.input.stream().map((v1) -> {
            return new BinaryTreeNode(v1);
        }).collect(Collectors.toList());
    }

    private void calculateInitialAlignments() {
        this.alignments = new TreeMap<>(Comparator.comparing((v0) -> {
            return v0.getRmsd();
        }));
        double[][] dArr = new double[this.input.size()][this.input.size()];
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.input.get(0));
        int i = 0;
        for (int i2 = 0; i2 < this.input.size() - 1; i2++) {
            for (int i3 = i2 + 1; i3 < this.input.size(); i3++) {
                StructuralMotif structuralMotif = this.input.get(i2).getStructuralMotif();
                StructuralMotif structuralMotif2 = this.input.get(i3).getStructuralMotif();
                SubstructureSuperimposition calculateIdealSubstructureSuperimposition = this.representationScheme == null ? this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(structuralMotif, structuralMotif2, this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(structuralMotif.getAllLeafSubstructures(), structuralMotif2.getAllLeafSubstructures(), this.atomFilter) : this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(structuralMotif, structuralMotif2, this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(structuralMotif.getAllLeafSubstructures(), structuralMotif2.getAllLeafSubstructures(), this.representationScheme);
                this.alignments.put(calculateIdealSubstructureSuperimposition, new Pair<>(new ConsensusContainer(structuralMotif, false), new ConsensusContainer(structuralMotif2, false)));
                dArr[i2][i3] = calculateIdealSubstructureSuperimposition.getRmsd();
                dArr[i3][i2] = calculateIdealSubstructureSuperimposition.getRmsd();
                i++;
                if (i % 1000 == 0) {
                    logger.info("computed {} of {} initial alignments ", Integer.valueOf(i), Integer.valueOf(this.input.size() * ((this.input.size() - 1) / 2)));
                }
            }
            arrayList.add(this.input.get(i2 + 1));
        }
        this.distanceMatrix = new LabeledSymmetricMatrix<>(dArr);
        this.distanceMatrix.setColumnLabels(arrayList);
    }
}
