Skip to content

Distributions

categorical

ReparameterizedCategorical

Bases: Distribution

A categorical distribution with reparameterized sampling using the Gumbel-Softmax trick.

This class extends the torch.distributions.Categorical distribution to allow for reparameterized sampling, which enables gradient-based optimization techniques.

Attributes:

Name Type Description
_categorical Categorical

The underlying categorical distribution.

temperature float

The temperature parameter for the Gumbel-Softmax distribution.

Source code in vambn/modelling/distributions/categorical.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class ReparameterizedCategorical(Distribution):
    """
    A categorical distribution with reparameterized sampling using the Gumbel-Softmax trick.

    This class extends the torch.distributions.Categorical distribution to allow for
    reparameterized sampling, which enables gradient-based optimization techniques.

    Attributes:
        _categorical (torch.distributions.Categorical): The underlying categorical distribution.
        temperature (float): The temperature parameter for the Gumbel-Softmax distribution.
    """

    def __init__(
        self,
        logits: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        temperature: float = 1.0,
    ):
        """
        Initialize the Reparameterized Categorical Distribution.
        Args:
            logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities).
            probs (Optional[torch.Tensor]): A tensor of probabilities.
            temperature (float): A temperature parameter for the Gumbel-Softmax distribution.
        """
        self._categorical = torch.distributions.Categorical(
            logits=logits, probs=probs
        )
        self.temperature = temperature

    @property
    def param_shape(self) -> torch.Size:
        """
        Returns the shape of the parameter tensor.
        Returns:
            torch.Size: The shape of the parameter tensor.
        """
        return self._categorical.param_shape

    @property
    def batch_shape(self) -> torch.Size:
        """
        Returns the shape of the batch of distributions.
        Returns:
            torch.Size: The shape of the batch of distributions.
        """
        return self._categorical.batch_shape

    @property
    def event_shape(self) -> torch.Size:
        """
        Returns the shape of the event of the distribution.
        Returns:
            torch.Size: The shape of the event of the distribution.
        """
        return self._categorical.event_shape

    @property
    def support(self) -> torch.Tensor:
        """
        Returns the support of the distribution.
        Returns:
            torch.Tensor: The support of the distribution.
        """
        return self._categorical.support

    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Draws a sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        return self._categorical.sample(sample_shape)

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Reparameterized sampling using Gumbel-Softmax trick.
        """
        if torch.any(torch.isnan(self._categorical.logits)):
            raise Exception("NaN values in logits")
        samples = gumbel_softmax(
            logits=self._categorical.logits,
            shape=tuple(self._categorical.logits.shape),
            tau=self.temperature,
            hard=False,
        )
        return samples

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log_prob of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: Log probability of the input value.
        """

        return self._categorical.log_prob(value)

    @constraints.dependent_property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        """
        Returns the argument constraints of the distribution.

        Returns:
            Dict[str, Constraint]: Constraint dictionary.
        """
        return self._categorical.arg_constraints

    @property
    def mode(self) -> torch.Tensor:
        """
        Returns the mode of the distribution.
        """
        return self._categorical.mode.unsqueeze(-1)

    def mean(self) -> torch.Tensor:
        """
        Returns the mean of the distribution.
        """
        return self._categorical.mean

batch_shape: torch.Size property

Returns the shape of the batch of distributions. Returns: torch.Size: The shape of the batch of distributions.

event_shape: torch.Size property

Returns the shape of the event of the distribution. Returns: torch.Size: The shape of the event of the distribution.

mode: torch.Tensor property

Returns the mode of the distribution.

param_shape: torch.Size property

Returns the shape of the parameter tensor. Returns: torch.Size: The shape of the parameter tensor.

support: torch.Tensor property

Returns the support of the distribution. Returns: torch.Tensor: The support of the distribution.

__init__(logits=None, probs=None, temperature=1.0)

Initialize the Reparameterized Categorical Distribution. Args: logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities). probs (Optional[torch.Tensor]): A tensor of probabilities. temperature (float): A temperature parameter for the Gumbel-Softmax distribution.

Source code in vambn/modelling/distributions/categorical.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    logits: Optional[torch.Tensor] = None,
    probs: Optional[torch.Tensor] = None,
    temperature: float = 1.0,
):
    """
    Initialize the Reparameterized Categorical Distribution.
    Args:
        logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities).
        probs (Optional[torch.Tensor]): A tensor of probabilities.
        temperature (float): A temperature parameter for the Gumbel-Softmax distribution.
    """
    self._categorical = torch.distributions.Categorical(
        logits=logits, probs=probs
    )
    self.temperature = temperature

