/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient;
import dr.evomodel.treedatalikelihood.hmc.MultivariateChainRule;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;

public interface GradientWrtPrecisionProvider {
    public double[] getGradientWrtPrecision(double[] var1, double[] var2);

    public double[] getGradientWrtVariance(double[] var1, double[] var2, double[] var3);

    public ConjugateWishartStatisticsProvider getWishartStatistic();

    public BranchSpecificGradient getBranchSpecificGradient();

    public static class BranchSpecificGradientWrtPrecisionProvider
    extends AbstractGradientWrtPrecisionProvider {
        private final BranchSpecificGradient branchSpecificGradient;

        public BranchSpecificGradientWrtPrecisionProvider(BranchSpecificGradient branchSpecificGradient) {
            this.branchSpecificGradient = branchSpecificGradient;
            this.dim = ((TreeDataLikelihood)branchSpecificGradient.getLikelihood()).getDataLikelihoodDelegate().getTraitDim();
        }

        @Override
        public double[] getGradientWrtPrecision(double[] dArray, double[] dArray2) {
            MultivariateChainRule.InverseGeneral inverseGeneral = new MultivariateChainRule.InverseGeneral(dArray);
            return inverseGeneral.chainGradient(dArray2);
        }

        @Override
        public double[] getGradientWrtVariance(double[] dArray, double[] dArray2, double[] dArray3) {
            return dArray3;
        }

        @Override
        public BranchSpecificGradient getBranchSpecificGradient() {
            return this.branchSpecificGradient;
        }
    }

    public static class WishartGradientWrtPrecisionProvider
    extends AbstractGradientWrtPrecisionProvider {
        private final ConjugateWishartStatisticsProvider wishartStatistics;

        public WishartGradientWrtPrecisionProvider(ConjugateWishartStatisticsProvider conjugateWishartStatisticsProvider) {
            this.wishartStatistics = conjugateWishartStatisticsProvider;
            this.dim = conjugateWishartStatisticsProvider.getPrecisionParameter().getRowDimension();
        }

        @Override
        public double[] getGradientWrtPrecision(double[] dArray, double[] dArray2) {
            WishartSufficientStatistics wishartSufficientStatistics = this.wishartStatistics.getWishartStatistics();
            double[] dArray3 = wishartSufficientStatistics.getScaleMatrix();
            int n = wishartSufficientStatistics.getDf();
            return this.getGradientWrtPrecision(dArray, n, dArray3);
        }

        private double[] getGradientWrtPrecision(double[] dArray, int n, double[] dArray2) {
            assert (dArray.length == this.dim * this.dim);
            assert (dArray2.length == this.dim * this.dim);
            assert (n > 0);
            double[] dArray3 = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim * this.dim; ++i) {
                dArray3[i] = 0.5 * ((double)n * dArray[i] - dArray2[i]);
            }
            return dArray3;
        }

        @Override
        public double[] getGradientWrtVariance(double[] dArray, double[] dArray2, double[] dArray3) {
            MultivariateChainRule.InverseGeneral inverseGeneral = new MultivariateChainRule.InverseGeneral(dArray);
            return inverseGeneral.chainGradient(this.getGradientWrtPrecision(dArray2, dArray3));
        }

        @Override
        public ConjugateWishartStatisticsProvider getWishartStatistic() {
            return this.wishartStatistics;
        }
    }

    public static abstract class AbstractGradientWrtPrecisionProvider
    implements GradientWrtPrecisionProvider {
        int dim;

        @Override
        public ConjugateWishartStatisticsProvider getWishartStatistic() {
            return null;
        }

        @Override
        public BranchSpecificGradient getBranchSpecificGradient() {
            return null;
        }
    }
}

