/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitBranchRateGradient;
import dr.evomodel.treedatalikelihood.discrete.DiscreteTraitNodeHeightDelegate;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.evomodel.treedatalikelihood.preorder.ProcessSimulationDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.Loggable;
import dr.inference.model.Parameter;
import dr.xml.Reportable;
import java.util.Arrays;

public class NodeHeightGradientForDiscreteTrait
extends DiscreteTraitBranchRateGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    private final TreeModel treeModel;
    protected TreeParameterModel indexHelper;
    private final NodeHeightProxyParameter nodeHeightProxyParameter;
    private final double tolerance = 0.01;
    private final double smallValueThreshold = 0.001;
    private static final boolean DEBUG = true;

    public NodeHeightGradientForDiscreteTrait(String string, TreeDataLikelihood treeDataLikelihood, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, Parameter parameter) {
        super(string, treeDataLikelihood, beagleDataLikelihoodDelegate, parameter, false);
        if (!(treeDataLikelihood.getTree() instanceof TreeModel)) {
            throw new IllegalArgumentException("Must provide a TreeModel");
        }
        this.treeModel = (TreeModel)treeDataLikelihood.getTree();
        this.indexHelper = new TreeParameterModel((MutableTreeModel)this.treeModel, (Parameter)new Parameter.Default(this.tree.getNodeCount() - 1), false);
        this.nodeHeightProxyParameter = new NodeHeightProxyParameter("internalNodeHeights", this.treeModel, true);
    }

    @Override
    protected String getTraitName(String string) {
        return "NodeHeightGradient";
    }

    @Override
    protected ProcessSimulationDelegate makeGradientDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate) {
        return new DiscreteTraitNodeHeightDelegate(string, tree, beagleDataLikelihoodDelegate, this.branchRateModel);
    }

    @Override
    public Parameter getParameter() {
        return this.nodeHeightProxyParameter;
    }

    @Override
    protected double getChainGradient(Tree tree, NodeRef nodeRef) {
        return this.branchRateModel.getBranchRate(tree, nodeRef);
    }

    @Override
    public double[] getGradientLogDensity() {
        if (this.treeTraitProvider.getTraitName() == super.getTraitName(null)) {
            double[] dArray = new double[this.tree.getInternalNodeCount()];
            Arrays.fill(dArray, 0.0);
            double[] dArray2 = super.getGradientLogDensity();
            for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
                int n;
                NodeRef nodeRef = this.tree.getNode(i + this.tree.getExternalNodeCount());
                for (n = 0; n < this.tree.getChildCount(nodeRef); ++n) {
                    NodeRef nodeRef2 = this.tree.getChild(nodeRef, n);
                    int n2 = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef2.getNumber());
                    int n3 = i;
                    dArray[n3] = dArray[n3] + dArray2[n2];
                }
                if (this.tree.isRoot(nodeRef)) continue;
                n = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef.getNumber());
                int n4 = i;
                dArray[n4] = dArray[n4] - dArray2[n];
            }
            return dArray;
        }
        double[] dArray = (double[])this.treeTraitProvider.getTrait(this.tree, null);
        return Arrays.copyOf(dArray, this.tree.getInternalNodeCount());
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        double[] dArray = new double[this.tree.getInternalNodeCount()];
        double[] dArray2 = (double[])this.treeDataLikelihood.getTreeTrait("NodeHeightHessian").getTrait(this.tree, null);
        return dArray2;
    }

    @Override
    protected int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.indexHelper.getParameterIndexFromNodeNumber(nodeRef.getNumber());
    }

    @Override
    public String getReport() {
        String string = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 0.01, 0.001) + HessianWrtParameterProvider.getReportAndCheckForError(this, 0.01);
        return string;
    }
}