arg_constraints()

Returns the argument constraints of the distribution.

Returns:

Type Description
Dict[str, Constraint]

Dict[str, Constraint]: Constraint dictionary.

Source code in vambn/modelling/distributions/categorical.py
114
115
116
117
118
119
120
121
122
@constraints.dependent_property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
    """
    Returns the argument constraints of the distribution.

    Returns:
        Dict[str, Constraint]: Constraint dictionary.
    """
    return self._categorical.arg_constraints

log_prob(value)

Calculate the log_prob of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description
Tensor

torch.Tensor: Log probability of the input value.

Source code in vambn/modelling/distributions/categorical.py
101
102
103
104
105
106
107
108
109
110
111
112
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """
    Calculate the log_prob of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: Log probability of the input value.
    """

    return self._categorical.log_prob(value)

mean()

Returns the mean of the distribution.

Source code in vambn/modelling/distributions/categorical.py
131
132
133
134
135
def mean(self) -> torch.Tensor:
    """
    Returns the mean of the distribution.
    """
    return self._categorical.mean

rsample(sample_shape=torch.Size())

Reparameterized sampling using Gumbel-Softmax trick.

Source code in vambn/modelling/distributions/categorical.py
87
88
89
90
91
92
93
94
95
96
97
98
99
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
    """
    Reparameterized sampling using Gumbel-Softmax trick.
    """
    if torch.any(torch.isnan(self._categorical.logits)):
        raise Exception("NaN values in logits")
    samples = gumbel_softmax(
        logits=self._categorical.logits,
        shape=tuple(self._categorical.logits.shape),
        tau=self.temperature,
        hard=False,
    )
    return samples

sample(sample_shape=torch.Size())

Draws a sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description
Tensor

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/categorical.py
75
76
77
78
79
80
81
82
83
84
85
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
    """
    Draws a sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    return self._categorical.sample(sample_shape)

gumbel_distribution

GumbelDistribution

Bases: ExpRelaxedCategorical

Gumbel distribution based on the ExpRelaxedCategorical distribution.

Source code in vambn/modelling/distributions/gumbel_distribution.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class GumbelDistribution(ExpRelaxedCategorical):
    """
    Gumbel distribution based on the ExpRelaxedCategorical distribution.
    """

    @property
    def probs(self):
        """
        Returns the probabilities associated with the Gumbel distribution.

        Returns:
            torch.Tensor: The probabilities.
        """
        return torch.exp(self.logits).clip(1e-6, 1 - 1e-6)

    @torch.no_grad()
    def sample(self, sample_shape=torch.Size()):
        """
        Draws a sample from the Gumbel distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        """
        Reparameterized sampling for the Gumbel distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The reparameterized sample.
        """
        return torch.exp(super().rsample(sample_shape))

    @property
    def mean(self):
        """
        Returns the mean of the Gumbel distribution.

        Returns:
            torch.Tensor: The mean of the distribution.
        """
        return self.probs.clip(1e-6, 1 - 1e-6)

    @property
    def mode(self):
        """
        Returns the mode of the Gumbel distribution.

        Returns:
            torch.Tensor: The mode of the distribution.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).mode

    def expand(self, batch_shape, _instance=None):
        """
        Expands the Gumbel distribution to the given batch shape.

        Args:
            batch_shape (torch.Size): The desired batch shape.
            _instance: The instance to expand.

        Returns:
            GumbelDistribution: The expanded Gumbel distribution.
        """
        return super().expand(batch_shape[:-1], _instance)

    def log_prob(self, value):
        """
        Calculates the log probability of a value under the Gumbel distribution.

        Args:
            value (torch.Tensor): The value for which to calculate the log probability.

        Returns:
            torch.Tensor: The log probability of the value.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).log_prob(value)

mean property

Returns the mean of the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The mean of the distribution.

mode property

Returns the mode of the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The mode of the distribution.

probs property

Returns the probabilities associated with the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The probabilities.

expand(batch_shape, _instance=None)

Expands the Gumbel distribution to the given batch shape.

Parameters:

Name Type Description Default
batch_shape Size

The desired batch shape.

required
_instance

The instance to expand.

None

Returns:

Name Type Description
GumbelDistribution

The expanded Gumbel distribution.

