/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous.cdi;

import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateIntegrator;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import org.ejml.data.DenseMatrix64F;

public class SafeMultivariateWithDriftIntegrator
extends SafeMultivariateIntegrator {
    private static boolean DEBUG = false;
    private static final boolean TIMING = false;
    private double[] vectorDispi;
    private double[] vectorDispj;
    double[] displacements;

    public SafeMultivariateWithDriftIntegrator(PrecisionType precisionType, int n, int n2, int n3, int n4, int n5) {
        super(precisionType, n, n2, n3, n4, n5);
        this.allocateStorage();
        System.err.println("Trying SafeMultivariateWithDriftIntegrator");
    }

    @Override
    public void getBranchDisplacement(int n, double[] dArray) {
        if (n == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait);
        System.arraycopy(this.displacements, n * this.dimTrait, dArray, 0, this.dimTrait);
    }

    @Override
    public void getBranchExpectation(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        assert (dArray4 != null);
        assert (dArray4.length >= this.dimTrait);
        assert (dArray2 != null);
        assert (dArray2.length >= this.dimTrait);
        assert (dArray3 != null);
        assert (dArray3.length >= this.dimTrait);
        for (int i = 0; i < this.dimTrait; ++i) {
            dArray4[i] = dArray2[i] + dArray3[i];
        }
    }

    private void allocateStorage() {
        this.displacements = new double[this.dimTrait * this.bufferCount];
        this.vectorDispi = new double[this.dimTrait];
        this.vectorDispj = new double[this.dimTrait];
    }

    @Override
    public void updateBrownianDiffusionMatrices(int n, int[] nArray, double[] dArray, double[] dArray2, int n2) {
        super.updateBrownianDiffusionMatrices(n, nArray, dArray, dArray2, n2);
        if (DEBUG) {
            System.err.println("Matrices (safe with drift):");
        }
        if (dArray2 != null) {
            assert (this.displacements != null);
            assert (dArray2.length >= n2 * this.dimTrait);
            int n3 = 0;
            for (int i = 0; i < n2; ++i) {
                double d = dArray[i];
                int n4 = this.dimTrait * nArray[i];
                SafeMultivariateWithDriftIntegrator.scale(dArray2, n3, d, this.displacements, n4, this.dimTrait);
                n3 += this.dimTrait;
            }
        }
    }

    @Override
    void computeDelta(int n, int n2, double[] dArray) {
        for (int i = 0; i < this.dimTrait; ++i) {
            dArray[i] = this.partials[n + i] - this.displacements[n2 + i];
        }
    }

    @Override
    void scaleAndDriftMean(int n, int n2, int n3) {
        for (int i = 0; i < this.dimTrait; ++i) {
            int n4 = n + i;
            this.preOrderPartials[n4] = this.preOrderPartials[n4] + this.displacements[n3 + i];
        }
    }

    @Override
    void partialMean(int n, int n2, int n3, int n4, int n5) {
        double[] dArray = this.vectorDispi;
        double[] dArray2 = this.vectorDispj;
        for (int i = 0; i < this.dimTrait; ++i) {
            dArray[i] = this.partials[n + i] - this.displacements[n4 + i];
            dArray2[i] = this.partials[n2 + i] - this.displacements[n5 + i];
        }
        double[] dArray3 = this.vectorPMk;
        this.computeWeightedSum(dArray, dArray2, this.dimTrait, dArray3);
        WrappedVector.Raw raw = new WrappedVector.Raw(this.partials, n3, this.dimTrait);
        WrappedVector.Raw raw2 = new WrappedVector.Raw(dArray3, 0, this.dimTrait);
        MissingOps.safeSolve(this.matrixPk, raw2, raw, false);
        if (DEBUG) {
            int n6;
            System.err.print("\t\tdisp i:");
            for (n6 = 0; n6 < this.dimTrait; ++n6) {
                System.err.print(" " + this.displacements[n4 + n6]);
            }
            System.err.println("");
            System.err.print("\t\tdisp j:");
            for (n6 = 0; n6 < this.dimTrait; ++n6) {
                System.err.print(" " + this.displacements[n5 + n6]);
            }
        }
    }

    @Override
    double computeSS(int n, DenseMatrix64F denseMatrix64F, int n2, DenseMatrix64F denseMatrix64F2, int n3, DenseMatrix64F denseMatrix64F3, int n4) {
        return MissingOps.weightedThreeInnerProductNormalized(this.vectorDispi, 0, denseMatrix64F, this.vectorDispj, 0, denseMatrix64F2, this.partials, n3, this.vectorPMk, 0, n4);
    }

    void computeWeightedSum(double[] dArray, double[] dArray2, int n, double[] dArray3) {
        MissingOps.weightedSum(dArray, 0, this.matrixPip, dArray2, 0, this.matrixPjp, n, dArray3);
    }
}

