Distributions¶
categorical
¶
ReparameterizedCategorical
¶
Bases: Distribution
A categorical distribution with reparameterized sampling using the Gumbel-Softmax trick.
This class extends the torch.distributions.Categorical distribution to allow for reparameterized sampling, which enables gradient-based optimization techniques.
Attributes:
Name | Type | Description |
---|---|---|
_categorical |
Categorical
|
The underlying categorical distribution. |
temperature |
float
|
The temperature parameter for the Gumbel-Softmax distribution. |
Source code in vambn/modelling/distributions/categorical.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
|
batch_shape: torch.Size
property
¶
Returns the shape of the batch of distributions. Returns: torch.Size: The shape of the batch of distributions.
event_shape: torch.Size
property
¶
Returns the shape of the event of the distribution. Returns: torch.Size: The shape of the event of the distribution.
mode: torch.Tensor
property
¶
Returns the mode of the distribution.
param_shape: torch.Size
property
¶
Returns the shape of the parameter tensor. Returns: torch.Size: The shape of the parameter tensor.
support: torch.Tensor
property
¶
Returns the support of the distribution. Returns: torch.Tensor: The support of the distribution.
__init__(logits=None, probs=None, temperature=1.0)
¶
Initialize the Reparameterized Categorical Distribution. Args: logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities). probs (Optional[torch.Tensor]): A tensor of probabilities. temperature (float): A temperature parameter for the Gumbel-Softmax distribution.
Source code in vambn/modelling/distributions/categorical.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
|
arg_constraints()
¶
Returns the argument constraints of the distribution.
Returns:
Type | Description |
---|---|
Dict[str, Constraint]
|
Dict[str, Constraint]: Constraint dictionary. |
Source code in vambn/modelling/distributions/categorical.py
114 115 116 117 118 119 120 121 122 |
|
log_prob(value)
¶
Calculate the log_prob of a value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Tensor
|
Input value. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Log probability of the input value. |
Source code in vambn/modelling/distributions/categorical.py
101 102 103 104 105 106 107 108 109 110 111 112 |
|
mean()
¶
Returns the mean of the distribution.
Source code in vambn/modelling/distributions/categorical.py
131 132 133 134 135 |
|
rsample(sample_shape=torch.Size())
¶
Reparameterized sampling using Gumbel-Softmax trick.
Source code in vambn/modelling/distributions/categorical.py
87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
sample(sample_shape=torch.Size())
¶
Draws a sample from the distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_shape |
Size
|
The shape of the sample to draw. Defaults to torch.Size(). |
Size()
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The drawn sample. |
Source code in vambn/modelling/distributions/categorical.py
75 76 77 78 79 80 81 82 83 84 85 |
|
gumbel_distribution
¶
GumbelDistribution
¶
Bases: ExpRelaxedCategorical
Gumbel distribution based on the ExpRelaxedCategorical distribution.
Source code in vambn/modelling/distributions/gumbel_distribution.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
|
mean
property
¶
Returns the mean of the Gumbel distribution.
Returns:
Type | Description |
---|---|
torch.Tensor: The mean of the distribution. |
mode
property
¶
Returns the mode of the Gumbel distribution.
Returns:
Type | Description |
---|---|
torch.Tensor: The mode of the distribution. |
probs
property
¶
Returns the probabilities associated with the Gumbel distribution.
Returns:
Type | Description |
---|---|
torch.Tensor: The probabilities. |
expand(batch_shape, _instance=None)
¶
Expands the Gumbel distribution to the given batch shape.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_shape |
Size
|
The desired batch shape. |
required |
_instance |
The instance to expand. |
None
|
Returns:
Name | Type | Description |
---|---|---|
GumbelDistribution |
The expanded Gumbel distribution. |
Source code in vambn/modelling/distributions/gumbel_distribution.py
70 71 72 73 74 75 76 77 78 79 80 81 |
|
log_prob(value)
¶
Calculates the log probability of a value under the Gumbel distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Tensor
|
The value for which to calculate the log probability. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: The log probability of the value. |
Source code in vambn/modelling/distributions/gumbel_distribution.py
83 84 85 86 87 88 89 90 91 92 93 94 |
|
rsample(sample_shape=torch.Size())
¶
Reparameterized sampling for the Gumbel distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_shape |
Size
|
The shape of the sample to draw. Defaults to torch.Size(). |
Size()
|
Returns:
Type | Description |
---|---|
torch.Tensor: The reparameterized sample. |
Source code in vambn/modelling/distributions/gumbel_distribution.py
37 38 39 40 41 42 43 44 45 46 47 |
|
sample(sample_shape=torch.Size())
¶
Draws a sample from the Gumbel distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_shape |
Size
|
The shape of the sample to draw. Defaults to torch.Size(). |
Size()
|
Returns:
Type | Description |
---|---|
torch.Tensor: The drawn sample. |
Source code in vambn/modelling/distributions/gumbel_distribution.py
23 24 25 26 27 28 29 30 31 32 33 34 35 |
|
gumbel_softmax
¶
gumbel_softmax(logits, shape, tau=1.0, hard=False)
¶
Gumbel-Softmax implementation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits |
Tensor
|
Logits to be used for the Gumbel-Softmax. |
required |
shape |
Tuple[int, int]
|
Shape of the logits. Required for torchscript. |
required |
tau |
float
|
Temperature factor. Defaults to 1.0. |
1.0
|
hard |
bool
|
Hard sampling or soft. Defaults to False. |
False
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Sampled categorical distribution. |
Raises:
Type | Description |
---|---|
ValueError
|
If logits contain NaN values. |
Source code in vambn/modelling/distributions/gumbel_softmax.py
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
sample_gumbel(shape, eps=1e-09, device=torch.device('cpu'))
¶
Generate a sample from the Gumbel distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shape |
Tuple[int, int]
|
Shape of the sample. |
required |
eps |
float
|
Value to be added to avoid numerical issues. Defaults to 1e-9. |
1e-09
|
device |
device
|
The device to generate the sample on. Defaults to torch.device("cpu"). |
device('cpu')
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Sample from the Gumbel distribution. |
Source code in vambn/modelling/distributions/gumbel_softmax.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
|
parameters
¶
CategoricalParameters
dataclass
¶
Bases: Parameters
Dataclass for categorical output
Source code in vambn/modelling/distributions/parameters.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
|
NormalParameters
dataclass
¶
Bases: Parameters
Dataclass for real output
Source code in vambn/modelling/distributions/parameters.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
|
Parameters
dataclass
¶
Bases: ABC
Dataclass for parameter output
Source code in vambn/modelling/distributions/parameters.py
21 22 23 24 25 26 27 28 29 30 |
|
PoissonParameters
dataclass
¶
Bases: Parameters
Dataclass for count output
Source code in vambn/modelling/distributions/parameters.py
69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
truncated_normal
¶
TruncatedNormal
¶
Bases: Normal
TruncatedNormal distribution with support [low, high].
This class extends the Normal distribution to support truncation, i.e., the distribution is limited to the range [low, high].
Source code in vambn/modelling/distributions/truncated_normal.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
|
__init__(loc, scale, low=-torch.tensor(float('inf')), high=torch.tensor(float('inf')), validate_args=None)
¶
Initialize the TruncatedNormal distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loc |
Tensor
|
The mean of the normal distribution. |
required |
scale |
Tensor
|
The standard deviation of the normal distribution. |
required |
low |
Optional[Tensor]
|
The lower bound of the truncation. Defaults to -inf. |
-tensor(float('inf'))
|
high |
Optional[Tensor]
|
The upper bound of the truncation. Defaults to inf. |
tensor(float('inf'))
|
validate_args |
Optional[bool]
|
Whether to validate arguments. Defaults to None. |
None
|
Source code in vambn/modelling/distributions/truncated_normal.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
|
cdf(value)
¶
Calculate the cumulative distribution function (CDF) of a value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Tensor
|
Input value. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: CDF of the input value. |
Source code in vambn/modelling/distributions/truncated_normal.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|
icdf(value)
¶
Calculate the inverse cumulative distribution function (ICDF) of a value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Tensor
|
Input value. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: ICDF of the input value. |
Source code in vambn/modelling/distributions/truncated_normal.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
|
log_prob(value)
¶
Calculate the log probability of a value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Tensor
|
Input value. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor: Log probability of the input value. |
Source code in vambn/modelling/distributions/truncated_normal.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
|
rsample(sample_shape=torch.Size())
¶
Draws a reparameterized sample from the distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_shape |
Size
|
The shape of the sample to draw. Defaults to torch.Size(). |
Size()
|
Returns:
Type | Description |
---|---|
torch.Tensor: The reparameterized sample. |
Source code in vambn/modelling/distributions/truncated_normal.py
71 72 73 74 75 76 77 78 79 80 81 82 |
|
sample(sample_shape=torch.Size())
¶
Draws a sample from the distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_shape |
Size
|
The shape of the sample to draw. Defaults to torch.Size(). |
Size()
|
Returns:
Type | Description |
---|---|
torch.Tensor: The drawn sample. |
Source code in vambn/modelling/distributions/truncated_normal.py
58 59 60 61 62 63 64 65 66 67 68 69 |
|