Source code in vambn/modelling/distributions/gumbel_distribution.py
70
71
72
73
74
75
76
77
78
79
80
81
def expand(self, batch_shape, _instance=None):
    """
    Expands the Gumbel distribution to the given batch shape.

    Args:
        batch_shape (torch.Size): The desired batch shape.
        _instance: The instance to expand.

    Returns:
        GumbelDistribution: The expanded Gumbel distribution.
    """
    return super().expand(batch_shape[:-1], _instance)

log_prob(value)

Calculates the log probability of a value under the Gumbel distribution.

Parameters:

Name Type Description Default
value Tensor

The value for which to calculate the log probability.

required

Returns:

Type Description

torch.Tensor: The log probability of the value.

Source code in vambn/modelling/distributions/gumbel_distribution.py
83
84
85
86
87
88
89
90
91
92
93
94
def log_prob(self, value):
    """
    Calculates the log probability of a value under the Gumbel distribution.

    Args:
        value (torch.Tensor): The value for which to calculate the log probability.

    Returns:
        torch.Tensor: The log probability of the value.
    """
    probs = self.probs.clip(1e-6, 1 - 1e-6)
    return OneHotCategorical(probs=probs).log_prob(value)

rsample(sample_shape=torch.Size())

Reparameterized sampling for the Gumbel distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The reparameterized sample.

Source code in vambn/modelling/distributions/gumbel_distribution.py
37
38
39
40
41
42
43
44
45
46
47
def rsample(self, sample_shape=torch.Size()):
    """
    Reparameterized sampling for the Gumbel distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The reparameterized sample.
    """
    return torch.exp(super().rsample(sample_shape))

sample(sample_shape=torch.Size())

