Blang

Examples

We start with a simple model for Markov chains:

package blang.validation.internals.fixtures model MarkovChain { param Simplex initialDistribution param TransitionMatrix transitionProbabilities random List<IntVar> chain laws { // Initial distribution: chain.get(0) | initialDistribution ~ Categorical(initialDistribution) // Transitions: for (int step : 1 ..< chain.size) { chain.get(step) | IntVar previous = chain.get(step - 1), transitionProbabilities ~ Categorical( if (previous >= 0 && previous < transitionProbabilities.nRows) transitionProbabilities.row(previous) else transitionProbabilities.row(0) ) } } }

Notice that we used the pre-computation construct, namely

IntVar previous = chain.get(step - 1)

Accessing an array is not so much expensive, so you may wonder why we bothered pre-computing this. It turns out there is actually an important speed gain to be made, of the order of the length of the chain. Why?

To understand, we need to outline how the blang inference engines work under the hood. Most of these engines exploit cases where conditional distributions only depend on subsets of the variables. To do so, blang inspects the model constituents (for example the Categorical constituents) to infer what is the scope of the constituent. A scope is simply the subset of the variables available at a given location of the code (e.g. the code in one function cannot access the local variables declared in another function, they are out of scope). So coming back to the Markov chain example, this means that by passing in the precomputed chain.get(step - 1) rather than all the latent variables, we make it possible for blang engines to infer that each time step in the HMM only have interdependence with the previous and next state rather than all states. In graphical model parlance, this means sparsity patterns in the graphical model are inferred.

Let us look now how we can use a Markov chain as a building block for an HMM:

package blang.validation.internals.fixtures model DynamicNormalMixture { param int nLatentStates random List<RealVar> observations random List<IntVar> states ?: latentIntList(observations.size) random DenseSimplex initialDistribution ?: latentSimplex(nLatentStates) random DenseTransitionMatrix transitionProbabilities ?: latentTransitionMatrix(nLatentStates) random List<RealVar> means ?: latentRealList(nLatentStates), variances ?: latentRealList(nLatentStates) param Matrix concentrations ?: ones(nLatentStates).readOnlyView laws { // Priors on initial and transition probabilities initialDistribution | concentrations ~ Dirichlet(concentrations) for (int latentStateIdx : 0 ..< means.size) { transitionProbabilities.row(latentStateIdx) | concentrations ~ Dirichlet(concentrations) } // Priors on means and variances for (int latentStateIdx : 0 ..< means.size) { means.get(latentStateIdx) ~ Normal(0.0, 1.0) variances.get(latentStateIdx) ~ Gamma(1.0, 1.0) } states | initialDistribution, transitionProbabilities ~ MarkovChain(initialDistribution, transitionProbabilities) // Gaussian emissions for (int obsIdx : 0 ..< observations.size) { observations.get(obsIdx) | means, variances, IntVar curIndic = states.get(obsIdx) ~ Normal(means.get(curIndic), variances.get(curIndic)) } } }

Undirected graphical models (AKA Markov random fields) are supported: here is for example how a square Ising model is implemented:

package blang.validation.internals.fixtures import briefj.collections.UnorderedPair import static blang.validation.internals.fixtures.Functions.squareIsingEdges model Ising { param Double beta ?: log(1 + sqrt(2.0)) / 2.0 // critical point param Integer N ?: 5 random List<IntVar> vertices ?: latentIntList(N*N) laws { // Pairwise potentials for (UnorderedPair<Integer, Integer> pair : squareIsingEdges(N)) { | IntVar first = vertices.get(pair.getFirst), IntVar second = vertices.get(pair.getSecond), beta ~ LogPotential( if ((first < 0 || first > 1 || second < 0 || second > 1)) return NEGATIVE_INFINITY else return beta*(2*first-1)*(2*second-1)) } // Node potentials for (IntVar vertex : vertices) { vertex ~ Bernoulli(0.5) } } }

First, we create a custom data type:

package blang.validation.internals.fixtures import org.eclipse.xtend.lib.annotations.Data import blang.core.RealVar import java.util.List import xlinear.Matrix import blang.core.WritableRealVar import blang.types.StaticUtils import blang.core.WritableIntVar import blang.mcmc.Samplers @Data @Samplers(SpikedRealVarSampler) class SpikedRealVar implements RealVar { public val WritableRealVar realPart public val WritableIntVar isZero new() { realPart = StaticUtils::latentReal isZero = StaticUtils::latentInt } override doubleValue() { if (isZero == 1) 0.0 else realPart.doubleValue } def public static double *(List<SpikedRealVar> vars, Matrix vector){ var sum = 0.0 if (vars.size != vector.nEntries) { throw new RuntimeException } for (int i : 0 ..< vars.size) { sum += vector.get(i) * vars.get(i).doubleValue } return sum } override String toString() { return Double.toString(doubleValue) } }

Then a corresponding sampler:

package blang.validation.internals.fixtures; import java.util.List; import bayonet.distributions.Random; import blang.core.Constrained; import blang.core.LogScaleFactor; import blang.mcmc.ConnectedFactor; import blang.mcmc.IntSliceSampler; import blang.mcmc.RealSliceSampler; import blang.mcmc.SampledVariable; import blang.mcmc.Sampler; import blang.mcmc.internals.SamplerBuilderContext; import static blang.types.ExtensionUtils.*; public class SpikedRealVarSampler implements Sampler { @SampledVariable SpikedRealVar variable; @ConnectedFactor Constrained constrained; @ConnectedFactor List<LogScaleFactor> numericFactors; RealSliceSampler sliceSampler; IntSliceSampler intSampler; @Override public void execute(Random rand) { if (!asBool(variable.isZero.intValue())) sliceSampler.execute(rand); intSampler.execute(rand); } @Override public boolean setup(SamplerBuilderContext context) { sliceSampler = RealSliceSampler.build(variable.realPart, numericFactors); intSampler = IntSliceSampler.build(variable.isZero, numericFactors); return true; } }

We can now define a distribution for this type:

package blang.validation.internals.fixtures model SpikeAndSlab { random List<SpikedRealVar> variables param RealVar zeroProbability param RealDistribution nonZeroLogDensity laws { for (int index : 0 ..< variables.size) { logf(zeroProbability, nonZeroLogDensity, RealVar variable = variables.get(index)) { if (zeroProbability < 0.0 || zeroProbability > 1.0) return NEGATIVE_INFINITY if (variable == 0.0) { log(zeroProbability) } else { log(1.0 - zeroProbability) + nonZeroLogDensity.logDensity(variable) } } indicator(SpikedRealVar variable = variables.get(index)) { variable.isZero.isBool } variables.get(index) is Constrained } } generate(rand) { for (SpikedRealVar variable : variables) { variable.isZero.set(Generators::bernoulli(rand, zeroProbability).asInt) variable.realPart.set(nonZeroLogDensity.sample(rand)) } } }

Afterwards, it is easy to incorporate Spike and Slab priors in more complicated models, for example a naive GLM here:

package blang.validation.internals.fixtures import static extension blang.validation.internals.fixtures.SpikedRealVar.* model SpikedGLM { param Matrix designMatrix // n by p random List<IntVar> output // size n random List<SpikedRealVar> coefficients ?: { val p = designMatrix.nCols return new ArrayList(p) => [ for (int i : 0 ..< p) add(new SpikedRealVar) ] } random RealVar zeroProbability ?: latentReal laws { zeroProbability ~ Beta(1,1) coefficients | zeroProbability ~ SpikeAndSlab(zeroProbability, Normal::distribution(0, 1)) for (int index : 0 ..< output.size) { output.get(index) | coefficients, Matrix predictors = designMatrix.row(index) ~ Bernoulli(logistic(coefficients * predictors)) } } }