/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.speciation;

import dr.evolution.coalescent.IntervalType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.speciation.EfficientSpeciationLikelihood;
import dr.evomodel.speciation.SpeciationModel;
import dr.evomodel.speciation.SpeciationModelGradientProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;

class CachedGradientDelegate
extends AbstractModel
implements TreeTrait<double[]> {
    private final SpeciationModelGradientProvider provider;
    private final BigFastTreeIntervals treeIntervals;
    private final SpeciationModel speciationModel;
    public static final boolean MEASURE_RUN_TIME = false;
    public double gradientTime;
    public int gradientCounts;
    private double[] gradient;
    private double[] storedGradient;
    private boolean gradientKnown;
    private boolean storedGradientKnown;

    CachedGradientDelegate(EfficientSpeciationLikelihood efficientSpeciationLikelihood) {
        super("cachedGradientDelegate");
        this.provider = efficientSpeciationLikelihood.getGradientProvider();
        this.treeIntervals = efficientSpeciationLikelihood.getTreeIntervals();
        this.speciationModel = efficientSpeciationLikelihood.getSpeciationModel();
        this.addModel(this.treeIntervals);
        this.addModel(this.speciationModel);
        this.gradientTime = 0.0;
        this.gradientCounts = 0;
        this.gradientKnown = false;
    }

    private double[] getGradientLogDensityImpl() {
        double[] dArray = new double[this.provider.getGradientLength()];
        this.provider.precomputeGradientConstants();
        this.provider.updateGradientModelValues(0);
        double[] dArray2 = this.provider.getBreakPoints();
        assert (dArray2[dArray2.length - 1] == Double.POSITIVE_INFINITY);
        int n = 0;
        while (this.treeIntervals.getStartTime() >= dArray2[n]) {
            this.speciationModel.updateLikelihoodModelValues(++n);
        }
        this.provider.processGradientSampling(dArray, n, this.treeIntervals.getStartTime());
        for (int i = 0; i < this.treeIntervals.getIntervalCount(); ++i) {
            double d = this.treeIntervals.getIntervalTime(i);
            double d2 = d + this.treeIntervals.getInterval(i);
            int n2 = this.treeIntervals.getLineageCount(i);
            while (d2 >= dArray2[n]) {
                double d3 = dArray2[n];
                this.provider.processGradientModelSegmentBreakPoint(dArray, n, d, d3, n2);
                d = d3;
                this.provider.updateGradientModelValues(++n);
            }
            if (d2 > d) {
                this.provider.processGradientInterval(dArray, n, d, d2, n2);
            }
            if (this.treeIntervals.getIntervalType(i) == IntervalType.SAMPLE) {
                this.provider.processGradientSampling(dArray, n, d2);
                continue;
            }
            if (this.treeIntervals.getIntervalType(i) == IntervalType.COALESCENT) {
                this.provider.processGradientCoalescence(dArray, n, d2);
                continue;
            }
            throw new RuntimeException("Birth-death tree includes non birth/death/sampling event.");
        }
        this.provider.processGradientOrigin(dArray, n, this.treeIntervals.getTotalDuration());
        this.provider.logConditioningProbability(n, dArray);
        return dArray;
    }

    @Override
    public String getTraitName() {
        return "speciationGradient";
    }

    @Override
    public TreeTrait.Intent getIntent() {
        return TreeTrait.Intent.WHOLE_TREE;
    }

    @Override
    public Class getTraitClass() {
        return double[].class;
    }

    @Override
    public double[] getTrait(Tree tree, NodeRef nodeRef) {
        if (!this.gradientKnown) {
            this.gradient = this.getGradientLogDensityImpl();
            this.gradientKnown = true;
        }
        return this.gradient;
    }

    @Override
    public String getTraitString(Tree tree, NodeRef nodeRef) {
        return null;
    }

    @Override
    public boolean getLoggable() {
        return false;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model != this.treeIntervals && model != this.speciationModel) {
            throw new IllegalArgumentException("Unknown model: " + model.getId());
        }
        this.gradientKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new IllegalArgumentException("Unknown variable: " + variable.getId());
    }

    @Override
    protected void storeState() {
        if (this.gradient != null) {
            if (this.storedGradient == null) {
                this.storedGradient = new double[this.gradient.length];
            }
            System.arraycopy(this.gradient, 0, this.storedGradient, 0, this.gradient.length);
        }
        this.storedGradientKnown = this.gradientKnown;
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.gradient;
        this.gradient = this.storedGradient;
        this.storedGradient = dArray;
        this.gradientKnown = this.storedGradientKnown;
    }

    @Override
    protected void acceptState() {
    }

    public double getGradientTime() {
        return this.gradientTime;
    }
}