Draws a sample from the Gumbel distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/gumbel_distribution.py
23
24
25
26
27
28
29
30
31
32
33
34
35
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
    """
    Draws a sample from the Gumbel distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    probs = self.probs.clip(1e-6, 1 - 1e-6)
    return OneHotCategorical(probs=probs).sample(sample_shape)

gumbel_softmax

gumbel_softmax(logits, shape, tau=1.0, hard=False)

Gumbel-Softmax implementation.

Parameters:

Name Type Description Default
logits Tensor

Logits to be used for the Gumbel-Softmax.

required
shape Tuple[int, int]

Shape of the logits. Required for torchscript.

required
tau float

Temperature factor. Defaults to 1.0.

1.0
hard bool

Hard sampling or soft. Defaults to False.

False

Returns:

Type Description
Tensor

torch.Tensor: Sampled categorical distribution.

Raises:

Type Description
ValueError

If logits contain NaN values.

Source code in vambn/modelling/distributions/gumbel_softmax.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def gumbel_softmax(
    logits: torch.Tensor,
    shape: Tuple[int, int],
    tau: float = 1.0,
    hard: bool = False,
) -> torch.Tensor:
    """
    Gumbel-Softmax implementation.

    Args:
        logits (torch.Tensor): Logits to be used for the Gumbel-Softmax.
        shape (Tuple[int, int]): Shape of the logits. Required for torchscript.
        tau (float, optional): Temperature factor. Defaults to 1.0.
        hard (bool, optional): Hard sampling or soft. Defaults to False.

    Returns:
        torch.Tensor: Sampled categorical distribution.

    Raises:
        ValueError: If logits contain NaN values.
    """
    if torch.isnan(logits).any():
        raise ValueError("Logits contain NaN values")

    gumbel_noise = sample_gumbel(shape, device=logits.device)
    y = logits + gumbel_noise
    tau = max(tau, 1e-9)
    y_soft = torch.softmax(y / (tau), dim=-1)

    if hard:
        _, ind = y_soft.max(dim=-1)
        y_hard = torch.zeros_like(y_soft).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(shape[0], shape[1])
        y_soft = (y_hard - y_soft).detach() + y_soft

    return y_soft

sample_gumbel(shape, eps=1e-09, device=torch.device('cpu'))

Generate a sample from the Gumbel distribution.

Parameters:

Name Type Description Default
shape Tuple[int, int]

Shape of the sample.

required
eps float

Value to be added to avoid numerical issues. Defaults to 1e-9.

1e-09
device device

The device to generate the sample on. Defaults to torch.device("cpu").

device('cpu')

Returns:

Type Description
Tensor

torch.Tensor: Sample from the Gumbel distribution.

Source code in vambn/modelling/distributions/gumbel_softmax.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sample_gumbel(
    shape: Tuple[int, int],
    eps: float = 1e-9,
    device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """
    Generate a sample from the Gumbel distribution.

    Args:
        shape (Tuple[int, int]): Shape of the sample.
        eps (float, optional): Value to be added to avoid numerical issues. Defaults to 1e-9.
        device (torch.device, optional): The device to generate the sample on. Defaults to torch.device("cpu").

    Returns:
        torch.Tensor: Sample from the Gumbel distribution.
    """
    U = torch.rand(shape, device=device)
    return -torch.log(-torch.log(U + eps) + eps)

parameters

CategoricalParameters dataclass

Bases: Parameters

Dataclass for categorical output

Source code in vambn/modelling/distributions/parameters.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class CategoricalParameters(Parameters):
    """Dataclass for categorical output"""

    logits: Tensor
    name: str = field(kw_only=True, default="cat")

    @property
    def device(self) -> torch.device:
        return self.logits.device

    @property
    def probs(self) -> Tensor:
        return logits_to_probs(self.logits)

    def __str__(self) -> str:
        return f"CatOutput: Avg. logits {self.logits.mean()}; shape {self.logits.shape}"

NormalParameters dataclass

Bases: Parameters

Dataclass for real output

Source code in vambn/modelling/distributions/parameters.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@dataclass
class NormalParameters(Parameters):
    """Dataclass for real output"""

    loc: Tensor
    scale: Tensor
    name: str = field(kw_only=True, default="real")

    @property
    def device(self) -> torch.device:
        return self.loc.device

    def __str__(self) -> str:
        return f"RealOutput: Avg. Mean {self.loc.mean()}; Avg. Std {self.scale.mean()}); shape {self.loc.shape}"

Parameters dataclass

Bases: ABC

Dataclass for parameter output

Source code in vambn/modelling/distributions/parameters.py
21
22
23
24
25
26
27
28
29
30
@dataclass
class Parameters(ABC):
    """Dataclass for parameter output"""

    name: str = field(kw_only=True, default="")

    @property
    @abstractmethod
    def device(self) -> torch.device:
        pass

PoissonParameters dataclass

Bases: Parameters

Dataclass for count output

Source code in vambn/modelling/distributions/parameters.py
69
70
71
72
73
74
75
76
77
78
79
80
81
@dataclass
class PoissonParameters(Parameters):
    """Dataclass for count output"""

    rate: Tensor
    name: str = field(kw_only=True, default="count")

    @property
    def device(self) -> torch.device:
        return self.rate.device

    def __str__(self) -> str:
        return f"CountOutput: Avg. Lambda {self.rate.mean()}; shape {self.rate.shape}"

truncated_normal

TruncatedNormal

Bases: Normal

TruncatedNormal distribution with support [low, high].

This class extends the Normal distribution to support truncation, i.e., the distribution is limited to the range [low, high].

Source code in vambn/modelling/distributions/truncated_normal.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class TruncatedNormal(Normal):
    """
    TruncatedNormal distribution with support [low, high].

    This class extends the Normal distribution to support truncation, i.e., the
    distribution is limited to the range [low, high].
    """

    arg_constraints = {
        "loc": constraints.real,
        "scale": constraints.positive,
        "low": constraints.real,
        "high": constraints.real,
    }

    def __init__(
        self,
        loc: torch.Tensor,
        scale: torch.Tensor,
        low: Optional[torch.Tensor] = -torch.tensor(float("inf")),
        high: Optional[torch.Tensor] = torch.tensor(float("inf")),
        validate_args: Optional[bool] = None,
    ):
        """
        Initialize the TruncatedNormal distribution.

        Args:
            loc (torch.Tensor): The mean of the normal distribution.
            scale (torch.Tensor): The standard deviation of the normal distribution.
            low (Optional[torch.Tensor]): The lower bound of the truncation. Defaults to -inf.
            high (Optional[torch.Tensor]): The upper bound of the truncation. Defaults to inf.
            validate_args (Optional[bool]): Whether to validate arguments. Defaults to None.
        """
        self.low = low
        self.high = high
        super().__init__(loc, scale, validate_args=validate_args)

    def _clamp(self, x):
        """
        Clamp the values to the range [low, high].

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The clamped tensor.
        """
        clamped_x = torch.clamp(x, self.low, self.high)
        x_mask = (x < self.low) | (x > self.high)
        x_fill = torch.where(x < self.low, self.low, self.high)
        return torch.where(x_mask, x_fill, clamped_x)

    def sample(self, sample_shape=torch.Size()):
        """
        Draws a sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        with torch.no_grad():
            return self._clamp(super().sample(sample_shape))

    def rsample(self, sample_shape=torch.Size()):
        """
        Draws a reparameterized sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The reparameterized sample.
        """
        with torch.no_grad():
            return self._clamp(super().rsample(sample_shape))

    def log_prob(self, value):
        """
        Calculate the log probability of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: Log probability of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        log_prob = super().log_prob(value)
        log_prob = torch.where(
            (value < self.low) | (value > self.high),
            torch.log(torch.tensor(1e-12)),
            log_prob,
        )
        normalizer = torch.log(self.cdf(self.high) - self.cdf(self.low))
        return log_prob - normalizer

    def cdf(self, value):
        """
        Calculate the cumulative distribution function (CDF) of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: CDF of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        cdf = super().cdf(value)
        low_cdf = super().cdf(self.low)
        high_cdf = super().cdf(self.high)
        return (cdf - low_cdf) / (high_cdf - low_cdf)

    def icdf(self, value):
        """
        Calculate the inverse cumulative distribution function (ICDF) of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: ICDF of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        low_cdf = super().cdf(self.low)
        high_cdf = super().cdf(self.high)
        rescaled_value = low_cdf + (high_cdf - low_cdf) * value
        return super().icdf(rescaled_value)

__init__(loc, scale, low=-torch.tensor(float('inf')), high=torch.tensor(float('inf')), validate_args=None)

Initialize the TruncatedNormal distribution.

Parameters:

Name Type Description Default
loc Tensor

The mean of the normal distribution.

required
scale Tensor

The standard deviation of the normal distribution.

required
low Optional[Tensor]

The lower bound of the truncation. Defaults to -inf.

-tensor(float('inf'))
high Optional[Tensor]

The upper bound of the truncation. Defaults to inf.

tensor(float('inf'))
validate_args Optional[bool]

Whether to validate arguments. Defaults to None.

None
Source code in vambn/modelling/distributions/truncated_normal.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    loc: torch.Tensor,
    scale: torch.Tensor,
    low: Optional[torch.Tensor] = -torch.tensor(float("inf")),
    high: Optional[torch.Tensor] = torch.tensor(float("inf")),
    validate_args: Optional[bool] = None,
):
    """
    Initialize the TruncatedNormal distribution.

    Args:
        loc (torch.Tensor): The mean of the normal distribution.
        scale (torch.Tensor): The standard deviation of the normal distribution.
        low (Optional[torch.Tensor]): The lower bound of the truncation. Defaults to -inf.
        high (Optional[torch.Tensor]): The upper bound of the truncation. Defaults to inf.
        validate_args (Optional[bool]): Whether to validate arguments. Defaults to None.
    """
    self.low = low
    self.high = high
    super().__init__(loc, scale, validate_args=validate_args)

cdf(value)

Calculate the cumulative distribution function (CDF) of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: CDF of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def cdf(self, value):
    """
    Calculate the cumulative distribution function (CDF) of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: CDF of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    cdf = super().cdf(value)
    low_cdf = super().cdf(self.low)
    high_cdf = super().cdf(self.high)
    return (cdf - low_cdf) / (high_cdf - low_cdf)

