/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.hmc;

import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.GMRFMultilocusSkyrideLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.MatrixVectorProductParameter;
import dr.inference.model.Parameter;
import dr.util.ComparableDouble;
import dr.util.HeapSort;
import dr.xml.Reportable;
import dr.xml.XMLParseException;
import java.util.ArrayList;
import java.util.List;

public class GMRFGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    private final GMRFMultilocusSkyrideLikelihood skygridLikelihood;
    private final WrtParameter wrtParameter;
    private final Parameter parameter;
    private final Double tolerance;

    public GMRFGradient(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, WrtParameter wrtParameter, Double d) {
        this.skygridLikelihood = gMRFMultilocusSkyrideLikelihood;
        this.wrtParameter = wrtParameter;
        this.parameter = wrtParameter.getParameter(gMRFMultilocusSkyrideLikelihood);
        this.tolerance = d;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.skygridLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.wrtParameter.getGradientLogDensity(this.skygridLikelihood);
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        return this.wrtParameter.getDiagonalHessianLogDensity(this.skygridLikelihood);
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public String getReport() {
        String string = this.skygridLikelihood + "." + this.wrtParameter.name + "\n";
        string = string + GradientWrtParameterProvider.getReportAndCheckForError(this, this.wrtParameter.getParameterLowerBound(), Double.POSITIVE_INFINITY, this.tolerance) + " \n";
        if (this.wrtParameter != WrtParameter.NODE_HEIGHT && this.wrtParameter != WrtParameter.DETERMINISTIC_SKYGRID) {
            string = string + HessianWrtParameterProvider.getReportAndCheckForError(this, this.tolerance) + "\n";
        }
        return string;
    }

    @Override
    public LogColumn[] getColumns() {
        return Loggable.getColumnsFromReport(this, "GMRFGradient report");
    }

    public static enum WrtParameter {
        LOG_POPULATION_SIZES("logPopulationSizes"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtLogPopulationSize();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtLogPopulationSize();
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) throws XMLParseException {
                if (gMRFMultilocusSkyrideLikelihood.getPopSizeParameter() instanceof MatrixVectorProductParameter) {
                    throw new XMLParseException("Cannot use 'logPopulationSizes' with deterministic skygrid");
                }
            }
        }
        ,
        DETERMINISTIC_SKYGRID("deterministicSkygrid"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                Parameter parameter = gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
                if (parameter instanceof MatrixVectorProductParameter) {
                    return ((MatrixVectorProductParameter)parameter).getVector();
                }
                return parameter;
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                Parameter parameter = gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
                if (parameter instanceof MatrixVectorProductParameter) {
                    return this.multiplyMatrixByDifferential(gMRFMultilocusSkyrideLikelihood.getGradientWrtLogPopulationSize(), (MatrixVectorProductParameter)gMRFMultilocusSkyrideLikelihood.getPopSizeParameter());
                }
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtLogPopulationSize();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                MatrixVectorProductParameter matrixVectorProductParameter = (MatrixVectorProductParameter)gMRFMultilocusSkyrideLikelihood.getPopSizeParameter();
                double[] dArray = gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtLogPopulationSize();
                throw new RuntimeException("Not yet implemented");
            }

            private double[] multiplyMatrixByDifferential(double[] dArray, MatrixVectorProductParameter matrixVectorProductParameter) {
                MatrixParameter matrixParameter = matrixVectorProductParameter.getMatrix();
                Parameter parameter = matrixVectorProductParameter.getVector();
                int n = matrixParameter.getRowDimension();
                int n2 = matrixParameter.getColumnDimension();
                assert (n == dArray.length);
                assert (n2 == parameter.getDimension());
                double[] dArray2 = new double[n2];
                for (int i = 0; i < n2; ++i) {
                    double d = 0.0;
                    for (int j = 0; j < n; ++j) {
                        d += matrixParameter.getParameterValue(j, i) * dArray[j];
                    }
                    dArray2[i] = d;
                }
                return dArray2;
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) throws XMLParseException {
                if (!(gMRFMultilocusSkyrideLikelihood.getPopSizeParameter() instanceof MatrixVectorProductParameter)) {
                    throw new XMLParseException("Cannot use 'deterministicSkygrid' with stochastic skygrid");
                }
            }
        }
        ,
        PRECISION("precision"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getPrecisionParameter();
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtPrecision();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtPrecision();
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
            }
        }
        ,
        REGRESSION_COEFFICIENTS("regressionCoefficients"){

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                List<Parameter> list = gMRFMultilocusSkyrideLikelihood.getBetaListParameter();
                if (list.size() > 1) {
                    throw new RuntimeException("This is not the correct way of handling multidimensional parameters");
                }
                return list.get(0);
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getGradientWrtRegressionCoefficients();
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return gMRFMultilocusSkyrideLikelihood.getDiagonalHessianWrtRegressionCoefficients();
            }

            @Override
            double getParameterLowerBound() {
                return Double.NEGATIVE_INFINITY;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) throws XMLParseException {
                if (gMRFMultilocusSkyrideLikelihood.getBetaParameter() == null) {
                    throw new XMLParseException("Cannot use 'regressionCoefficients' with deterministic skygrid");
                }
            }
        }
        ,
        NODE_HEIGHT("nodeHeight"){
            Parameter parameter;

            @Override
            Parameter getParameter(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                if (this.parameter == null) {
                    TreeModel treeModel = (TreeModel)gMRFMultilocusSkyrideLikelihood.getTree(0);
                    this.parameter = new NodeHeightProxyParameter("allInternalNode", treeModel, true);
                }
                return this.parameter;
            }

            @Override
            double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return this.getGradientWrtNodeHeights(gMRFMultilocusSkyrideLikelihood);
            }

            @Override
            double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                return new double[gMRFMultilocusSkyrideLikelihood.getTree(0).getInternalNodeCount()];
            }

            @Override
            double getParameterLowerBound() {
                return 0.0;
            }

            @Override
            public void getWarning(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) throws XMLParseException {
                if (gMRFMultilocusSkyrideLikelihood.nLoci() > 1) {
                    throw new XMLParseException("Not yet implemented for multiple loci.");
                }
            }

            private double[] getGradientWrtNodeHeights(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood) {
                int n;
                gMRFMultilocusSkyrideLikelihood.getLogLikelihood();
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(0);
                double[] dArray = new double[tree.getInternalNodeCount()];
                double[] dArray2 = gMRFMultilocusSkyrideLikelihood.getPopSizeParameter().getParameterValues();
                double d = 1.0 / gMRFMultilocusSkyrideLikelihood.getPopulationFactor(0);
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(0);
                int[] nArray = new int[tree.getInternalNodeCount()];
                int[] nArray2 = new int[tree.getInternalNodeCount()];
                this.getGridIndexForInternalNodes(gMRFMultilocusSkyrideLikelihood, 0, nArray, nArray2);
                for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
                    NodeRef nodeRef = tree.getNode(i + tree.getExternalNodeCount());
                    n = this.getNodeHeightParameterIndex(nodeRef, tree);
                    int n2 = treeIntervals.getLineageCount(nArray[i]);
                    double d2 = Math.exp(-dArray2[nArray2[n]]);
                    int n3 = n;
                    dArray[n3] = dArray[n3] + -d2 * (double)n2 * (double)(n2 - 1);
                    if (tree.isRoot(nodeRef)) continue;
                    int n4 = treeIntervals.getLineageCount(nArray[i] + 1);
                    int n5 = n;
                    dArray[n5] = dArray[n5] - -d2 * (double)n4 * (double)(n4 - 1);
                }
                double d3 = 0.5 * d;
                n = 0;
                while (n < dArray.length) {
                    int n6 = n++;
                    dArray[n6] = dArray[n6] * d3;
                }
                return dArray;
            }

            private int getNodeHeightParameterIndex(NodeRef nodeRef, Tree tree) {
                return nodeRef.getNumber() - tree.getExternalNodeCount();
            }

            private void getGridIndexForInternalNodes(GMRFMultilocusSkyrideLikelihood gMRFMultilocusSkyrideLikelihood, int n, int[] nArray, int[] nArray2) {
                Tree tree = gMRFMultilocusSkyrideLikelihood.getTree(n);
                double[] dArray = new double[tree.getInternalNodeCount()];
                double[] dArray2 = new double[tree.getInternalNodeCount()];
                int[] nArray3 = new int[tree.getInternalNodeCount()];
                5.sortNodeHeights(tree, dArray, dArray2, nArray3);
                int n2 = 0;
                double[] dArray3 = gMRFMultilocusSkyrideLikelihood.getGridPoints();
                int n3 = 0;
                TreeIntervals treeIntervals = gMRFMultilocusSkyrideLikelihood.getTreeIntervals(n);
                for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
                    while (n2 < dArray3.length && dArray3[n2] < dArray[i]) {
                        ++n2;
                    }
                    nArray2[nArray3[i]] = n2;
                    while (n3 < treeIntervals.getIntervalCount() - 1 && treeIntervals.getIntervalTime(n3) < dArray[i]) {
                        ++n3;
                    }
                    nArray[nArray3[i]] = n3;
                }
            }
        };

        private final String name;

        public static void sortNodeHeights(Tree tree, double[] dArray, double[] dArray2, int[] nArray) {
            int n;
            ArrayList<ComparableDouble> arrayList = new ArrayList<ComparableDouble>();
            for (n = 0; n < nArray.length; ++n) {
                double d = tree.getNodeHeight(tree.getNode(tree.getExternalNodeCount() + n));
                arrayList.add(new ComparableDouble(d));
                dArray2[n] = d;
            }
            HeapSort.sort(arrayList, nArray);
            for (n = 0; n < nArray.length; ++n) {
                dArray[n] = dArray2[nArray[n]];
            }
        }

        private WrtParameter(String string2) {
            this.name = string2;
        }

        abstract Parameter getParameter(GMRFMultilocusSkyrideLikelihood var1);

        abstract double[] getGradientLogDensity(GMRFMultilocusSkyrideLikelihood var1);

        abstract double[] getDiagonalHessianLogDensity(GMRFMultilocusSkyrideLikelihood var1);

        abstract double getParameterLowerBound();

        public abstract void getWarning(GMRFMultilocusSkyrideLikelihood var1) throws XMLParseException;

        public static WrtParameter factory(String string) {
            for (WrtParameter wrtParameter : WrtParameter.values()) {
                if (!string.equalsIgnoreCase(wrtParameter.name)) continue;
                return wrtParameter;
            }
            return null;
        }
    }
}

