package cc.mallet.grmm.inference.gbp;

import cc.mallet.grmm.types.DiscreteFactor;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.TableFactor;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/grmm/inference/gbp/SparseMessageSender.class */
public class SparseMessageSender extends AbstractMessageStrategy {
    private double epsilon;

    public SparseMessageSender(double d) {
        this.epsilon = d;
    }

    @Override // cc.mallet.grmm.inference.gbp.MessageStrategy
    public void sendMessage(RegionEdge regionEdge) {
        TableFactor tableFactor;
        Factor msgProduct = msgProduct(regionEdge);
        Iterator it = regionEdge.factorsToSend.iterator();
        while (it.hasNext()) {
            msgProduct.multiplyBy((Factor) it.next());
        }
        TableFactor tableFactor2 = (TableFactor) msgProduct.marginalize(regionEdge.to.vars);
        tableFactor2.normalize();
        if (shouldPruneMessage(regionEdge, tableFactor2)) {
            tableFactor = Factors.retainMass(tableFactor2, this.epsilon);
            tableFactor.normalize();
        } else {
            tableFactor = tableFactor2;
        }
        this.newMessages.setMessage(regionEdge.from, regionEdge.to, tableFactor);
    }

    @Override // cc.mallet.grmm.inference.gbp.MessageStrategy
    public MessageArray averageMessages(RegionGraph regionGraph, MessageArray messageArray, MessageArray messageArray2, double d) {
        MessageArray messageArray3 = new MessageArray(regionGraph);
        Iterator edgeIterator = regionGraph.edgeIterator();
        while (edgeIterator.hasNext()) {
            RegionEdge regionEdge = (RegionEdge) edgeIterator.next();
            DiscreteFactor message = messageArray.getMessage(regionEdge.from, regionEdge.to);
            DiscreteFactor message2 = messageArray2.getMessage(regionEdge.from, regionEdge.to);
            if (message != null) {
                TableFactor tableFactor = (TableFactor) Factors.average(message, message2, d);
                messageArray3.setMessage(regionEdge.from, regionEdge.to, shouldPruneMessage(regionEdge, tableFactor) ? Factors.retainMass(tableFactor, this.epsilon) : tableFactor);
            }
        }
        int i = 0;
        int i2 = 0;
        Iterator edgeIterator2 = regionGraph.edgeIterator();
        while (edgeIterator2.hasNext()) {
            RegionEdge regionEdge2 = (RegionEdge) edgeIterator2.next();
            DiscreteFactor message3 = messageArray3.getMessage(regionEdge2.from, regionEdge2.to);
            i += message3.numLocations();
            i2 += new HashVarSet(message3.varSet()).weight();
        }
        System.out.println("Sparsity quotient = " + i + " of " + i2);
        return messageArray3;
    }

    private boolean shouldPruneMessage(RegionEdge regionEdge, Factor factor) {
        return regionEdge.to.children.isEmpty();
    }
}