icdf(value)

Calculate the inverse cumulative distribution function (ICDF) of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: ICDF of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def icdf(self, value):
    """
    Calculate the inverse cumulative distribution function (ICDF) of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: ICDF of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    low_cdf = super().cdf(self.low)
    high_cdf = super().cdf(self.high)
    rescaled_value = low_cdf + (high_cdf - low_cdf) * value
    return super().icdf(rescaled_value)

log_prob(value)

Calculate the log probability of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: Log probability of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def log_prob(self, value):
    """
    Calculate the log probability of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: Log probability of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    log_prob = super().log_prob(value)
    log_prob = torch.where(
        (value < self.low) | (value > self.high),
        torch.log(torch.tensor(1e-12)),
        log_prob,
    )
    normalizer = torch.log(self.cdf(self.high) - self.cdf(self.low))
    return log_prob - normalizer

rsample(sample_shape=torch.Size())

Draws a reparameterized sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The reparameterized sample.

Source code in vambn/modelling/distributions/truncated_normal.py
71
72
73
74
75
76
77
78
79
80
81
82
def rsample(self, sample_shape=torch.Size()):
    """
    Draws a reparameterized sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The reparameterized sample.
    """
    with torch.no_grad():
        return self._clamp(super().rsample(sample_shape))

sample(sample_shape=torch.Size())

Draws a sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/truncated_normal.py
58
59
60
61
62
63
64
65
66
67
68
69
def sample(self, sample_shape=torch.Size()):
    """
    Draws a sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    with torch.no_grad():
        return self._clamp(super().sample(sample_shape))