PyTorch defines two components for probabilistic modeling: the Distribution and the Transform. A distribution object represents the probability distribution p(X) of a random variable X. A distribution must implement the sample and log_prob methods, meaning that we can draw realizations x∼p(X) from the distribution and evaluate the log-likelihood logp(X=x) of realizations.
distribution = torch.distributions.Normal(torch.tensor(0.0), torch.tensor(1.0))x = distribution.sample() # x ~ p(X)log_p = distribution.log_prob(x) # log p(X = x)x, log_p
Output:
(tensor(1.5410), tensor(-2.1063))
A transform object represents a bijective transformation f:X↦Y from a domain to a co-domain. A transformation must implement a forward call y=f(x), an inverse call x=f−1(y) and the log_abs_det_jacobian method to compute the log-absolute-determinant of the transformation’s Jacobian logdet∂x∂f(x).
Combining a base distribution p(Z) and a transformation f:X↦Z defines a new distribution p(X). The likelihood is given by the change of random variables formula:p(X=x)=p(Z=f(x))det∂x∂f(x)Sampling from p(X) can be performed by first drawing realizations z∼p(Z) and then applying the inverse transformation x=f−1(z). Such combination of a base distribution and a bijective transformation is sometimes called a normalizing flow. The term normalizing refers to the fact that the base distribution is often a (standard) normal distribution.
When designing the distributions module, the PyTorch team decided that distributions and transformations should be lightweight objects that are used as part of computations but destroyed afterwards. Consequently, the Distribution and Transform classes are not sub-classes of torch.nn.Module, which means that we cannot retrieve their parameters with .parameters(), send their internal tensor to GPU with .cuda() or train them as regular neural networks. In addition, the concepts of conditional distribution and transformation, which are essential for probabilistic inference, are impossible to express with the current interface.To solve these problems, zuko defines two concepts: the LazyDistribution and the LazyTransform, which are modules whose forward pass returns a distribution or transformation, respectively. These components hold the parameters of the distributions/transformations as well as the recipe to build them. This way, the actual distribution/transformation objects are lazily constructed and destroyed when necessary. Importantly, because the creation of the distribution/transformation object is delayed, an eventual condition can be easily taken into account. This design enables lazy distributions to act like distributions while retaining features inherent to modules, such as trainable parameters.
Let’s say we have a dataset of pairs (x,c)∼p(X,C) and want to model the distribution of X given c, that is p(X∣c). The goal of variational inference is to find the model qϕ⋆(X∣c) that is most similar to p(X∣c) among a family of (conditional) distributions qϕ(X∣c) distinguished by their parameters ϕ. Expressing the dissimilarity between two distributions as their Kullback-Leibler (KL) divergence, the variational inference objective becomes:ϕ∗=argϕmin=argϕmin=argϕminKL(p(x,c)∣∣qϕ(x∣c)p(c))Ep(x,c)[logqϕ(x∣c)p(c)p(x,c)]Ep(x,c)[−logqϕ(x∣c)]For example, let X be a standard Gaussian variable and C be a vector of three unit Gaussian variables Ci centered at X.
x = torch.distributions.Normal(0, 1).sample((1024,))c = torch.distributions.Normal(x, 1).sample((3,)).Tfor i in range(3): print(x[i], c[i])
Calling the forward method of the model with a context c returns a distribution object, which we can use to draw realizations or evaluate the likelihood of realizations. In the code below, model(c=c[0]) calls the forward method as implemented above.
The result of log_prob is part of a computation graph (it has a grad_fn) and therefore it can be used to train the parameters of the model by variational inference. Importantly, when the parameters of the model are modified, for example due to a gradient descent step, you must remember to call the forward method again to re-build the distribution with the new parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)for _ in range(64): loss = -model(c).log_prob(x).mean() # E[-log q(x | c)] loss.backward() optimizer.step() optimizer.zero_grad()
Following the same spirit, a parametric normalizing flow in Zuko is a special LazyDistribution that contains a LazyTransform and a base LazyDistribution. To increase expressivity, the transformation is usually the composition of a sequence of “simple” transformations:f(x)=fn∘⋯∘f2∘f1(x)For which we can compute the determinant of the Jacobian as:det∂x∂f(x)=∏i=1ndet∂xi−1∂fi(xi−1)Where x0=x and xi=fi(xi−1). In the univariate case, finding a bijective transformation whose determinant of the Jacobian is tractable is easy: any differentiable monotonic function works. In the multivariate case, the most common way to make the determinant easy to compute is to enforce a triangular Jacobian. This is achieved by a transformation y=f(x) where each element yi is a monotonic function of xi, conditioned on the preceding elements x<i.yi=f(xi∣x<i)Autoregressive and coupling transformations are notable examples of this class of transformations.
transform = zuko.flows.MaskedAutoregressiveTransform( features=5, context=0, # no context univariate=zuko.transforms.MonotonicRQSTransform, # rational-quadratic spline shapes=([8], [8], [7]), # shapes of the spline parameters (8 bins) hidden_features=(64, 128, 256), # size of the hyper-network) # fmt: skiptransform
Zuko provides many pre-built flow architectures including NICE, MAF, NSF, CNF and many others. We recommend users to try MAF and NSF first as they are efficient baselines. In the following cell, we instantiate a conditional flow (5 sample features and 8 context features) with 3 affine autoregressive transformations.
Alternatively, a flow can be built as a custom Flow object given a sequence of lazy transformations and a base lazy distribution. The following demonstrates a condensed example of many things that are possible in Zuko. But remember, with great power comes great responsibility (and great bugs).
from zuko.distributions import BoxUniformfrom zuko.flows import ( GeneralCouplingTransform, MaskedAutoregressiveTransform,)from zuko.lazy import ( Flow, UnconditionalDistribution, UnconditionalTransform,)from zuko.transforms import ( AffineTransform, MonotonicRQSTransform, RotationTransform, SigmoidTransform,)flow = Flow( transform=[ UnconditionalTransform( # [0, 255] to ]0, 1[ AffineTransform, # y = loc + scale * x torch.tensor(1 / 512), # loc torch.tensor(1 / 256), # scale buffer=True, # not trainable ), UnconditionalTransform(lambda: SigmoidTransform().inv), # y = logit(x) MaskedAutoregressiveTransform( # autoregressive transform (affine by default) features=5, context=8, passes=5, # fully-autoregressive ), UnconditionalTransform(RotationTransform, torch.randn(5, 5)), # trainable rotation GeneralCouplingTransform( # coupling transform features=5, context=8, univariate=MonotonicRQSTransform, # rational-quadratic spline shapes=([8], [8], [7]), # shapes of the spline parameters (8 bins) hidden_features=(256, 256), # size of the hyper-network activation=torch.nn.ELU, # ELU activation in hyper-network ).inv, # inverse ], base=UnconditionalDistribution( # ignore context BoxUniform, torch.full([5], -3.0), # lower bound torch.full([5], +3.0), # upper bound buffer=True, # not trainable ),) # fmt: skipflow