dists
PyTorch probability distributions re-parameterized for convenience.
The output of networks designed to provide probabilistic predictions becomes much more interpretable when they predict the mean and the standard deviation of some probability distribution rather than its (often arcance) natural parameters. To that end, this module provides re-parameterized versions of probability distributions to match the parameterization of the respective log-likelihoods implemented as loss functions.
- class MeanStdLogNormal(mean, std, validate_args=False)[source]
Bases:
LogNormalSubclass of the PyTorch
LogNormaldistribution.- Parameters:
mean (Tensor) – Mean value(s) of the Log-Normal distribution(s) in natural scale.
std (Tensor) – Standard deviation(s) of Log-Normal distribution(s) in natural scale.
validate_args (bool, optional) – Whether the parent class should validate the transformed parameters or not. Defaults to
False
See also
- class MuSigmaGamma(mu, sigma, validate_args=False)[source]
Bases:
GammaSubclass of the
Gammadistribution with intuitive parameters.- Parameters:
mu (Tensor) – Mean value(s) of the Gamma distribution(s).
sigma (Tensor) – Standard deviation(s) of the Gamma distribution(s).
validate_args (bool, optional) – Whether the parent class should validate the transformed parameters or not. Defaults to
False
See also
- class MuSigmaNegativeBinomial(mu, sigma, validate_args=False)[source]
Bases:
NegativeBinomialSubclass of the PyTorch
NegativeBinomialdistribution.- Parameters:
mu (Tensor) – Mean value(s) of the Gamma distribution(s).
sigma (Tensor) – Standard deviation(s) of the Gamma distribution(s).
validate_args (bool, optional) – Whether the parent class should validate the transformed parameters or not. Defaults to
False
Note
This parameterization only makes sense if the variance is strictly greater than the mean. This is best taken into account already on the model output side, but can be checked here by setting validate_args to
True.