diff --git a/src/main/java/mascot/app/beauti/NeDynamicsListInputEditor.java b/src/main/java/mascot/app/beauti/NeDynamicsListInputEditor.java index 45546bf..edcc191 100644 --- a/src/main/java/mascot/app/beauti/NeDynamicsListInputEditor.java +++ b/src/main/java/mascot/app/beauti/NeDynamicsListInputEditor.java @@ -445,17 +445,22 @@ private void addParameter(RealParameter parameter, String pId, String dynamics, } private void removeParameters(NeDynamics neDynamics, MCMC mcmc) { + // The spec interface inputs (RealScalar / RealVector) don't extend + // StateNode, but at runtime the XML always binds a RealScalarParam / + // RealVectorParam (which do). Cast to StateNode at the call site so + // the BEAUti disconnect logic can use getID(). + CompoundDistribution posterior = (CompoundDistribution) mcmc.posteriorInput.get(); if (neDynamics instanceof ConstantNe) { - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ConstantNe) neDynamics).NeInput.get()); + removeParameter(posterior, (StateNode) ((ConstantNe) neDynamics).NeInput.get()); }else if (neDynamics instanceof ExponentialNe) { - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ExponentialNe) neDynamics).growthRateInput.get()); - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((ExponentialNe) neDynamics).logNeNullInput.get()); + removeParameter(posterior, (StateNode) ((ExponentialNe) neDynamics).growthRateInput.get()); + removeParameter(posterior, (StateNode) ((ExponentialNe) neDynamics).logNeNullInput.get()); }else if (neDynamics instanceof Skygrowth) { - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((Skygrowth) neDynamics).NeInput.get()); + removeParameter(posterior, (StateNode) ((Skygrowth) neDynamics).NeInput.get()); }else if (neDynamics instanceof LogisticNe) { - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).capacityInput.get()); - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).carryingProportionInput.get()); - removeParameter(((CompoundDistribution) mcmc.posteriorInput.get()), ((LogisticNe) neDynamics).growthRateInput.get()); + removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).capacityInput.get()); + removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).carryingProportionInput.get()); + removeParameter(posterior, (StateNode) ((LogisticNe) neDynamics).growthRateInput.get()); } } diff --git a/src/main/java/mascot/glmmodel/ErrorSmoothing.java b/src/main/java/mascot/glmmodel/ErrorSmoothing.java index a531acc..4859238 100644 --- a/src/main/java/mascot/glmmodel/ErrorSmoothing.java +++ b/src/main/java/mascot/glmmodel/ErrorSmoothing.java @@ -6,9 +6,11 @@ import beast.base.inference.Distribution; import beast.base.inference.State; import beast.base.inference.StateNode; -import beast.base.inference.distribution.ParametricDistribution; +import beast.base.spec.domain.Real; +import beast.base.spec.inference.distribution.ScalarDistribution; import beast.base.spec.inference.parameter.IntVectorParam; import beast.base.spec.inference.parameter.RealVectorParam; +import beast.base.spec.type.RealVector; import java.util.ArrayList; import java.util.List; @@ -19,14 +21,14 @@ "If x is multidimensional, the components of x are assumed to be independent, " + "so the sum of log probabilities of all elements of x is returned as the prior.") public class ErrorSmoothing extends Distribution { - final public Input m_x = new Input<>("x", "point at which the density is calculated", Validate.REQUIRED); + final public Input> m_x = new Input<>("x", "point at which the density is calculated", Validate.REQUIRED); - final public Input distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED); + final public Input> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED); /** * shadows distInput * */ - protected ParametricDistribution dist; + protected ScalarDistribution dist; @Override public void initAndValidate() { @@ -36,9 +38,14 @@ public void initAndValidate() { @Override public double calculateLogP() { - Function x = m_x.get(); - // spec types enforce bounds via domain, so no explicit check needed - logP = dist.calcLogP(x); + RealVector x = m_x.get(); + // Components of x are treated as iid draws from `dist` — sum the + // scalar log-densities. The spec ScalarDistribution doesn't have a + // single-shot logP-over-a-vector method; loop manually. + logP = 0; + for (int i = 0; i < x.size(); i++) { + logP += dist.logDensity(x.get(i)); + } if (logP == Double.POSITIVE_INFINITY) { logP = Double.NEGATIVE_INFINITY; } @@ -67,19 +74,17 @@ public void sample(State state, Random random) { sampleConditions(state, random); // sample distribution parameters - Function x = m_x.get(); + RealVector x = m_x.get(); - Double[] newx; try { - newx = dist.sample(1)[0]; - + // Spec ScalarDistribution.sample() returns one value per call; + // for a vector x we draw size(x) iid samples. m_x is now typed + // RealVector, so only the real-vector branch + // applies (the legacy code had an IntVectorParam branch that's + // unreachable under the new Input type). if (x instanceof RealVectorParam rvp) { - for (int i = 0; i < newx.length; i++) { - rvp.set(i, newx[i]); - } - } else if (x instanceof IntVectorParam ivp) { - for (int i = 0; i < newx.length; i++) { - ivp.set(i, (int)Math.round(newx[i])); + for (int i = 0; i < rvp.size(); i++) { + rvp.set(i, dist.sample().get(0)); } } diff --git a/src/main/java/mascot/glmmodel/GlmModel.java b/src/main/java/mascot/glmmodel/GlmModel.java index 6f90f80..5c40c58 100644 --- a/src/main/java/mascot/glmmodel/GlmModel.java +++ b/src/main/java/mascot/glmmodel/GlmModel.java @@ -4,12 +4,16 @@ import beast.base.core.Input.Validate; import beast.base.core.Loggable; import beast.base.inference.CalculationNode; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; import beast.base.spec.inference.parameter.BoolVectorParam; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealScalarParam; import beast.base.spec.inference.parameter.RealVectorParam; -public abstract class GlmModel extends CalculationNode implements Loggable { +import java.util.List; + +public abstract class GlmModel extends CalculationNode implements Loggable, StateNodeInitialiser { public Input covariateListInput = new Input<>("covariateList", "input of covariates", Validate.REQUIRED); public Input> scalerInput = new Input<>("scaler", "input of covariates scaler", Validate.REQUIRED); @@ -84,6 +88,21 @@ public void setNrDummy(){ verticalEntries = 0; } + // Subclasses (LogLinear, etc.) fix up the dimensions of scaler / + // indicator / error parameters during their own initAndValidate. The + // default initStateNodes is a no-op; getInitialisedStateNodes reports + // the parameters this class takes responsibility for so the framework + // can dedupe. + @Override + public void initStateNodes() { + } + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(scalerInput.get()); + stateNodes.add(indicatorInput.get()); + if (errorInput.get() != null) stateNodes.add(errorInput.get()); + if (constantErrorInput.get() != null) stateNodes.add(constantErrorInput.get()); + } } diff --git a/src/main/java/mascot/glmmodel/MaxRate.java b/src/main/java/mascot/glmmodel/MaxRate.java index 545179f..818b5aa 100644 --- a/src/main/java/mascot/glmmodel/MaxRate.java +++ b/src/main/java/mascot/glmmodel/MaxRate.java @@ -6,9 +6,7 @@ import beast.base.core.Input.Validate; import beast.base.inference.Distribution; import beast.base.inference.State; -import beast.base.inference.distribution.ParametricDistribution; -import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealVectorParam; +import beast.base.spec.inference.distribution.ScalarDistribution; import mascot.dynamics.GLM; import java.util.ArrayList; @@ -21,14 +19,14 @@ "so the sum of log probabilities of all elements of x is returned as the prior.") public class MaxRate extends Distribution { final public Input GLMStepwiseModelInput = new Input<>("GLMmodel", "glm model input"); - final public Input distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED); + final public Input> distInput = new Input<>("distr", "distribution used to calculate prior, e.g. normal, beta, gamma.", Validate.REQUIRED); final public Input migrationOnlyInput = new Input<>("migrationOnly", "put prior only on migration rates", false); final public Input NeOnlyInput = new Input<>("NeOnly", "put prior only on migration rates", false); /** * shadows distInput * */ - protected ParametricDistribution dist; + protected ScalarDistribution dist; @Override @@ -42,16 +40,13 @@ public double calculateLogP() { Double[] mig = GLMStepwiseModelInput.get().getAllCoalescentRate(); Double[] coal = GLMStepwiseModelInput.get().getAllBackwardsMigration(); - RealVectorParam dCoal = new RealVectorParam<>(unbox(coal), Real.INSTANCE); - RealVectorParam dMig = new RealVectorParam<>(unbox(mig), Real.INSTANCE); - logP = 0.0; if (migrationOnlyInput.get()){ - logP += dist.calcLogP(dMig); + logP += sumLogDensity(mig); }else{ - logP += dist.calcLogP(dCoal); - logP += dist.calcLogP(dMig); + logP += sumLogDensity(coal); + logP += sumLogDensity(mig); } if (logP == Double.POSITIVE_INFINITY) { logP = Double.NEGATIVE_INFINITY; @@ -59,12 +54,14 @@ public double calculateLogP() { return logP; } - private static double[] unbox(Double[] values) { - double[] result = new double[values.length]; - for (int i = 0; i < values.length; i++) { - result[i] = values[i]; - } - return result; + private double sumLogDensity(Double[] values) { + // Each entry is treated as an iid draw from `dist` (matches what + // legacy ParametricDistribution.calcLogP(Function) did over a + // multidimensional sample). Spec ScalarDistribution exposes + // logDensity per scalar; sum manually. + double sum = 0; + for (Double v : values) sum += dist.logDensity(v); + return sum; } /** diff --git a/src/main/java/mascot/logger/StructuredTreeLogger.java b/src/main/java/mascot/logger/StructuredTreeLogger.java index 20a5a52..8db8c96 100644 --- a/src/main/java/mascot/logger/StructuredTreeLogger.java +++ b/src/main/java/mascot/logger/StructuredTreeLogger.java @@ -10,7 +10,7 @@ import beast.base.evolution.tree.Tree; import beast.base.evolution.tree.TreeInterface; import beast.base.inference.StateNode; -import beast.base.spec.inference.parameter.BoolVectorParam; +import beast.base.spec.type.BoolVector; import mascot.distribution.Mascot; import mascot.ode.Euler2ndOrderTransitions; import mascot.ode.MascotODEUpDown; @@ -48,7 +48,7 @@ public class StructuredTreeLogger extends Tree implements Loggable { public Input clockModelInput = new Input("branchratemodel", "rate to be logged with branches of the tree"); public Input> parameterInput = new Input>("metadata", "meta data to be logged with the tree nodes", new ArrayList<>()); public Input maxStateInput = new Input("maxState", "report branch lengths as substitutions (branch length times clock rate for the branch)", false); - public Input conditionalStateProbsInput = new Input<>("conditionalStateProbs", "report branch lengths as substitutions (branch length times clock rate for the branch)"); + public Input conditionalStateProbsInput = new Input<>("conditionalStateProbs", "report branch lengths as substitutions (branch length times clock rate for the branch)"); public Input substitutionsInput = new Input("substitutions", "report branch lengths as substitutions (branch length times clock rate for the branch)", false); public Input decimalPlacesInput = new Input("dp", "the number of decimal places to use writing branch lengths and rates, use -1 for full precision (default = full precision)", -1); diff --git a/src/main/java/mascot/logger/mappedProbLogger.java b/src/main/java/mascot/logger/mappedProbLogger.java index f5a8d7a..4dd99cc 100644 --- a/src/main/java/mascot/logger/mappedProbLogger.java +++ b/src/main/java/mascot/logger/mappedProbLogger.java @@ -7,7 +7,7 @@ import beast.base.evolution.branchratemodel.BranchRateModel; import beast.base.evolution.tree.Node; import beast.base.inference.CalculationNode; -import beast.base.spec.inference.parameter.BoolVectorParam; +import beast.base.spec.type.BoolVector; import mascot.distribution.MappedMascot; import java.io.PrintStream; @@ -34,7 +34,7 @@ public class mappedProbLogger extends CalculationNode implements Loggable { "meta data to be logged with the tree nodes", new ArrayList<>()); public Input maxStateInput = new Input("maxState", "report branch lengths as substitutions (branch length times clock rate for the branch)", false); - public Input conditionalStateProbsInput = new Input<>("conditionalStateProbs", + public Input conditionalStateProbsInput = new Input<>("conditionalStateProbs", "report branch lengths as substitutions (branch length times clock rate for the branch)"); public Input substitutionsInput = new Input("substitutions", "report branch lengths as substitutions (branch length times clock rate for the branch)", false); diff --git a/src/main/java/mascot/mapped/AncestralStateTreeLikelihood.java b/src/main/java/mascot/mapped/AncestralStateTreeLikelihood.java index a270f4b..287cd4f 100644 --- a/src/main/java/mascot/mapped/AncestralStateTreeLikelihood.java +++ b/src/main/java/mascot/mapped/AncestralStateTreeLikelihood.java @@ -5,7 +5,7 @@ import beast.base.core.Loggable; import beast.base.evolution.alignment.Alignment; import beast.base.evolution.datatype.DataType; -import beast.base.evolution.likelihood.TreeLikelihood; +import beast.base.spec.evolution.likelihood.TreeLikelihood; import beast.base.evolution.tree.Node; import beast.base.evolution.tree.Tree; import beast.base.evolution.tree.TreeInterface; diff --git a/src/main/java/mascot/parameterdynamics/ConstantNe.java b/src/main/java/mascot/parameterdynamics/ConstantNe.java index 6f073aa..59aee25 100644 --- a/src/main/java/mascot/parameterdynamics/ConstantNe.java +++ b/src/main/java/mascot/parameterdynamics/ConstantNe.java @@ -3,14 +3,14 @@ import beast.base.core.Input; import beast.base.core.Input.Validate; import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealScalarParam; +import beast.base.spec.type.RealScalar; public class ConstantNe extends NeDynamics { - public Input> NeInput = new Input<>( + public Input> NeInput = new Input<>( "logNe", "input of the Ne at the time of the most recent sampled ancestor", Validate.REQUIRED); - RealScalarParam Ne; + RealScalar Ne; @Override public void initAndValidate() { @@ -31,9 +31,6 @@ public double getNeTime(double t) { @Override public boolean isDirty() { - if (Ne.somethingIsDirty()) - return true; - - return false; + return isDirtyInput(Ne); } } diff --git a/src/main/java/mascot/parameterdynamics/ExponentialNe.java b/src/main/java/mascot/parameterdynamics/ExponentialNe.java index dce0d92..b5a54a6 100644 --- a/src/main/java/mascot/parameterdynamics/ExponentialNe.java +++ b/src/main/java/mascot/parameterdynamics/ExponentialNe.java @@ -3,20 +3,20 @@ import beast.base.core.Input; import beast.base.core.Input.Validate; import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealScalarParam; +import beast.base.spec.type.RealScalar; public class ExponentialNe extends NeDynamics { - public Input> logNeNullInput = new Input<>( + public Input> logNeNullInput = new Input<>( "NeNull", "input of the Ne at the time of the most recent sampled ancestor", Validate.REQUIRED); - public Input> growthRateInput = new Input<>( + public Input> growthRateInput = new Input<>( "growthRate", "input of the growth rate", Validate.REQUIRED); public Input minNeInput = new Input<>( "minNe", "input of the minimal Ne", 0.0); - RealScalarParam logNeNull; - RealScalarParam growthRate; + RealScalar logNeNull; + RealScalar growthRate; @Override public void initAndValidate() { @@ -39,13 +39,7 @@ public double getNeTime(double t) { @Override public boolean isDirty() { - if (logNeNull.somethingIsDirty()) - return true; - - if (growthRate.somethingIsDirty()) - return true; - - return false; + return isDirtyInput(logNeNull) || isDirtyInput(growthRate); } diff --git a/src/main/java/mascot/parameterdynamics/LogLinearGLM.java b/src/main/java/mascot/parameterdynamics/LogLinearGLM.java index ffc91c9..888800b 100644 --- a/src/main/java/mascot/parameterdynamics/LogLinearGLM.java +++ b/src/main/java/mascot/parameterdynamics/LogLinearGLM.java @@ -2,13 +2,17 @@ import beast.base.core.Input; import beast.base.core.Input.Validate; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; import beast.base.spec.inference.parameter.BoolVectorParam; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealScalarParam; import beast.base.spec.inference.parameter.RealVectorParam; import mascot.glmmodel.CovariateList; -public class LogLinearGLM extends NeDynamics { +import java.util.List; + +public class LogLinearGLM extends NeDynamics implements StateNodeInitialiser { public Input covariateListInput = new Input<>("covariateList", "input of covariates", Validate.REQUIRED); public Input> scalerInput = new Input<>("scaler", "input of covariates scaler", Validate.REQUIRED); public Input indicatorInput = new Input<>("indicator", "input of covariates scaler", Validate.REQUIRED); @@ -131,5 +135,15 @@ public void restore() { valuesKnown = false; } + @Override + public void initStateNodes() { + } + + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(scalerInput.get()); + stateNodes.add(indicatorInput.get()); + if (errorInput.get() != null) stateNodes.add(errorInput.get()); + } } diff --git a/src/main/java/mascot/parameterdynamics/LogisticNe.java b/src/main/java/mascot/parameterdynamics/LogisticNe.java index 7bbec9b..cccc195 100644 --- a/src/main/java/mascot/parameterdynamics/LogisticNe.java +++ b/src/main/java/mascot/parameterdynamics/LogisticNe.java @@ -3,23 +3,23 @@ import beast.base.core.Input; import beast.base.core.Input.Validate; import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealScalarParam; +import beast.base.spec.type.RealScalar; public class LogisticNe extends NeDynamics { - public Input> carryingProportionInput = new Input<>( + public Input> carryingProportionInput = new Input<>( "carryingProportion", "the proportion of the current Ne of the maximial Ne (capactity)", Validate.REQUIRED); - public Input> capacityInput = new Input<>( + public Input> capacityInput = new Input<>( "capacity", "input of the maximal Ne", Validate.REQUIRED); - public Input> growthRateInput = new Input<>( + public Input> growthRateInput = new Input<>( "growthRate", "input of the growth rate", Validate.REQUIRED); public Input minNeInput = new Input<>( "minNe", "input of the minimal Ne", 0.0); - RealScalarParam cP; - RealScalarParam capacity; - RealScalarParam growthRate; + RealScalar cP; + RealScalar capacity; + RealScalar growthRate; @Override public void initAndValidate() { @@ -43,16 +43,7 @@ public double getNeTime(double t) { @Override public boolean isDirty() { - if (cP.somethingIsDirty()) - return true; - - if (capacity.somethingIsDirty()) - return true; - - if (growthRate.somethingIsDirty()) - return true; - - return false; + return isDirtyInput(cP) || isDirtyInput(capacity) || isDirtyInput(growthRate); } } diff --git a/src/main/java/mascot/parameterdynamics/NeDynamics.java b/src/main/java/mascot/parameterdynamics/NeDynamics.java index 9d538bf..c970a0d 100644 --- a/src/main/java/mascot/parameterdynamics/NeDynamics.java +++ b/src/main/java/mascot/parameterdynamics/NeDynamics.java @@ -1,16 +1,17 @@ package mascot.parameterdynamics; import beast.base.inference.CalculationNode; +import beast.base.inference.StateNode; public abstract class NeDynamics extends CalculationNode { - + public boolean isTime; - + @Override public void initAndValidate() { } - + /** * recalculate the dynamics */ @@ -24,7 +25,7 @@ public void initAndValidate() { public double getNeTime(double t) { throw new IllegalArgumentException("Function not implemented. Class of parametric function not correctly recognized"); } - + /** * returns the effective population size at interval i @@ -34,13 +35,26 @@ public double getNeTime(double t) { public double getNeInterval(int i) { throw new IllegalArgumentException("Function not implemented. Class of parametric function not correctly recognized"); } - + public void setNrIntervals(int intervals) {} - + public boolean isDirty() { return true; }; - - + + /** + * True if the value bound by {@code input} has been marked dirty by the + * MCMC framework. Spec interface inputs (RealScalar, RealVector, …) + * don't expose dirtiness directly — they may be backed either by a + * {@link StateNode} (proposal-dirty signal) or by an upstream + * {@link CalculationNode} (recalculation-dirty signal). Subclasses use + * this helper inside {@link #isDirty()} so concrete-spec-param Inputs + * can be replaced by the interface without losing the dirty check. + */ + protected static boolean isDirtyInput(Object input) { + if (input instanceof StateNode s && s.somethingIsDirty()) return true; + if (input instanceof CalculationNode c && c.isDirtyCalculation()) return true; + return false; + } } diff --git a/src/main/java/mascot/parameterdynamics/NeSplineInterpolation.java b/src/main/java/mascot/parameterdynamics/NeSplineInterpolation.java index 7407a5e..39b7368 100644 --- a/src/main/java/mascot/parameterdynamics/NeSplineInterpolation.java +++ b/src/main/java/mascot/parameterdynamics/NeSplineInterpolation.java @@ -2,18 +2,22 @@ import beast.base.core.Description; import beast.base.core.Input; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealVectorParam; import mascot.dynamics.RateShifts; import org.apache.commons.math4.legacy.analysis.interpolation.SplineInterpolator; import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialSplineFunction; +import java.util.List; + /** * @author Nicola F. Mueller */ @Description("Populaiton function with values at certain time points that are interpolated in between. Parameter has to be in log space") -public class NeSplineInterpolation extends NeDynamics { +public class NeSplineInterpolation extends NeDynamics implements StateNodeInitialiser { final public Input> NeInput = new Input<>("logNe", "Nes over time in log space", Input.Validate.REQUIRED); @@ -116,10 +120,17 @@ public void recalculate() { public boolean isDirty() { if (Ne.isDirty(0)) return true; - + return false; } + @Override + public void initStateNodes() { + } + + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(NeInput.get()); + } - } \ No newline at end of file diff --git a/src/main/java/mascot/parameterdynamics/Skygrowth.java b/src/main/java/mascot/parameterdynamics/Skygrowth.java index 529798e..7fbd8fc 100644 --- a/src/main/java/mascot/parameterdynamics/Skygrowth.java +++ b/src/main/java/mascot/parameterdynamics/Skygrowth.java @@ -2,16 +2,20 @@ import beast.base.core.Description; import beast.base.core.Input; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealVectorParam; import mascot.dynamics.RateShifts; +import java.util.List; + /** * @author Nicola F. Mueller */ @Description("Populaiton function with values at certain time points that are interpolated in between. Parameter has to be in log space") -public class Skygrowth extends NeDynamics { +public class Skygrowth extends NeDynamics implements StateNodeInitialiser { final public Input> NeInput = new Input<>("logNe", "Nes over time in log space", Input.Validate.REQUIRED); @@ -112,10 +116,20 @@ public void recalculate() { public boolean isDirty() { if (Ne.isDirty(0)) return true; - + return false; } + // Dimension fixing is already done in initAndValidate; nothing further + // needed here. The interface is implemented so the framework knows this + // class owns the dimension of NeInput. + @Override + public void initStateNodes() { + } + + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(NeInput.get()); + } - } \ No newline at end of file diff --git a/src/main/java/mascot/parameterdynamics/StructuredSkygrid.java b/src/main/java/mascot/parameterdynamics/StructuredSkygrid.java index 35af55a..95a6d45 100644 --- a/src/main/java/mascot/parameterdynamics/StructuredSkygrid.java +++ b/src/main/java/mascot/parameterdynamics/StructuredSkygrid.java @@ -3,12 +3,16 @@ import beast.base.core.Description; import beast.base.core.Input; import beast.base.core.Input.Validate; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealVectorParam; +import java.util.List; + @Description("Skygrid style dynamics for MASCOT. These assume constant effective population sizes within one state "+ "and that these effective population sizes only change within one interval") -public class StructuredSkygrid extends NeDynamics { +public class StructuredSkygrid extends NeDynamics implements StateNodeInitialiser { public Input> NeLogInput = new Input<>( "NeLog", "input of the log effective population sizes", Validate.REQUIRED); @@ -44,7 +48,16 @@ public boolean isDirty() { if(NeLogInput.get().isDirty(i)) return true; - return false; + return false; + } + + @Override + public void initStateNodes() { + } + + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(NeLogInput.get()); } } diff --git a/src/main/java/mascot/skyline/GLMPrior.java b/src/main/java/mascot/skyline/GLMPrior.java index 2f0d723..e420eaa 100644 --- a/src/main/java/mascot/skyline/GLMPrior.java +++ b/src/main/java/mascot/skyline/GLMPrior.java @@ -4,7 +4,9 @@ import beast.base.core.Input; import beast.base.inference.Distribution; import beast.base.inference.State; -import beast.base.inference.distribution.ParametricDistribution; +import beast.base.inference.StateNode; +import beast.base.inference.StateNodeInitialiser; +import beast.base.spec.inference.distribution.ScalarDistribution; import beast.base.spec.inference.parameter.BoolVectorParam; import beast.base.spec.domain.Real; import beast.base.spec.inference.parameter.RealVectorParam; @@ -22,7 +24,7 @@ " to enhance phylogeographic reconstructions\n"+ " PLOS Computational Biology 21(9):e1013421,\n"+ " https://doi.org/10.1371/journal.pcbi.1013421") -public class GLMPrior extends Distribution { +public class GLMPrior extends Distribution implements StateNodeInitialiser { public Input covariateListInput = new Input<>("covariateList", "input of covariates", Input.Validate.REQUIRED); public Input> scalerInput = new Input<>("scaler", "input of covariates scaler", Input.Validate.REQUIRED); @@ -33,11 +35,11 @@ public class GLMPrior extends Distribution { public Input NeFunctionInput = new Input<>( "NeDynamics", "input of the log effective population sizes", Input.Validate.REQUIRED); - final public Input distInput = new Input<>("distr", + final public Input> distInput = new Input<>("distr", "distribution used to on the error terms of the GLM.", Input.Validate.REQUIRED); - protected ParametricDistribution dist; + protected ScalarDistribution dist; double[] intTimes; @@ -167,4 +169,14 @@ public List getConditions() { public void sample(State state, Random random) { } + + @Override + public void initStateNodes() { + } + + @Override + public void getInitialisedStateNodes(List stateNodes) { + stateNodes.add(scalerInput.get()); + stateNodes.add(indicatorInput.get()); + } } diff --git a/src/main/java/mascot/skyline/GrowthRateSmoothingPrior.java b/src/main/java/mascot/skyline/GrowthRateSmoothingPrior.java index b9a0ee6..b62b4ab 100644 --- a/src/main/java/mascot/skyline/GrowthRateSmoothingPrior.java +++ b/src/main/java/mascot/skyline/GrowthRateSmoothingPrior.java @@ -5,9 +5,9 @@ import beast.base.core.Input.Validate; import beast.base.inference.Distribution; import beast.base.inference.State; -import beast.base.inference.distribution.ParametricDistribution; import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealVectorParam; +import beast.base.spec.inference.distribution.ScalarDistribution; +import beast.base.spec.type.RealVector; import mascot.dynamics.RateShifts; import java.util.List; @@ -19,54 +19,53 @@ " PLOS Computational Biology 21(9):e1013421,\n"+ " https://doi.org/10.1371/journal.pcbi.1013421") public class GrowthRateSmoothingPrior extends Distribution { - - public Input> NeLogInput = new Input<>( - "NeLog", "input of effective population sizes"); - - final public Input distInput = new Input<>("distr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + + public Input> NeLogInput = new Input<>( + "NeLog", "input of effective population sizes"); + + final public Input> distInput = new Input<>("distr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Validate.REQUIRED); - - final public Input initDistrInput = new Input<>("initialDistr", + + final public Input> initDistrInput = new Input<>("initialDistr", "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Input.Validate.OPTIONAL); - - final public Input finalDistrInput = new Input<>("finalDistr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + + final public Input> finalDistrInput = new Input<>("finalDistr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Input.Validate.OPTIONAL); - - final public Input meanDistrInput = new Input<>("meanDistr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + + final public Input> meanDistrInput = new Input<>("meanDistr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Input.Validate.OPTIONAL); - + public Input rateShiftsInput = new Input<>( - "rateShifts", "timing of the rate shifts", Validate.REQUIRED); - - - - - private RealVectorParam NeLog; - - protected ParametricDistribution dist; - protected ParametricDistribution initDistr; - protected ParametricDistribution finalDistr; - protected ParametricDistribution meanDistr; - + "rateShifts", "timing of the rate shifts", Validate.REQUIRED); + + + + private RealVector NeLog; + + protected ScalarDistribution dist; + protected ScalarDistribution initDistr; + protected ScalarDistribution finalDistr; + protected ScalarDistribution meanDistr; + RateShifts rateShifts; - + @Override public void initAndValidate() { - NeLog = NeLogInput.get(); + NeLog = NeLogInput.get(); dist = distInput.get(); rateShifts = rateShiftsInput.get(); - + if (initDistrInput.get()!=null) initDistr = initDistrInput.get(); if (finalDistrInput.get()!=null) finalDistr = finalDistrInput.get(); if (meanDistrInput.get()!=null) meanDistr = meanDistrInput.get(); - + } @@ -86,44 +85,44 @@ public List getConditions() { @Override public void sample(State state, Random random) { // TODO Auto-generated method stub - + } - + public double calculateLogP() { logP = 0; - + double[] growthRates = new double[NeLog.size()-1]; //loop over all time points for (int j = 1; j < NeLog.size(); j++){ double timediff = rateShifts.getValue(j) - rateShifts.getValue(j-1); - double logdiff = NeLog.get(j) - NeLog.get(j-1); - growthRates[j-1] = logdiff/timediff; + double logdiff = NeLog.get(j) - NeLog.get(j-1); + growthRates[j-1] = logdiff/timediff; } - - + + for (int j = 1; j < growthRates.length; j++){ double diff = growthRates[j]-growthRates[j-1]; logP += dist.logDensity(diff); } - + // add contribution from first or last entry if (initDistrInput.get()!=null) logP += initDistr.logDensity(growthRates[0]); if (finalDistrInput.get()!=null) { logP += finalDistr.logDensity(growthRates[growthRates.length-1]); } - + if (meanDistrInput.get()!=null) { double mean=0.0; for (int j = 0; j < growthRates.length; j++){ - mean += growthRates[j]; + mean += growthRates[j]; } mean /= growthRates.length; logP += meanDistr.logDensity(mean); } - + return logP; } } diff --git a/src/main/java/mascot/skyline/LogSmoothingPrior.java b/src/main/java/mascot/skyline/LogSmoothingPrior.java index aefde17..7b17cb0 100644 --- a/src/main/java/mascot/skyline/LogSmoothingPrior.java +++ b/src/main/java/mascot/skyline/LogSmoothingPrior.java @@ -5,9 +5,9 @@ import beast.base.core.Input.Validate; import beast.base.inference.Distribution; import beast.base.inference.State; -import beast.base.inference.distribution.ParametricDistribution; import beast.base.spec.domain.Real; -import beast.base.spec.inference.parameter.RealVectorParam; +import beast.base.spec.inference.distribution.ScalarDistribution; +import beast.base.spec.type.RealVector; import java.util.List; import java.util.Random; @@ -18,39 +18,39 @@ " PLOS Computational Biology 21(9):e1013421,\n"+ " https://doi.org/10.1371/journal.pcbi.1013421") public class LogSmoothingPrior extends Distribution { - - public Input> NeLogInput = new Input<>( - "NeLog", "input of effective population sizes"); - - final public Input distInput = new Input<>("distr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + + public Input> NeLogInput = new Input<>( + "NeLog", "input of effective population sizes"); + + final public Input> distInput = new Input<>("distr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Validate.REQUIRED); - - final public Input initDistrInput = new Input<>("initialDistr", + + final public Input> initDistrInput = new Input<>("initialDistr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + Input.Validate.OPTIONAL); + + final public Input> finalDistrInput = new Input<>("finalDistr", "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Input.Validate.OPTIONAL); - - final public Input finalDistrInput = new Input<>("finalDistr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", + + final public Input> meanDistrInput = new Input<>("meanDistr", + "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", Input.Validate.OPTIONAL); - - final public Input meanDistrInput = new Input<>("meanDistr", - "distribution used to calculate prior on the difference between intervals, e.g. normal, beta, gamma.", - Input.Validate.OPTIONAL); - - private RealVectorParam NeLog; - - protected ParametricDistribution dist; - protected ParametricDistribution initDistr; - protected ParametricDistribution finalDistr; - protected ParametricDistribution meanDistr; - + + private RealVector NeLog; + + protected ScalarDistribution dist; + protected ScalarDistribution initDistr; + protected ScalarDistribution finalDistr; + protected ScalarDistribution meanDistr; + @Override public void initAndValidate() { - NeLog = NeLogInput.get(); + NeLog = NeLogInput.get(); dist = distInput.get(); - - + + if (initDistrInput.get()!=null) initDistr = initDistrInput.get(); if (finalDistrInput.get()!=null) @@ -77,14 +77,14 @@ public List getConditions() { @Override public void sample(State state, Random random) { // TODO Auto-generated method stub - + } - + public double calculateLogP() { logP = 0; - - - + + + //loop over all time points for (int j = 1; j < NeLog.size(); j++){ double diff = NeLog.get(j) - NeLog.get(j-1); @@ -106,7 +106,7 @@ public double calculateLogP() { mean /= NeLog.size(); logP += meanDistr.logDensity(mean); } - + return logP; } } diff --git a/src/main/java/mascot/util/LargerThan.java b/src/main/java/mascot/util/LargerThan.java index eb65dab..367c785 100644 --- a/src/main/java/mascot/util/LargerThan.java +++ b/src/main/java/mascot/util/LargerThan.java @@ -5,16 +5,17 @@ import java.util.List; import java.util.Random; import beast.base.core.Description; -import beast.base.core.Function; import beast.base.core.Input; import beast.base.core.Input.Validate; import beast.base.inference.*; +import beast.base.spec.domain.Real; +import beast.base.spec.type.RealVector; @Description("returns 0 if condition is met and negative infinity if not") public class LargerThan extends Distribution { - final public Input largerInput = new Input<>("larger", "argument for which the differences for entries is calculated", Validate.REQUIRED); - final public Input smallerInput = new Input<>("smaller", "argument for which the differences for entries is calculated", Validate.REQUIRED); + final public Input> largerInput = new Input<>("larger", "argument for which the differences for entries is calculated", Validate.REQUIRED); + final public Input> smallerInput = new Input<>("smaller", "argument for which the differences for entries is calculated", Validate.REQUIRED); @Override @@ -24,10 +25,10 @@ public void initAndValidate() { @Override public double calculateLogP() { - Function larger = largerInput.get(); - Function smaller = smallerInput.get(); - for (int i = 0; i < larger.getDimension(); i++) { - if (larger.getArrayValue(i) <= smaller.getArrayValue(i)) { + RealVector larger = largerInput.get(); + RealVector smaller = smallerInput.get(); + for (int i = 0; i < larger.size(); i++) { + if (larger.get(i) <= smaller.get(i)) { logP = Double.NEGATIVE_INFINITY; return Double.NEGATIVE_INFINITY; } @@ -40,7 +41,7 @@ public double calculateLogP() { @Override public void sample(State state, Random random) { // do nothing - + }