/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hawkes;

import dr.evolution.util.Taxa;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.hawkes.HawkesCore;
import dr.inference.hawkes.HawkesRateProvider;
import dr.inference.hawkes.MassivelyParallelHPHImpl;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.FastMatrixParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.HeapSort;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AndRule;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.List;
import java.util.StringTokenizer;

public class HawkesLikelihood
extends AbstractModelLikelihood
implements Reportable,
GradientWrtParameterProvider {
    private static final String REQUIRED_FLAGS_PROPERTY = "hph.required.flags";
    private static final String HAWKES_LIKELIHOOD = "hawkesLikelihood";
    private final Double tolerance;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        static final String LOCATIONS = "locations";
        static final String TIMES = "times";
        static final String TIME_ATTRIBUTE_NAME = "timeTrait";
        static final String LOCATION_ATTRIBUTE_NAME = "locationTrait";
        static final String LOCATION_MEAN = "locationMean";
        static final String LOCATION_MEAN_NAME = "locationMeanTrait";
        static final String LOCATION_VARIANCE = "locationVariance";
        static final String LOCATION_VARIANCE_NAME = "locationVarianceTrait";
        static final String LOCATION_VARIANCE_CONVERSION = "conversion";
        static final String BY_INCREMENT = "byIncrement";
        static final String HPH_DIMENSION = "hphDimension";
        static final String SIGMA_PRECISON = "sigmaXprec";
        static final String TAU_X_PRECISION = "tauXprec";
        static final String TAU_T_PRECISION = "tauTprec";
        static final String OMEGA = "omega";
        static final String THETA = "theta";
        static final String MU = "mu0";
        static final String RANDOM_RATES = "randomRates";
        static final String TOLERANCE = "gradientCheckTolerance";
        static final String JITTER = "jitter";
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newIntegerRule("hphDimension", false, "The dimension of the space for HPH"), new ElementRule("locations", new XMLSyntaxRule[]{new ElementRule(MatrixParameterInterface.class), new ElementRule("Optional location prior related trait construction", new XMLSyntaxRule[]{new AndRule(new XMLSyntaxRule[]{new ElementRule("locationMean", MatrixParameterInterface.class), AttributeRule.newStringRule("locationMeanTrait")}), new AndRule(new XMLSyntaxRule[]{new ElementRule("locationVariance", MatrixParameterInterface.class), AttributeRule.newStringRule("locationVarianceTrait")}), AttributeRule.newStringRule("conversion")}, true)}), new ElementRule("times", Taxa.class), AttributeRule.newStringRule("timeTrait"), AttributeRule.newStringRule("locationTrait"), AttributeRule.newBooleanRule("byIncrement", true), new ElementRule("sigmaXprec", Parameter.class), new ElementRule("tauXprec", Parameter.class), new ElementRule("tauTprec", Parameter.class), new ElementRule("omega", Parameter.class), new ElementRule("theta", Parameter.class), new ElementRule("mu0", Parameter.class), new ElementRule("randomRates", Parameter.class, "The random rate parameter.", true), AttributeRule.newDoubleRule("gradientCheckTolerance", true), TreeTraitParserUtilities.jitterRules((boolean)true)};

        @Override
        public String getParserName() {
            return HawkesLikelihood.HAWKES_LIKELIHOOD;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            Object object;
            Parameter parameter;
            Object object2;
            Object object3;
            int n = xMLObject.getIntegerAttribute(HPH_DIMENSION);
            XMLObject xMLObject2 = xMLObject.getChild(LOCATIONS);
            MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface)xMLObject2.getChild(MatrixParameterInterface.class);
            String string = xMLObject.getStringAttribute(TIME_ATTRIBUTE_NAME);
            String string2 = xMLObject.getStringAttribute(LOCATION_ATTRIBUTE_NAME);
            Taxa taxa = (Taxa)xMLObject.getElementFirstChild(TIMES);
            double[] dArray = this.parseTimes(taxa, string, matrixParameterInterface, string2);
            MatrixParameterInterface matrixParameterInterface2 = null;
            MatrixParameterInterface matrixParameterInterface3 = null;
            if (xMLObject2.hasChildNamed(LOCATION_MEAN) && xMLObject2.hasChildNamed(LOCATION_VARIANCE)) {
                matrixParameterInterface2 = (MatrixParameterInterface)xMLObject2.getElementFirstChild(LOCATION_MEAN);
                object3 = xMLObject2.getChild(LOCATION_MEAN).getStringAttribute(LOCATION_MEAN_NAME);
                matrixParameterInterface3 = (MatrixParameterInterface)xMLObject2.getElementFirstChild(LOCATION_VARIANCE);
                object2 = xMLObject2.getChild(LOCATION_VARIANCE).getStringAttribute(LOCATION_VARIANCE_NAME);
                this.parseTimes(taxa, string, matrixParameterInterface2, (String)object3);
                parameter = new FastMatrixParameter("tmpLocationVariance", 1, matrixParameterInterface3.getColumnDimension(), 0.0, false);
                this.parseTimes(taxa, string, (MatrixParameterInterface)parameter, (String)object2);
                object = UnitConversion.factory((String)xMLObject2.getChild(LOCATION_VARIANCE).getAttribute(LOCATION_VARIANCE_CONVERSION));
                for (int i = 0; i < matrixParameterInterface3.getRowDimension(); ++i) {
                    for (int j = 0; j < matrixParameterInterface3.getColumnDimension(); ++j) {
                        matrixParameterInterface3.setParameterValue(i, j, ((UnitConversion)((Object)object)).convert(((FastMatrixParameter)parameter).getParameterValue(0, j)));
                    }
                }
            }
            if (xMLObject.hasChildNamed(JITTER)) {
                object3 = new TreeTraitParserUtilities();
                object2 = TreeTraitParserUtilities.parseMissingIndices((Parameter)matrixParameterInterface, (double[])matrixParameterInterface.getParameterValues());
                object3.jitter(xMLObject, n, (List)object2, matrixParameterInterface.getDimension());
            }
            object3 = (Parameter)xMLObject.getElementFirstChild(SIGMA_PRECISON);
            object2 = (Parameter)xMLObject.getElementFirstChild(TAU_X_PRECISION);
            parameter = (Parameter)xMLObject.getElementFirstChild(TAU_T_PRECISION);
            object = (Parameter)xMLObject.getElementFirstChild(OMEGA);
            Parameter parameter2 = (Parameter)xMLObject.getElementFirstChild(THETA);
            Parameter parameter3 = (Parameter)xMLObject.getElementFirstChild(MU);
            HawkesRateProvider hawkesRateProvider = xMLObject.hasChildNamed(RANDOM_RATES) ? new HawkesRateProvider.Default((Parameter)xMLObject.getElementFirstChild(RANDOM_RATES)) : new HawkesRateProvider.None();
            boolean bl = xMLObject.getAttribute(BY_INCREMENT, false);
            Double d = xMLObject.getAttribute(TOLERANCE, 1.0E-4);
            return new HawkesLikelihood(n, (Parameter)object2, (Parameter)object3, parameter, (Parameter)object, parameter2, parameter3, hawkesRateProvider, matrixParameterInterface, dArray, d, bl);
        }

        private double[] parseTimes(Taxa taxa, String string, MatrixParameterInterface matrixParameterInterface, String string2) {
            double[] dArray = new double[taxa.getTaxonCount()];
            double[] dArray2 = new double[taxa.getTaxonCount()];
            for (int i = 0; i < taxa.getTaxonCount(); ++i) {
                dArray[i] = Double.valueOf((String)taxa.getTaxon(i).getAttribute(string));
            }
            int[] nArray = new int[dArray.length];
            HeapSort.sort(dArray, nArray);
            double d = dArray[nArray[0]];
            for (int i = 0; i < taxa.getTaxonCount(); ++i) {
                dArray2[i] = dArray[nArray[i]] - d;
                matrixParameterInterface.getParameter(i).setId(taxa.getTaxonId(nArray[i]));
                StringTokenizer stringTokenizer = new StringTokenizer((String)taxa.getTaxon(nArray[i]).getAttribute(string2));
                Parameter parameter = matrixParameterInterface.getParameter(i);
                for (int j = 0; j < parameter.getDimension(); ++j) {
                    parameter.setParameterValue(j, Double.valueOf(stringTokenizer.nextToken()));
                }
            }
            return dArray2;
        }

        @Override
        public String getParserDescription() {
            return "Provides the likelihood of pairwise distance given vectors of coordinatesfor points according to the multidimensional scaling scheme of XXX & Rafferty (to fill in).";
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public Class getReturnType() {
            return HawkesLikelihood.class;
        }
    };
    private final int hphDimension;
    private final int locationCount;
    private HawkesCore hphCore;
    private HawkesModel hawkesModel;
    private boolean likelihoodKnown = false;
    private double logLikelihood;
    private double storedLogLikelihood;
    private long flags = 0L;
    private double[] gradient;

    public HawkesLikelihood(int n, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, HawkesRateProvider hawkesRateProvider, MatrixParameterInterface matrixParameterInterface, double[] dArray, Double d, boolean bl) {
        super(HAWKES_LIKELIHOOD);
        this.hawkesModel = new HawkesModel(parameter, parameter2, parameter3, parameter4, parameter5, parameter6, hawkesRateProvider, matrixParameterInterface, dArray, bl);
        this.hphDimension = n;
        this.locationCount = this.hawkesModel.getLocationCount();
        this.tolerance = d;
        this.initialize(n, this.hawkesModel);
    }

    protected int initialize(int n, HawkesModel hawkesModel) {
        this.hphCore = this.getCore();
        System.err.println("Initializing with flags: " + this.flags);
        this.hphCore.initialize(n, this.locationCount, this.flags);
        this.hawkesModel = hawkesModel;
        int n2 = this.hphCore.getInternalDimension();
        this.setupLocationsParameter(hawkesModel.getLocationsParameter());
        this.hphCore.setParameters(hawkesModel.getParameterValues());
        hawkesModel.getRateProvider().setRandomRates(this.hphCore);
        this.updateAllLocations(hawkesModel.getLocationsParameter());
        this.addModel(hawkesModel);
        return n2;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        double d = this.getLogLikelihood();
        stringBuilder.append(this.getClass().getName()).append("(").append(d).append(")");
        return stringBuilder.toString();
    }

    @Override
    public Likelihood getLikelihood() {
        return this;
    }

    public HawkesModel getHawkesModel() {
        return this.hawkesModel;
    }

    @Override
    public Parameter getParameter() {
        return this.hawkesModel.getLocationsParameter();
    }

    @Override
    public int getDimension() {
        return this.hawkesModel.getLocationsParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        if (this.gradient == null) {
            this.gradient = new double[this.hawkesModel.getLocationsParameter().getDimension()];
        }
        this.getLogLikelihood();
        this.hphCore.getLocationGradient(this.gradient);
        return this.gradient;
    }

    public double[] getRandomRateGradient() {
        double[] dArray = new double[this.hawkesModel.getRateProvider().getParameter().getDimension()];
        this.getLogLikelihood();
        this.hphCore.getRandomRatesGradient(dArray);
        return dArray;
    }

    public MatrixParameterInterface getMatrixParameter() {
        return this.hawkesModel.getLocationsParameter();
    }

    private HawkesCore getCore() {
        long l = 0L;
        String string = System.getProperty(REQUIRED_FLAGS_PROPERTY);
        if (string != null) {
            l = Long.parseLong(string.trim());
        }
        System.err.println("Attempting to use a native HPH core with flag: " + l + "; may the force be with you ....");
        MassivelyParallelHPHImpl massivelyParallelHPHImpl = new MassivelyParallelHPHImpl();
        this.flags = l;
        return massivelyParallelHPHImpl;
    }

    public int getHphDimension() {
        return this.hphDimension;
    }

    public int getLocationCount() {
        return this.locationCount;
    }

    private void updateAllLocations(MatrixParameterInterface matrixParameterInterface) {
        this.hphCore.updateLocation(-1, matrixParameterInterface.getParameterValues());
    }

    private void setupLocationsParameter(MatrixParameterInterface matrixParameterInterface) {
        boolean bl;
        boolean bl2 = bl = matrixParameterInterface.getColumnDimension() > 0;
        if (bl) {
            if (matrixParameterInterface.getColumnDimension() != this.locationCount) {
                throw new RuntimeException("locationsParameter column dimension (" + matrixParameterInterface.getColumnDimension() + ") is not equal to the locationCount (" + this.locationCount + ")");
            }
            if (matrixParameterInterface.getRowDimension() != this.hphDimension) {
                throw new RuntimeException("locationsParameter row dimension (" + matrixParameterInterface.getRowDimension() + ") is not equal to the hphDimension (" + this.hphDimension + ")");
            }
        } else {
            throw new IllegalArgumentException("Dimensions on matrix must be set");
        }
        for (int i = 0; i < matrixParameterInterface.getColumnDimension(); ++i) {
            Parameter parameter = matrixParameterInterface.getParameter(i);
            try {
                parameter.getBounds();
                continue;
            }
            catch (NullPointerException nullPointerException) {
                parameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, parameter.getDimension()));
            }
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    @Override
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
        this.hphCore.storeState();
    }

    @Override
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = true;
        this.hphCore.restoreState();
    }

    @Override
    protected void acceptState() {
        this.hphCore.acceptState();
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.hphCore.makeDirty();
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.updateAllLocations(this.hawkesModel.getLocationsParameter());
            this.hphCore.setTimesData(this.hawkesModel.getTimes());
            this.hphCore.setParameters(this.hawkesModel.getParameterValues());
            this.hawkesModel.getRateProvider().setRandomRates(this.hphCore);
            this.logLikelihood = this.hphCore.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    public class HawkesModel
    extends AbstractModel {
        final Parameter tauXprec;
        final Parameter sigmaXprec;
        final Parameter tauTprec;
        final Parameter omega;
        final Parameter theta;
        final Parameter mu0;
        final HawkesRateProvider rateProvider;
        final MatrixParameterInterface locationsParameter;
        final double[] times;
        static final String HAWKES_MODEL = "HawkesModel";
        final boolean byIncrement;

        public HawkesModel(Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, HawkesRateProvider hawkesRateProvider, MatrixParameterInterface matrixParameterInterface, double[] dArray, boolean bl) {
            super(HAWKES_MODEL);
            this.byIncrement = bl;
            this.tauXprec = parameter;
            this.sigmaXprec = parameter2;
            this.tauTprec = parameter3;
            this.omega = parameter4;
            this.theta = parameter5;
            this.mu0 = parameter6;
            this.rateProvider = hawkesRateProvider;
            this.locationsParameter = matrixParameterInterface;
            this.times = dArray;
            this.checkDimensions();
            this.addVariable(parameter);
            this.addVariable(parameter3);
            this.addVariable(parameter2);
            this.addVariable(parameter4);
            this.addVariable(parameter5);
            this.addVariable(parameter6);
            this.addVariable(matrixParameterInterface);
            if (hawkesRateProvider instanceof Model) {
                this.addModel((Model)((Object)hawkesRateProvider));
            }
        }

        private void checkDimensions() {
            if (this.times.length != this.getLocationCount()) {
                throw new RuntimeException("Times dimension doesn't match location count.");
            }
            if (this.getTotalDimension() != 6) {
                throw new RuntimeException("Parameter dimension is wrong.");
            }
        }

        public HawkesRateProvider getRateProvider() {
            return this.rateProvider;
        }

        public MatrixParameterInterface getLocationsParameter() {
            return this.locationsParameter;
        }

        public double[] getTimes() {
            return this.times;
        }

        private int getTotalDimension() {
            return this.sigmaXprec.getDimension() + this.tauXprec.getDimension() + this.tauTprec.getDimension() + this.omega.getDimension() + this.theta.getDimension() + this.mu0.getDimension();
        }

        public double[] getParameterValues() {
            double[] dArray = new double[]{this.byIncrement ? this.sigmaXprec.getParameterValue(0) + this.tauXprec.getParameterValue(0) : this.sigmaXprec.getParameterValue(0), this.tauXprec.getParameterValue(0), this.tauTprec.getParameterValue(0), this.byIncrement ? this.tauTprec.getParameterValue(0) + this.omega.getParameterValue(0) : this.omega.getParameterValue(0), this.theta.getParameterValue(0), this.mu0.getParameterValue(0)};
            return dArray;
        }

        public int getLocationCount() {
            return this.locationsParameter.getColumnDimension();
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
            HawkesLikelihood.this.likelihoodKnown = false;
        }

        @Override
        protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
            HawkesLikelihood.this.likelihoodKnown = false;
        }

        @Override
        protected void storeState() {
            HawkesLikelihood.this.storedLogLikelihood = HawkesLikelihood.this.logLikelihood;
            HawkesLikelihood.this.hphCore.storeState();
        }

        @Override
        protected void restoreState() {
            HawkesLikelihood.this.logLikelihood = HawkesLikelihood.this.storedLogLikelihood;
            HawkesLikelihood.this.likelihoodKnown = true;
            HawkesLikelihood.this.hphCore.restoreState();
        }

        @Override
        protected void acceptState() {
            HawkesLikelihood.this.hphCore.acceptState();
        }
    }

    static enum UnitConversion {
        KM_TO_DEGREE("kmToDegree"){

            @Override
            double convert(double d) {
                return Math.pow(110.5, 2.0) * 6.0 * Math.PI / d;
            }
        };

        private final String name;

        private UnitConversion(String string2) {
            this.name = string2;
        }

        abstract double convert(double var1);

        public static UnitConversion factory(String string) {
            for (UnitConversion unitConversion : UnitConversion.values()) {
                if (!string.equalsIgnoreCase(unitConversion.name)) continue;
                return unitConversion;
            }
            return null;
        }
    }

    public static enum ObservationType {
        POINT,
        UPPER_BOUND,
        LOWER_BOUND,
        MISSING;

    }
}

