/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxon;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class ConstraintsTreeLikelihood
extends AbstractModelLikelihood {
    private final Tree targetTree;
    private final Set<BitSet> constraintsClades = new HashSet<BitSet>();
    private final Set<Integer> constrainedTips = new HashSet<Integer>();
    private boolean likelihoodKnown = false;
    private boolean storedLikelihoodKnown = false;
    private double logLikelihood = 0.0;
    private double storedLogLikelihood = 0.0;
    private boolean[] updateNode;
    private Set<BitSet> lostClades = new HashSet<BitSet>();
    private Set<BitSet> storedLostClades = new HashSet<BitSet>();
    private Map<Integer, BitSet> restoreCache;
    private final boolean uniqueClades;
    private BitSet[] targetTreeNodeCladeMap;

    public ConstraintsTreeLikelihood(String string, Tree tree, Tree tree2) throws TreeUtils.MissingTaxonException {
        super(string);
        int n;
        block0: for (n = 0; n < tree2.getTaxonCount(); ++n) {
            String string2 = tree2.getTaxonId(n);
            if (tree.getTaxonIndex(string2) == -1) {
                throw new TreeUtils.MissingTaxonException(tree2.getTaxon(n));
            }
            Taxon taxon = tree.getTaxon(tree.getTaxonIndex(string2));
            for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
                NodeRef nodeRef = tree.getExternalNode(i);
                if (!tree.getNodeTaxon(nodeRef).equals(taxon)) continue;
                this.constrainedTips.add(nodeRef.getNumber());
                continue block0;
            }
        }
        this.setupClades(tree2, tree2.getRoot(), tree);
        this.uniqueClades = tree2.getExternalNodeCount() == tree.getExternalNodeCount();
        this.updateNode = new boolean[tree.getNodeCount()];
        this.targetTreeNodeCladeMap = new BitSet[tree.getNodeCount()];
        this.restoreCache = new HashMap<Integer, BitSet>();
        this.lostClades = new HashSet<BitSet>(this.constraintsClades);
        for (n = 0; n < this.updateNode.length; ++n) {
            this.updateNode[n] = true;
        }
        if (tree instanceof TreeModel) {
            this.addModel((TreeModel)tree);
        }
        this.targetTree = tree;
    }

    private void updateAllNodes() {
        for (int i = 0; i < this.updateNode.length; ++i) {
            this.updateNode[i] = true;
        }
        this.lostClades = new HashSet<BitSet>(this.constraintsClades);
        this.likelihoodKnown = false;
    }

    private void updateNodeAndAncestors(NodeRef nodeRef) {
        while (nodeRef != null) {
            int n = nodeRef.getNumber();
            this.updateNode[n] = true;
            BitSet bitSet = this.targetTreeNodeCladeMap[n];
            if (this.constraintsClades.contains(bitSet)) {
                this.lostClades.add(bitSet);
            }
            this.restoreCache.put(n, bitSet);
            nodeRef = this.targetTree.getParent(nodeRef);
        }
        this.likelihoodKnown = false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (object instanceof TreeChangedEvent) {
            if (((TreeChangedEvent)object).isNodeChanged()) {
                if (this.uniqueClades) {
                    this.updateNodeAndAncestors(((TreeChangedEvent)object).getNode());
                } else {
                    this.updateAllNodes();
                }
            } else if (((TreeChangedEvent)object).isTreeChanged()) {
                this.updateAllNodes();
            }
        }
    }

    @Override
    protected void storeState() {
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedLogLikelihood = this.logLikelihood;
        this.storedLostClades = new HashSet<BitSet>(this.lostClades);
        this.restoreCache = new HashMap<Integer, BitSet>();
    }

    @Override
    protected void restoreState() {
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.logLikelihood = this.storedLogLikelihood;
        for (int n : this.restoreCache.keySet()) {
            this.targetTreeNodeCladeMap[n] = this.restoreCache.get(n);
        }
        this.lostClades = new HashSet<BitSet>(this.storedLostClades);
    }

    @Override
    protected void acceptState() {
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = this.isCompatible() ? 0.0 : Double.NEGATIVE_INFINITY;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.updateAllNodes();
    }

    private boolean isCompatible() {
        this.getClades(this.targetTree.getRoot());
        return this.lostClades.size() == 0;
    }

    private BitSet setupClades(Tree tree, NodeRef nodeRef, Tree tree2) {
        BitSet bitSet = new BitSet();
        if (tree.isExternal(nodeRef)) {
            String string = tree.getNodeTaxon(nodeRef).getId();
            bitSet.set(tree2.getTaxonIndex(string));
        } else {
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                bitSet.or(this.setupClades(tree, nodeRef2, tree2));
            }
            this.constraintsClades.add(bitSet);
        }
        return bitSet;
    }

    private BitSet getClades(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        boolean bl = this.targetTree.isExternal(nodeRef);
        if (this.updateNode[n]) {
            BitSet bitSet = new BitSet();
            if (bl) {
                if (this.constrainedTips.contains(nodeRef.getNumber())) {
                    bitSet.set(nodeRef.getNumber());
                }
            } else {
                for (int i = 0; i < this.targetTree.getChildCount(nodeRef); ++i) {
                    NodeRef nodeRef2 = this.targetTree.getChild(nodeRef, i);
                    bitSet.or(this.getClades(nodeRef2));
                }
            }
            this.updateNode[n] = false;
            this.targetTreeNodeCladeMap[n] = bitSet;
            this.lostClades.remove(bitSet);
        }
        return this.targetTreeNodeCladeMap[n];
    }
}

