dists
PyTorch probability distributions re-parameterized for convenience.
The output of networks designed to provide probabilistic predictions becomes much more interpretable when they predict the mean and the standard deviation of some probability distribution rather than its (often arcance) natural parameters. To that end, this module provides re-parameterized versions of probability distributions to match the parameterization of the respective log-likelihoods implemented as loss functions.
- class MuSigmaGamma(mu, sigma, validate_args=False)[source]
Bases:
Gamma
Subclass of the PyTorch Gamma distribution with intuitive parameters.
- Parameters:
mu (Tensor) – Mean value(s) of the Gamma distribution(s).
sigma (Tensor) – Standard deviation(s) of the Gamma distribution(s).
validate_args (bool, optional) – Whether the parent class should validate the transformed parameters or not. Defaults to
False
See also
- class MuSigmaNegativeBinomial(mu, sigma, validate_args=False)[source]
Bases:
NegativeBinomial
Subclass of the PyTorch NegativeBinomial distribution.
- Parameters:
mu (Tensor) – Mean value(s) of the Gamma distribution(s).
sigma (Tensor) – Standard deviation(s) of the Gamma distribution(s).
validate_args (bool, optional) – Whether the parent class should validate the transformed parameters or not. Defaults to
False
Note
This parameterization only makes sense if the variance is strictly greater than the mean. This is best taken into account already on the model output side, but can be checked here by setting validate_args to
True
.