Skip to content

Modelling

distributions

categorical

ReparameterizedCategorical

Bases: Distribution

A categorical distribution with reparameterized sampling using the Gumbel-Softmax trick.

This class extends the torch.distributions.Categorical distribution to allow for reparameterized sampling, which enables gradient-based optimization techniques.

Attributes:

Name Type Description
_categorical Categorical

The underlying categorical distribution.

temperature float

The temperature parameter for the Gumbel-Softmax distribution.

Source code in vambn/modelling/distributions/categorical.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class ReparameterizedCategorical(Distribution):
    """
    A categorical distribution with reparameterized sampling using the Gumbel-Softmax trick.

    This class extends the torch.distributions.Categorical distribution to allow for
    reparameterized sampling, which enables gradient-based optimization techniques.

    Attributes:
        _categorical (torch.distributions.Categorical): The underlying categorical distribution.
        temperature (float): The temperature parameter for the Gumbel-Softmax distribution.
    """

    def __init__(
        self,
        logits: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        temperature: float = 1.0,
    ):
        """
        Initialize the Reparameterized Categorical Distribution.
        Args:
            logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities).
            probs (Optional[torch.Tensor]): A tensor of probabilities.
            temperature (float): A temperature parameter for the Gumbel-Softmax distribution.
        """
        self._categorical = torch.distributions.Categorical(
            logits=logits, probs=probs
        )
        self.temperature = temperature

    @property
    def param_shape(self) -> torch.Size:
        """
        Returns the shape of the parameter tensor.
        Returns:
            torch.Size: The shape of the parameter tensor.
        """
        return self._categorical.param_shape

    @property
    def batch_shape(self) -> torch.Size:
        """
        Returns the shape of the batch of distributions.
        Returns:
            torch.Size: The shape of the batch of distributions.
        """
        return self._categorical.batch_shape

    @property
    def event_shape(self) -> torch.Size:
        """
        Returns the shape of the event of the distribution.
        Returns:
            torch.Size: The shape of the event of the distribution.
        """
        return self._categorical.event_shape

    @property
    def support(self) -> torch.Tensor:
        """
        Returns the support of the distribution.
        Returns:
            torch.Tensor: The support of the distribution.
        """
        return self._categorical.support

    def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Draws a sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        return self._categorical.sample(sample_shape)

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Reparameterized sampling using Gumbel-Softmax trick.
        """
        if torch.any(torch.isnan(self._categorical.logits)):
            raise Exception("NaN values in logits")
        samples = gumbel_softmax(
            logits=self._categorical.logits,
            shape=tuple(self._categorical.logits.shape),
            tau=self.temperature,
            hard=False,
        )
        return samples

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log_prob of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: Log probability of the input value.
        """

        return self._categorical.log_prob(value)

    @constraints.dependent_property
    def arg_constraints(self) -> Dict[str, constraints.Constraint]:
        """
        Returns the argument constraints of the distribution.

        Returns:
            Dict[str, Constraint]: Constraint dictionary.
        """
        return self._categorical.arg_constraints

    @property
    def mode(self) -> torch.Tensor:
        """
        Returns the mode of the distribution.
        """
        return self._categorical.mode.unsqueeze(-1)

    def mean(self) -> torch.Tensor:
        """
        Returns the mean of the distribution.
        """
        return self._categorical.mean
batch_shape: torch.Size property

Returns the shape of the batch of distributions. Returns: torch.Size: The shape of the batch of distributions.

event_shape: torch.Size property

Returns the shape of the event of the distribution. Returns: torch.Size: The shape of the event of the distribution.

mode: torch.Tensor property

Returns the mode of the distribution.

param_shape: torch.Size property

Returns the shape of the parameter tensor. Returns: torch.Size: The shape of the parameter tensor.

support: torch.Tensor property

Returns the support of the distribution. Returns: torch.Tensor: The support of the distribution.

__init__(logits=None, probs=None, temperature=1.0)

Initialize the Reparameterized Categorical Distribution. Args: logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities). probs (Optional[torch.Tensor]): A tensor of probabilities. temperature (float): A temperature parameter for the Gumbel-Softmax distribution.

Source code in vambn/modelling/distributions/categorical.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    logits: Optional[torch.Tensor] = None,
    probs: Optional[torch.Tensor] = None,
    temperature: float = 1.0,
):
    """
    Initialize the Reparameterized Categorical Distribution.
    Args:
        logits (Optional[torch.Tensor]): A tensor of logits (unnormalized log probabilities).
        probs (Optional[torch.Tensor]): A tensor of probabilities.
        temperature (float): A temperature parameter for the Gumbel-Softmax distribution.
    """
    self._categorical = torch.distributions.Categorical(
        logits=logits, probs=probs
    )
    self.temperature = temperature
arg_constraints()

Returns the argument constraints of the distribution.

Returns:

Type Description
Dict[str, Constraint]

Dict[str, Constraint]: Constraint dictionary.

Source code in vambn/modelling/distributions/categorical.py
114
115
116
117
118
119
120
121
122
@constraints.dependent_property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
    """
    Returns the argument constraints of the distribution.

    Returns:
        Dict[str, Constraint]: Constraint dictionary.
    """
    return self._categorical.arg_constraints
log_prob(value)

Calculate the log_prob of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description
Tensor

torch.Tensor: Log probability of the input value.

Source code in vambn/modelling/distributions/categorical.py
101
102
103
104
105
106
107
108
109
110
111
112
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
    """
    Calculate the log_prob of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: Log probability of the input value.
    """

    return self._categorical.log_prob(value)
mean()

Returns the mean of the distribution.

Source code in vambn/modelling/distributions/categorical.py
131
132
133
134
135
def mean(self) -> torch.Tensor:
    """
    Returns the mean of the distribution.
    """
    return self._categorical.mean
rsample(sample_shape=torch.Size())

Reparameterized sampling using Gumbel-Softmax trick.

Source code in vambn/modelling/distributions/categorical.py
87
88
89
90
91
92
93
94
95
96
97
98
99
def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
    """
    Reparameterized sampling using Gumbel-Softmax trick.
    """
    if torch.any(torch.isnan(self._categorical.logits)):
        raise Exception("NaN values in logits")
    samples = gumbel_softmax(
        logits=self._categorical.logits,
        shape=tuple(self._categorical.logits.shape),
        tau=self.temperature,
        hard=False,
    )
    return samples
sample(sample_shape=torch.Size())

Draws a sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description
Tensor

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/categorical.py
75
76
77
78
79
80
81
82
83
84
85
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
    """
    Draws a sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    return self._categorical.sample(sample_shape)

gumbel_distribution

GumbelDistribution

Bases: ExpRelaxedCategorical

Gumbel distribution based on the ExpRelaxedCategorical distribution.

Source code in vambn/modelling/distributions/gumbel_distribution.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class GumbelDistribution(ExpRelaxedCategorical):
    """
    Gumbel distribution based on the ExpRelaxedCategorical distribution.
    """

    @property
    def probs(self):
        """
        Returns the probabilities associated with the Gumbel distribution.

        Returns:
            torch.Tensor: The probabilities.
        """
        return torch.exp(self.logits).clip(1e-6, 1 - 1e-6)

    @torch.no_grad()
    def sample(self, sample_shape=torch.Size()):
        """
        Draws a sample from the Gumbel distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        """
        Reparameterized sampling for the Gumbel distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The reparameterized sample.
        """
        return torch.exp(super().rsample(sample_shape))

    @property
    def mean(self):
        """
        Returns the mean of the Gumbel distribution.

        Returns:
            torch.Tensor: The mean of the distribution.
        """
        return self.probs.clip(1e-6, 1 - 1e-6)

    @property
    def mode(self):
        """
        Returns the mode of the Gumbel distribution.

        Returns:
            torch.Tensor: The mode of the distribution.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).mode

    def expand(self, batch_shape, _instance=None):
        """
        Expands the Gumbel distribution to the given batch shape.

        Args:
            batch_shape (torch.Size): The desired batch shape.
            _instance: The instance to expand.

        Returns:
            GumbelDistribution: The expanded Gumbel distribution.
        """
        return super().expand(batch_shape[:-1], _instance)

    def log_prob(self, value):
        """
        Calculates the log probability of a value under the Gumbel distribution.

        Args:
            value (torch.Tensor): The value for which to calculate the log probability.

        Returns:
            torch.Tensor: The log probability of the value.
        """
        probs = self.probs.clip(1e-6, 1 - 1e-6)
        return OneHotCategorical(probs=probs).log_prob(value)
mean property

Returns the mean of the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The mean of the distribution.

mode property

Returns the mode of the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The mode of the distribution.

probs property

Returns the probabilities associated with the Gumbel distribution.

Returns:

Type Description

torch.Tensor: The probabilities.

expand(batch_shape, _instance=None)

Expands the Gumbel distribution to the given batch shape.

Parameters:

Name Type Description Default
batch_shape Size

The desired batch shape.

required
_instance

The instance to expand.

None

Returns:

Name Type Description
GumbelDistribution

The expanded Gumbel distribution.

Source code in vambn/modelling/distributions/gumbel_distribution.py
70
71
72
73
74
75
76
77
78
79
80
81
def expand(self, batch_shape, _instance=None):
    """
    Expands the Gumbel distribution to the given batch shape.

    Args:
        batch_shape (torch.Size): The desired batch shape.
        _instance: The instance to expand.

    Returns:
        GumbelDistribution: The expanded Gumbel distribution.
    """
    return super().expand(batch_shape[:-1], _instance)
log_prob(value)

Calculates the log probability of a value under the Gumbel distribution.

Parameters:

Name Type Description Default
value Tensor

The value for which to calculate the log probability.

required

Returns:

Type Description

torch.Tensor: The log probability of the value.

Source code in vambn/modelling/distributions/gumbel_distribution.py
83
84
85
86
87
88
89
90
91
92
93
94
def log_prob(self, value):
    """
    Calculates the log probability of a value under the Gumbel distribution.

    Args:
        value (torch.Tensor): The value for which to calculate the log probability.

    Returns:
        torch.Tensor: The log probability of the value.
    """
    probs = self.probs.clip(1e-6, 1 - 1e-6)
    return OneHotCategorical(probs=probs).log_prob(value)
rsample(sample_shape=torch.Size())

Reparameterized sampling for the Gumbel distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The reparameterized sample.

Source code in vambn/modelling/distributions/gumbel_distribution.py
37
38
39
40
41
42
43
44
45
46
47
def rsample(self, sample_shape=torch.Size()):
    """
    Reparameterized sampling for the Gumbel distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The reparameterized sample.
    """
    return torch.exp(super().rsample(sample_shape))
sample(sample_shape=torch.Size())

Draws a sample from the Gumbel distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/gumbel_distribution.py
23
24
25
26
27
28
29
30
31
32
33
34
35
@torch.no_grad()
def sample(self, sample_shape=torch.Size()):
    """
    Draws a sample from the Gumbel distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    probs = self.probs.clip(1e-6, 1 - 1e-6)
    return OneHotCategorical(probs=probs).sample(sample_shape)

gumbel_softmax

gumbel_softmax(logits, shape, tau=1.0, hard=False)

Gumbel-Softmax implementation.

Parameters:

Name Type Description Default
logits Tensor

Logits to be used for the Gumbel-Softmax.

required
shape Tuple[int, int]

Shape of the logits. Required for torchscript.

required
tau float

Temperature factor. Defaults to 1.0.

1.0
hard bool

Hard sampling or soft. Defaults to False.

False

Returns:

Type Description
Tensor

torch.Tensor: Sampled categorical distribution.

Raises:

Type Description
ValueError

If logits contain NaN values.

Source code in vambn/modelling/distributions/gumbel_softmax.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def gumbel_softmax(
    logits: torch.Tensor,
    shape: Tuple[int, int],
    tau: float = 1.0,
    hard: bool = False,
) -> torch.Tensor:
    """
    Gumbel-Softmax implementation.

    Args:
        logits (torch.Tensor): Logits to be used for the Gumbel-Softmax.
        shape (Tuple[int, int]): Shape of the logits. Required for torchscript.
        tau (float, optional): Temperature factor. Defaults to 1.0.
        hard (bool, optional): Hard sampling or soft. Defaults to False.

    Returns:
        torch.Tensor: Sampled categorical distribution.

    Raises:
        ValueError: If logits contain NaN values.
    """
    if torch.isnan(logits).any():
        raise ValueError("Logits contain NaN values")

    gumbel_noise = sample_gumbel(shape, device=logits.device)
    y = logits + gumbel_noise
    tau = max(tau, 1e-9)
    y_soft = torch.softmax(y / (tau), dim=-1)

    if hard:
        _, ind = y_soft.max(dim=-1)
        y_hard = torch.zeros_like(y_soft).view(-1, shape[-1])
        y_hard.scatter_(1, ind.view(-1, 1), 1)
        y_hard = y_hard.view(shape[0], shape[1])
        y_soft = (y_hard - y_soft).detach() + y_soft

    return y_soft

sample_gumbel(shape, eps=1e-09, device=torch.device('cpu'))

Generate a sample from the Gumbel distribution.

Parameters:

Name Type Description Default
shape Tuple[int, int]

Shape of the sample.

required
eps float

Value to be added to avoid numerical issues. Defaults to 1e-9.

1e-09
device device

The device to generate the sample on. Defaults to torch.device("cpu").

device('cpu')

Returns:

Type Description
Tensor

torch.Tensor: Sample from the Gumbel distribution.

Source code in vambn/modelling/distributions/gumbel_softmax.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def sample_gumbel(
    shape: Tuple[int, int],
    eps: float = 1e-9,
    device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """
    Generate a sample from the Gumbel distribution.

    Args:
        shape (Tuple[int, int]): Shape of the sample.
        eps (float, optional): Value to be added to avoid numerical issues. Defaults to 1e-9.
        device (torch.device, optional): The device to generate the sample on. Defaults to torch.device("cpu").

    Returns:
        torch.Tensor: Sample from the Gumbel distribution.
    """
    U = torch.rand(shape, device=device)
    return -torch.log(-torch.log(U + eps) + eps)

parameters

CategoricalParameters dataclass

Bases: Parameters

Dataclass for categorical output

Source code in vambn/modelling/distributions/parameters.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class CategoricalParameters(Parameters):
    """Dataclass for categorical output"""

    logits: Tensor
    name: str = field(kw_only=True, default="cat")

    @property
    def device(self) -> torch.device:
        return self.logits.device

    @property
    def probs(self) -> Tensor:
        return logits_to_probs(self.logits)

    def __str__(self) -> str:
        return f"CatOutput: Avg. logits {self.logits.mean()}; shape {self.logits.shape}"

NormalParameters dataclass

Bases: Parameters

Dataclass for real output

Source code in vambn/modelling/distributions/parameters.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@dataclass
class NormalParameters(Parameters):
    """Dataclass for real output"""

    loc: Tensor
    scale: Tensor
    name: str = field(kw_only=True, default="real")

    @property
    def device(self) -> torch.device:
        return self.loc.device

    def __str__(self) -> str:
        return f"RealOutput: Avg. Mean {self.loc.mean()}; Avg. Std {self.scale.mean()}); shape {self.loc.shape}"

Parameters dataclass

Bases: ABC

Dataclass for parameter output

Source code in vambn/modelling/distributions/parameters.py
21
22
23
24
25
26
27
28
29
30
@dataclass
class Parameters(ABC):
    """Dataclass for parameter output"""

    name: str = field(kw_only=True, default="")

    @property
    @abstractmethod
    def device(self) -> torch.device:
        pass

PoissonParameters dataclass

Bases: Parameters

Dataclass for count output

Source code in vambn/modelling/distributions/parameters.py
69
70
71
72
73
74
75
76
77
78
79
80
81
@dataclass
class PoissonParameters(Parameters):
    """Dataclass for count output"""

    rate: Tensor
    name: str = field(kw_only=True, default="count")

    @property
    def device(self) -> torch.device:
        return self.rate.device

    def __str__(self) -> str:
        return f"CountOutput: Avg. Lambda {self.rate.mean()}; shape {self.rate.shape}"

truncated_normal

TruncatedNormal

Bases: Normal

TruncatedNormal distribution with support [low, high].

This class extends the Normal distribution to support truncation, i.e., the distribution is limited to the range [low, high].

Source code in vambn/modelling/distributions/truncated_normal.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class TruncatedNormal(Normal):
    """
    TruncatedNormal distribution with support [low, high].

    This class extends the Normal distribution to support truncation, i.e., the
    distribution is limited to the range [low, high].
    """

    arg_constraints = {
        "loc": constraints.real,
        "scale": constraints.positive,
        "low": constraints.real,
        "high": constraints.real,
    }

    def __init__(
        self,
        loc: torch.Tensor,
        scale: torch.Tensor,
        low: Optional[torch.Tensor] = -torch.tensor(float("inf")),
        high: Optional[torch.Tensor] = torch.tensor(float("inf")),
        validate_args: Optional[bool] = None,
    ):
        """
        Initialize the TruncatedNormal distribution.

        Args:
            loc (torch.Tensor): The mean of the normal distribution.
            scale (torch.Tensor): The standard deviation of the normal distribution.
            low (Optional[torch.Tensor]): The lower bound of the truncation. Defaults to -inf.
            high (Optional[torch.Tensor]): The upper bound of the truncation. Defaults to inf.
            validate_args (Optional[bool]): Whether to validate arguments. Defaults to None.
        """
        self.low = low
        self.high = high
        super().__init__(loc, scale, validate_args=validate_args)

    def _clamp(self, x):
        """
        Clamp the values to the range [low, high].

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The clamped tensor.
        """
        clamped_x = torch.clamp(x, self.low, self.high)
        x_mask = (x < self.low) | (x > self.high)
        x_fill = torch.where(x < self.low, self.low, self.high)
        return torch.where(x_mask, x_fill, clamped_x)

    def sample(self, sample_shape=torch.Size()):
        """
        Draws a sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The drawn sample.
        """
        with torch.no_grad():
            return self._clamp(super().sample(sample_shape))

    def rsample(self, sample_shape=torch.Size()):
        """
        Draws a reparameterized sample from the distribution.

        Args:
            sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

        Returns:
            torch.Tensor: The reparameterized sample.
        """
        with torch.no_grad():
            return self._clamp(super().rsample(sample_shape))

    def log_prob(self, value):
        """
        Calculate the log probability of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: Log probability of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        log_prob = super().log_prob(value)
        log_prob = torch.where(
            (value < self.low) | (value > self.high),
            torch.log(torch.tensor(1e-12)),
            log_prob,
        )
        normalizer = torch.log(self.cdf(self.high) - self.cdf(self.low))
        return log_prob - normalizer

    def cdf(self, value):
        """
        Calculate the cumulative distribution function (CDF) of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: CDF of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        cdf = super().cdf(value)
        low_cdf = super().cdf(self.low)
        high_cdf = super().cdf(self.high)
        return (cdf - low_cdf) / (high_cdf - low_cdf)

    def icdf(self, value):
        """
        Calculate the inverse cumulative distribution function (ICDF) of a value.

        Args:
            value (torch.Tensor): Input value.

        Returns:
            torch.Tensor: ICDF of the input value.
        """
        if self._validate_args:
            self._validate_sample(value)
        low_cdf = super().cdf(self.low)
        high_cdf = super().cdf(self.high)
        rescaled_value = low_cdf + (high_cdf - low_cdf) * value
        return super().icdf(rescaled_value)
__init__(loc, scale, low=-torch.tensor(float('inf')), high=torch.tensor(float('inf')), validate_args=None)

Initialize the TruncatedNormal distribution.

Parameters:

Name Type Description Default
loc Tensor

The mean of the normal distribution.

required
scale Tensor

The standard deviation of the normal distribution.

required
low Optional[Tensor]

The lower bound of the truncation. Defaults to -inf.

-tensor(float('inf'))
high Optional[Tensor]

The upper bound of the truncation. Defaults to inf.

tensor(float('inf'))
validate_args Optional[bool]

Whether to validate arguments. Defaults to None.

None
Source code in vambn/modelling/distributions/truncated_normal.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def __init__(
    self,
    loc: torch.Tensor,
    scale: torch.Tensor,
    low: Optional[torch.Tensor] = -torch.tensor(float("inf")),
    high: Optional[torch.Tensor] = torch.tensor(float("inf")),
    validate_args: Optional[bool] = None,
):
    """
    Initialize the TruncatedNormal distribution.

    Args:
        loc (torch.Tensor): The mean of the normal distribution.
        scale (torch.Tensor): The standard deviation of the normal distribution.
        low (Optional[torch.Tensor]): The lower bound of the truncation. Defaults to -inf.
        high (Optional[torch.Tensor]): The upper bound of the truncation. Defaults to inf.
        validate_args (Optional[bool]): Whether to validate arguments. Defaults to None.
    """
    self.low = low
    self.high = high
    super().__init__(loc, scale, validate_args=validate_args)
cdf(value)

Calculate the cumulative distribution function (CDF) of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: CDF of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def cdf(self, value):
    """
    Calculate the cumulative distribution function (CDF) of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: CDF of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    cdf = super().cdf(value)
    low_cdf = super().cdf(self.low)
    high_cdf = super().cdf(self.high)
    return (cdf - low_cdf) / (high_cdf - low_cdf)
icdf(value)

Calculate the inverse cumulative distribution function (ICDF) of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: ICDF of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def icdf(self, value):
    """
    Calculate the inverse cumulative distribution function (ICDF) of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: ICDF of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    low_cdf = super().cdf(self.low)
    high_cdf = super().cdf(self.high)
    rescaled_value = low_cdf + (high_cdf - low_cdf) * value
    return super().icdf(rescaled_value)
log_prob(value)

Calculate the log probability of a value.

Parameters:

Name Type Description Default
value Tensor

Input value.

required

Returns:

Type Description

torch.Tensor: Log probability of the input value.

Source code in vambn/modelling/distributions/truncated_normal.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def log_prob(self, value):
    """
    Calculate the log probability of a value.

    Args:
        value (torch.Tensor): Input value.

    Returns:
        torch.Tensor: Log probability of the input value.
    """
    if self._validate_args:
        self._validate_sample(value)
    log_prob = super().log_prob(value)
    log_prob = torch.where(
        (value < self.low) | (value > self.high),
        torch.log(torch.tensor(1e-12)),
        log_prob,
    )
    normalizer = torch.log(self.cdf(self.high) - self.cdf(self.low))
    return log_prob - normalizer
rsample(sample_shape=torch.Size())

Draws a reparameterized sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The reparameterized sample.

Source code in vambn/modelling/distributions/truncated_normal.py
71
72
73
74
75
76
77
78
79
80
81
82
def rsample(self, sample_shape=torch.Size()):
    """
    Draws a reparameterized sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The reparameterized sample.
    """
    with torch.no_grad():
        return self._clamp(super().rsample(sample_shape))
sample(sample_shape=torch.Size())

Draws a sample from the distribution.

Parameters:

Name Type Description Default
sample_shape Size

The shape of the sample to draw. Defaults to torch.Size().

Size()

Returns:

Type Description

torch.Tensor: The drawn sample.

Source code in vambn/modelling/distributions/truncated_normal.py
58
59
60
61
62
63
64
65
66
67
68
69
def sample(self, sample_shape=torch.Size()):
    """
    Draws a sample from the distribution.

    Args:
        sample_shape (torch.Size, optional): The shape of the sample to draw. Defaults to torch.Size().

    Returns:
        torch.Tensor: The drawn sample.
    """
    with torch.no_grad():
        return self._clamp(super().sample(sample_shape))

models

config

DataModuleConfig dataclass

Configuration for the data module.

Attributes:

Name Type Description
name str

The name of the data module.

variable_types VarTypes

Types of variables used in the data module.

num_timepoints int

The number of timepoints. Must be at least 1.

n_layers Optional[int]

The number of layers. Defaults to None.

noise_size Optional[int]

The size of the noise. Defaults to None.

Source code in vambn/modelling/models/config.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@dataclass
class DataModuleConfig:
    """Configuration for the data module.

    Attributes:
        name (str): The name of the data module.
        variable_types (VarTypes): Types of variables used in the data module.
        num_timepoints (int): The number of timepoints. Must be at least 1.
        n_layers (Optional[int]): The number of layers. Defaults to None.
        noise_size (Optional[int]): The size of the noise. Defaults to None.
    """

    name: str
    variable_types: VarTypes
    num_timepoints: int
    n_layers: Optional[int] = None
    noise_size: Optional[int] = None

    def __post_init__(self):
        """Validate that the number of timepoints is at least 1.

        Raises:
            Exception: If the number of timepoints is less than 1.
        """
        if self.num_timepoints < 1:
            raise Exception("Number of timepoints must be at least 1")

    @cached_property
    def input_dim(self):
        """Get the input dimension based on variable types.

        Returns:
            int: The input dimension.
        """
        return get_input_dim(self.variable_types)

    @cached_property
    def is_longitudinal(self) -> bool:
        """Check if the data is longitudinal.

        Returns:
            bool: True if the number of timepoints is greater than 1, False otherwise.
        """
        return self.num_timepoints > 1
input_dim cached property

Get the input dimension based on variable types.

Returns:

Name Type Description
int

The input dimension.

is_longitudinal: bool cached property

Check if the data is longitudinal.

Returns:

Name Type Description
bool bool

True if the number of timepoints is greater than 1, False otherwise.

__post_init__()

Validate that the number of timepoints is at least 1.

Raises:

Type Description
Exception

If the number of timepoints is less than 1.

Source code in vambn/modelling/models/config.py
33
34
35
36
37
38
39
40
def __post_init__(self):
    """Validate that the number of timepoints is at least 1.

    Raises:
        Exception: If the number of timepoints is less than 1.
    """
    if self.num_timepoints < 1:
        raise Exception("Number of timepoints must be at least 1")

ModelConfig dataclass

Configuration for the model.

Source code in vambn/modelling/models/config.py
 8
 9
10
11
12
@dataclass
class ModelConfig:
    """Configuration for the model."""

    pass

conversion

Conversion

Source code in vambn/modelling/models/conversion.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Conversion:
    @staticmethod
    def _encode_categorical(
        x: Tensor,
        mask: Tensor,
        n: int,
    ) -> Tuple[Tensor, Tensor]:
        """Encode categorical data into one-hot vectors.

        Args:
            x (Tensor): The input tensor containing categorical data.
            mask (Tensor): The mask tensor indicating valid data points.
            n (int): The number of categories.

        Returns:
            Tuple[Tensor, Tensor]: A tuple containing the one-hot encoded data tensor and the updated mask tensor.
        """
        ohe_data = torch.zeros((mask.shape[0], n), device=x.device)
        ohe_data[mask == 1, :] = torch.nn.functional.one_hot(
            torch.masked_select(x, mask.bool()).long(), num_classes=n
        ).float()
        ohe_mask = mask.view(-1, 1).repeat(1, n)
        return ohe_data, ohe_mask

    @staticmethod
    def cat_to_one_hot(
        variable_types: List[VariableType],
        x: Tensor,
        mask: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """Normalize the input data batch-wise for various data types.

        Args:
            variable_types (List[VariableType]): List of variable types indicating the data type of each variable.
            x (Tensor): The input tensor containing data to be normalized.
            mask (Tensor): The mask tensor indicating valid data points.

        Returns:
            Tuple[Tensor, Tensor]: A tuple containing the normalized data tensor and the updated mask tensor.
        """
        normalized_data, normalized_mask = [], []
        for i, (var_type, d, m) in enumerate(zip(variable_types, x.T, mask.T)):
            if var_type.data_type == "cat":
                normalized, mask_i = Conversion._encode_categorical(
                    x=d, mask=m, n=var_type.input_dim
                )
            else:
                normalized, mask_i = d, m

            normalized_data.append(normalized)
            normalized_mask.append(mask_i)

        if any([x.ndim == 2 for x in normalized_data]):
            normalized_data = [
                x if x.ndim == 2 else x.unsqueeze(1) for x in normalized_data
            ]
            normalized_mask = [
                x if x.ndim == 2 else x.unsqueeze(1) for x in normalized_mask
            ]
            new_x = torch.cat(normalized_data, dim=1).float()
            new_mask = torch.cat(normalized_mask, dim=1)
        else:
            new_x = torch.stack(normalized_data, dim=1).float()
            new_mask = torch.stack(normalized_mask, dim=1)
        new_x *= new_mask

        return new_x, new_mask
cat_to_one_hot(variable_types, x, mask) staticmethod

Normalize the input data batch-wise for various data types.

Parameters:

Name Type Description Default
variable_types List[VariableType]

List of variable types indicating the data type of each variable.

required
x Tensor

The input tensor containing data to be normalized.

required
mask Tensor

The mask tensor indicating valid data points.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: A tuple containing the normalized data tensor and the updated mask tensor.

Source code in vambn/modelling/models/conversion.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@staticmethod
def cat_to_one_hot(
    variable_types: List[VariableType],
    x: Tensor,
    mask: Tensor,
) -> Tuple[Tensor, Tensor]:
    """Normalize the input data batch-wise for various data types.

    Args:
        variable_types (List[VariableType]): List of variable types indicating the data type of each variable.
        x (Tensor): The input tensor containing data to be normalized.
        mask (Tensor): The mask tensor indicating valid data points.

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing the normalized data tensor and the updated mask tensor.
    """
    normalized_data, normalized_mask = [], []
    for i, (var_type, d, m) in enumerate(zip(variable_types, x.T, mask.T)):
        if var_type.data_type == "cat":
            normalized, mask_i = Conversion._encode_categorical(
                x=d, mask=m, n=var_type.input_dim
            )
        else:
            normalized, mask_i = d, m

        normalized_data.append(normalized)
        normalized_mask.append(mask_i)

    if any([x.ndim == 2 for x in normalized_data]):
        normalized_data = [
            x if x.ndim == 2 else x.unsqueeze(1) for x in normalized_data
        ]
        normalized_mask = [
            x if x.ndim == 2 else x.unsqueeze(1) for x in normalized_mask
        ]
        new_x = torch.cat(normalized_data, dim=1).float()
        new_mask = torch.cat(normalized_mask, dim=1)
    else:
        new_x = torch.stack(normalized_data, dim=1).float()
        new_mask = torch.stack(normalized_mask, dim=1)
    new_x *= new_mask

    return new_x, new_mask

gan

Classifier

Bases: Module

Source code in vambn/modelling/models/gan.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class Classifier(nn.Module):
    @typeguard.typechecked
    def __init__(
        self,
        input_dim: int,
        layer_sizes: Tuple[int, ...],
        output_dim: int = 1,
        activation: nn.Module = nn.ReLU(),
    ) -> None:
        """Initialize the Classifier network.

        Args:
            input_dim (int): The dimension of the input tensor.
            layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
            output_dim (int, optional): The dimension of the output tensor. Defaults to 1.
            activation (nn.Module, optional): The activation function to use. Defaults to nn.ReLU().
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layer_sizes = layer_sizes
        self.activation = activation

        layers = [
            nn.Linear(input_dim, layer_sizes[0]),
            nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
            activation,
        ]

        for i in range(len(layer_sizes) - 1):
            layers += [
                nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
                nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
                activation,
            ]

        layers += [
            nn.Linear(layer_sizes[-1], output_dim),
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the Classifier network.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor with sigmoid activation.
        """
        return F.sigmoid(self.net(x))
__init__(input_dim, layer_sizes, output_dim=1, activation=nn.ReLU())

Initialize the Classifier network.

Parameters:

Name Type Description Default
input_dim int

The dimension of the input tensor.

required
layer_sizes Tuple[int, ...]

A tuple specifying the sizes of the hidden layers.

required
output_dim int

The dimension of the output tensor. Defaults to 1.

1
activation Module

The activation function to use. Defaults to nn.ReLU().

ReLU()
Source code in vambn/modelling/models/gan.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@typeguard.typechecked
def __init__(
    self,
    input_dim: int,
    layer_sizes: Tuple[int, ...],
    output_dim: int = 1,
    activation: nn.Module = nn.ReLU(),
) -> None:
    """Initialize the Classifier network.

    Args:
        input_dim (int): The dimension of the input tensor.
        layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
        output_dim (int, optional): The dimension of the output tensor. Defaults to 1.
        activation (nn.Module, optional): The activation function to use. Defaults to nn.ReLU().
    """
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.layer_sizes = layer_sizes
    self.activation = activation

    layers = [
        nn.Linear(input_dim, layer_sizes[0]),
        nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
        activation,
    ]

    for i in range(len(layer_sizes) - 1):
        layers += [
            nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
            nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
            activation,
        ]

    layers += [
        nn.Linear(layer_sizes[-1], output_dim),
    ]
    self.net = nn.Sequential(*layers)
forward(x)

Forward pass through the Classifier network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor with sigmoid activation.

Source code in vambn/modelling/models/gan.py
107
108
109
110
111
112
113
114
115
116
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the Classifier network.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor with sigmoid activation.
    """
    return F.sigmoid(self.net(x))

Discriminator

Bases: Module

Source code in vambn/modelling/models/gan.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class Discriminator(nn.Module):
    @typeguard.typechecked
    def __init__(
        self,
        input_dim: int,
        layer_sizes: Tuple[int, ...],
        output_dim: int = 1,
        activation: nn.Module = nn.LeakyReLU(),
    ) -> None:
        """Initialize the Discriminator network.

        Args:
            input_dim (int): The dimension of the input tensor.
            layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
            output_dim (int, optional): The dimension of the output tensor. Defaults to 1.
            activation (nn.Module, optional): The activation function to use. Defaults to nn.LeakyReLU().
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layer_sizes = layer_sizes
        self.activation = activation

        layers = [
            nn.Linear(input_dim, layer_sizes[0]),
            nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
            activation,
        ]

        for i in range(len(layer_sizes) - 1):
            layers += [
                nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
                nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
                activation,
            ]

        layers += [
            nn.Linear(layer_sizes[-1], output_dim),
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the Discriminator network.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        return torch.mean(self.net(x))
__init__(input_dim, layer_sizes, output_dim=1, activation=nn.LeakyReLU())

Initialize the Discriminator network.

Parameters:

Name Type Description Default
input_dim int

The dimension of the input tensor.

required
layer_sizes Tuple[int, ...]

A tuple specifying the sizes of the hidden layers.

required
output_dim int

The dimension of the output tensor. Defaults to 1.

1
activation Module

The activation function to use. Defaults to nn.LeakyReLU().

LeakyReLU()
Source code in vambn/modelling/models/gan.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@typeguard.typechecked
def __init__(
    self,
    input_dim: int,
    layer_sizes: Tuple[int, ...],
    output_dim: int = 1,
    activation: nn.Module = nn.LeakyReLU(),
) -> None:
    """Initialize the Discriminator network.

    Args:
        input_dim (int): The dimension of the input tensor.
        layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
        output_dim (int, optional): The dimension of the output tensor. Defaults to 1.
        activation (nn.Module, optional): The activation function to use. Defaults to nn.LeakyReLU().
    """
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.layer_sizes = layer_sizes
    self.activation = activation

    layers = [
        nn.Linear(input_dim, layer_sizes[0]),
        nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
        activation,
    ]

    for i in range(len(layer_sizes) - 1):
        layers += [
            nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
            nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
            activation,
        ]

    layers += [
        nn.Linear(layer_sizes[-1], output_dim),
    ]
    self.net = nn.Sequential(*layers)
forward(x)

Forward pass through the Discriminator network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor.

Source code in vambn/modelling/models/gan.py
160
161
162
163
164
165
166
167
168
169
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the Discriminator network.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor.
    """
    return torch.mean(self.net(x))

Generator

Bases: Module

Source code in vambn/modelling/models/gan.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Generator(nn.Module):
    @typeguard.typechecked
    def __init__(
        self,
        input_dim: int,
        layer_sizes: Tuple[int, ...],
        output_dim: int,
        activation: nn.Module = nn.ReLU(),
    ) -> None:
        """Initialize the Generator network.

        Args:
            input_dim (int): The dimension of the input tensor.
            layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
            output_dim (int): The dimension of the output tensor.
            activation (nn.Module, optional): The activation function to use. Defaults to nn.ReLU().
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layer_sizes = layer_sizes
        self.activation = activation

        layers = [
            nn.Linear(input_dim, layer_sizes[0]),
            nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
            activation,
        ]

        for i in range(len(layer_sizes) - 1):
            layers += [
                nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
                nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
                activation,
            ]

        layers += [
            nn.Linear(layer_sizes[-1], output_dim),
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the Generator network.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        return self.net(x)
__init__(input_dim, layer_sizes, output_dim, activation=nn.ReLU())

Initialize the Generator network.

Parameters:

Name Type Description Default
input_dim int

The dimension of the input tensor.

required
layer_sizes Tuple[int, ...]

A tuple specifying the sizes of the hidden layers.

required
output_dim int

The dimension of the output tensor.

required
activation Module

The activation function to use. Defaults to nn.ReLU().

ReLU()
Source code in vambn/modelling/models/gan.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@typeguard.typechecked
def __init__(
    self,
    input_dim: int,
    layer_sizes: Tuple[int, ...],
    output_dim: int,
    activation: nn.Module = nn.ReLU(),
) -> None:
    """Initialize the Generator network.

    Args:
        input_dim (int): The dimension of the input tensor.
        layer_sizes (Tuple[int, ...]): A tuple specifying the sizes of the hidden layers.
        output_dim (int): The dimension of the output tensor.
        activation (nn.Module, optional): The activation function to use. Defaults to nn.ReLU().
    """
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.layer_sizes = layer_sizes
    self.activation = activation

    layers = [
        nn.Linear(input_dim, layer_sizes[0]),
        nn.BatchNorm1d(layer_sizes[0], eps=1e-5, momentum=0.1),
        activation,
    ]

    for i in range(len(layer_sizes) - 1):
        layers += [
            nn.Linear(layer_sizes[i], layer_sizes[i + 1]),
            nn.BatchNorm1d(layer_sizes[i + 1], eps=1e-5, momentum=0.1),
            activation,
        ]

    layers += [
        nn.Linear(layer_sizes[-1], output_dim),
    ]
    self.net = nn.Sequential(*layers)
forward(x)

Forward pass through the Generator network.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor.

Source code in vambn/modelling/models/gan.py
54
55
56
57
58
59
60
61
62
63
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through the Generator network.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor.
    """
    return self.net(x)

hivae

config

HivaeConfig dataclass

Bases: ModelConfig

Configuration class for HIVAE models.

Attributes:

Name Type Description
name str

Name of the configuration. Typically the module name.

variable_types VarTypes

Definition of the variable types. See VarTypes.

dim_s int

Dimension of the latent space S.

dim_y int

Dimension of the Y space.

dim_z int

Dimension of the latent space Z.

mtl_methods Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file.

use_imputation_layer bool

Flag to use imputation layer.

dropout float

Dropout rate.

n_layers Optional[int]

Number of layers. Needed for longitudinal data.

num_timepoints int

Number of timepoints for longitudinal data. Default is 1.

Source code in vambn/modelling/models/hivae/config.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@dataclass
class HivaeConfig(ModelConfig):
    """Configuration class for HIVAE models.

    Attributes:
        name (str): Name of the configuration. Typically the module name.
        variable_types (VarTypes): Definition of the variable types. See VarTypes.
        dim_s (int): Dimension of the latent space S.
        dim_y (int): Dimension of the Y space.
        dim_z (int): Dimension of the latent space Z.
        mtl_methods (Tuple[str, ...]): Methods for multi-task learning. Tested 
            possibilities are combinations of "identity", "gradnorm", "graddrop". 
            Further implementations and details can be found in the mtl.py file.
        use_imputation_layer (bool): Flag to use imputation layer.
        dropout (float): Dropout rate.
        n_layers (Optional[int]): Number of layers. Needed for longitudinal data.
        num_timepoints (int): Number of timepoints for longitudinal data. Default is 1.
    """
    name: str
    variable_types: VarTypes
    dim_s: int
    dim_y: int
    dim_z: int
    mtl_methods: Tuple[str, ...]
    use_imputation_layer: bool
    dropout: float

    # Only needed for longitudinal data
    n_layers: Optional[int] = None
    num_timepoints: int = 1

    def __post_init__(self):
        """Converts mtl_methods to a tuple if it is a list."""
        if isinstance(self.mtl_methods, List):
            self.mtl_methods = tuple(self.mtl_methods)

    @cached_property
    def input_dim(self):
        """Gets the input dimension based on variable types.

        Returns:
            int: The input dimension.
        """
        return get_input_dim(self.variable_types)

    @cached_property
    def is_longitudinal(self) -> bool:
        """Checks if the data is longitudinal.

        Returns:
            bool: True if the data has more than one timepoint, False otherwise.
        """
        return self.num_timepoints > 1
input_dim cached property

Gets the input dimension based on variable types.

Returns:

Name Type Description
int

The input dimension.

is_longitudinal: bool cached property

Checks if the data is longitudinal.

Returns:

Name Type Description
bool bool

True if the data has more than one timepoint, False otherwise.

__post_init__()

Converts mtl_methods to a tuple if it is a list.

Source code in vambn/modelling/models/hivae/config.py
84
85
86
87
def __post_init__(self):
    """Converts mtl_methods to a tuple if it is a list."""
    if isinstance(self.mtl_methods, List):
        self.mtl_methods = tuple(self.mtl_methods)
ModularHivaeConfig dataclass

Bases: ModelConfig

Configuration class for Modular HIVAE models.

Attributes:

Name Type Description
module_config Tuple[DataModuleConfig]

Configuration for the data modules. See DataModuleConfig.

dim_s int | Tuple[int, ...] | Dict[str, int]

Dimension of the latent space S.

dim_z int

Dimension of the latent space Z.

dim_ys int

Dimension of the YS space.

dim_y int | Tuple[int, ...] | Dict[str, int]

Dimension of the Y space.

mtl_method Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file.

use_imputation_layer bool

Flag to use imputation layer.

dropout float

Dropout rate.

n_layers int

Number of layers.

shared_element str

Shared element type. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Default is "none".

Source code in vambn/modelling/models/hivae/config.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@dataclass
class ModularHivaeConfig(ModelConfig):
    """Configuration class for Modular HIVAE models.

    Attributes:
        module_config (Tuple[DataModuleConfig]): Configuration for the data modules. See DataModuleConfig.
        dim_s (int | Tuple[int, ...] | Dict[str, int]): Dimension of the latent space S.
        dim_z (int): Dimension of the latent space Z.
        dim_ys (int): Dimension of the YS space.
        dim_y (int | Tuple[int, ...] | Dict[str, int]): Dimension of the Y space.         
        mtl_method (Tuple[str, ...]): Methods for multi-task learning. Tested
            possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file.
        use_imputation_layer (bool): Flag to use imputation layer.
        dropout (float): Dropout rate.
        n_layers (int): Number of layers.
        shared_element (str): Shared element type. Possible values are "none", 
            "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl".
            Default is "none".
    """    
    module_config: Tuple[DataModuleConfig]
    dim_s: int | Tuple[int, ...] | Dict[str, int]
    dim_z: int
    dim_ys: int
    dim_y: int | Tuple[int, ...] | Dict[str, int]
    mtl_method: Tuple[str, ...]
    use_imputation_layer: bool
    dropout: float
    n_layers: int
    shared_element: str = "none"

    def __post_init__(self):
        """Validates and sets up the configuration after initialization.

        Raises:
            Exception: If the number of layers is less than 1.
        """
        if self.n_layers < 1:
            raise Exception("Number of layers must be at least 1")

        for module in self.module_config:
            module.n_layers = self.n_layers
__post_init__()

Validates and sets up the configuration after initialization.

Raises:

Type Description
Exception

If the number of layers is less than 1.

Source code in vambn/modelling/models/hivae/config.py
40
41
42
43
44
45
46
47
48
49
50
def __post_init__(self):
    """Validates and sets up the configuration after initialization.

    Raises:
        Exception: If the number of layers is less than 1.
    """
    if self.n_layers < 1:
        raise Exception("Number of layers must be at least 1")

    for module in self.module_config:
        module.n_layers = self.n_layers

decoder

Decoder

Bases: Module

HIVAE Decoder class.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects. See VarTypes in data/dataclasses.py.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
mtl_method Tuple[str, ...]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
decoder_shared Module

Shared decoder module. Defaults to nn.Identity().

Identity()

Attributes:

Name Type Description
prior_s_val Parameter

Prior distribution for s values.

prior_loc_z ModifiedLinear

Linear layer for z prior distribution.

s_dim int

Dimension of s space.

z_dim int

Dimension of z space.

y_dim int

Dimension of y space.

variable_types VarTypes

List of variable types.

decoder_shared Module

Shared decoder module.

internal_layer_norm Module

Layer normalization module.

heads ModuleList

List of head modules for each variable type.

mtl_methods Tuple[str, ...]

Methods for multi-task learning.

_mtl_module_y Module

Multi-task learning module for y.

_mtl_module_s Module

Multi-task learning module for s.

moo_block MultiMOOForLoop

Multi-task learning block.

_decoding bool

Decoding flag.

Source code in vambn/modelling/models/hivae/decoder.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class Decoder(nn.Module):
    """HIVAE Decoder class.

    Args:
        variable_types (VarTypes): List of VariableType objects. See VarTypes in
            data/dataclasses.py.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        mtl_method (Tuple[str, ...], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        decoder_shared (nn.Module, optional): Shared decoder module. Defaults to nn.Identity().

    Attributes:
        prior_s_val (nn.Parameter): Prior distribution for s values.
        prior_loc_z (ModifiedLinear): Linear layer for z prior distribution.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        variable_types (VarTypes): List of variable types.
        decoder_shared (nn.Module): Shared decoder module.
        internal_layer_norm (nn.Module): Layer normalization module.
        heads (nn.ModuleList): List of head modules for each variable type.
        mtl_methods (Tuple[str, ...]): Methods for multi-task learning.
        _mtl_module_y (nn.Module): Multi-task learning module for y.
        _mtl_module_s (nn.Module): Multi-task learning module for s.
        moo_block (moo.MultiMOOForLoop): Multi-task learning block.
        _decoding (bool): Decoding flag.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        s_dim: int,
        z_dim: int,
        y_dim: int,
        mtl_method: Tuple[str, ...] = ("identity",),
        decoder_shared: nn.Module = nn.Identity(),
    ):
        """
        Initialize the HIVAE Decoder.

        Args:
            variable_types (List[VariableType]): List of VariableType objects.
            s_dim (int): Dimension of s space.
            z_dim (int): Dimension of z space.
            y_dim (int): Dimension of y space.
            mtl_method (Tuple[str]): List of methods to use for multi-task learning.
        """
        super().__init__()

        # prior distributions
        self.prior_s_val = nn.Parameter(
            torch.ones(s_dim) / s_dim, requires_grad=False
        )
        self.prior_loc_z = ModifiedLinear(s_dim, z_dim, bias=True)

        self.s_dim = s_dim
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.variable_types = variable_types

        self.decoder_shared = decoder_shared
        self.internal_layer_norm = nn.Identity()  # nn.LayerNorm(self.z_dim)
        self.heads: List[
            RealHead | PosHead | CountHead | CatHead
        ] | nn.ModuleList = nn.ModuleList(
            [
                HEAD_DICT[variable_types[i].data_type](
                    variable_types[i], s_dim, z_dim, y_dim
                )
                for i in range(len(variable_types))
            ]
        )

        self.mtl_methods = mtl_method
        self._mtl_module_y: nn.Module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=len(self.heads),
        )
        self._mtl_module_s: nn.Module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=len(self.heads),
        )
        self.moo_block = moo.MultiMOOForLoop(
            len(self.heads),
            moo_methods=(self._mtl_module_y, self._mtl_module_s),
        )

        self._decoding = True

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self._decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self._decoding = value

    @cached_property
    def colnames(self) -> List[str]:
        """Gets the column names of the data.

        Returns:
            List[str]: List of column names.
        """
        return [var.name for var in self.variable_types]

    @property
    def prior_s(self) -> torch.distributions.OneHotCategorical:
        """Gets the prior distribution for s.

        Returns:
            torch.distributions.OneHotCategorical: Prior distribution for s.
        """
        return torch.distributions.OneHotCategorical(
            probs=self.prior_s_val, validate_args=False
        )

    def prior_z(self, loc: torch.Tensor) -> torch.distributions.Normal:
        """Gets the prior distribution for z.

        Args:
            loc (torch.Tensor): Location parameter for z.

        Returns:
            torch.distributions.Normal: Prior distribution for z.
        """
        return torch.distributions.Normal(loc, torch.ones_like(loc))

    def kl_s(self, encoder_output: EncoderOutput) -> torch.Tensor:
        """Computes the KL divergence for s.

        Args:
            encoder_output (EncoderOutput): Encoder output.

        Returns:
            torch.Tensor: KL divergence for s.
        """
        return torch.distributions.kl.kl_divergence(
            torch.distributions.OneHotCategorical(
                logits=encoder_output.logits_s, validate_args=False
            ),
            self.prior_s,
        )

    def kl_z(
        self,
        mean_qz: torch.Tensor,
        std_qz: torch.Tensor,
        mean_pz: torch.Tensor,
        std_pz: torch.Tensor,
    ) -> torch.Tensor:
        """Computes the KL divergence for z.

        Args:
            mean_qz (torch.Tensor): Mean of the posterior distribution.
            std_qz (torch.Tensor): Standard deviation of the posterior distribution.
            mean_pz (torch.Tensor): Mean of the prior distribution.
            std_pz (torch.Tensor): Standard deviation of the prior distribution.

        Returns:
            torch.Tensor: KL divergence for z.
        """
        return torch.distributions.kl.kl_divergence(
            torch.distributions.Normal(mean_qz, std_qz),
            torch.distributions.Normal(mean_pz, std_pz),
        ).sum(dim=-1)

    def _cat_samples(self, samples: List[torch.Tensor]) -> torch.Tensor:
        """Concatenates samples.

        Args:
            samples (List[torch.Tensor]): List of samples.

        Returns:
            torch.Tensor: Concatenated samples.

        Raises:
            ValueError: If no samples were drawn or if samples have an incorrect shape.
        """
        if len(samples) == 0 or all(x is None for x in samples):
            raise ValueError("No samples were drawn")
        else:
            sample_stack = torch.stack(samples, dim=1)
            if sample_stack.ndim == 2:
                return sample_stack
            elif sample_stack.ndim == 3 and sample_stack.shape[2] == 1:
                return sample_stack.squeeze(2)
            else:
                raise ValueError(
                    "Samples should be of shape (batch, features) or (batch, features, 1)"
                )

    def forward(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
        normalization_parameters: NormalizationParameters,
    ) -> HivaeOutput:
        """Forward pass of the decoder.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Mask for the data.
            encoder_output (EncoderOutput): Output from the encoder.
            normalization_parameters (NormalizationParameters): Parameters for normalization.

        Returns:
            HivaeOutput: Output from the decoder.
        """
        samples_z = encoder_output.samples_z
        decoder_representation = encoder_output.decoder_representation
        samples_s = encoder_output.samples_s

        # Obtaining the parameters for the decoder
        # decoder representation of shape batch x dim_z
        interim_decoder_representation = self.decoder_shared(
            decoder_representation
        )  # identity
        interim_decoder_representation = self.internal_layer_norm(
            interim_decoder_representation
        )  # identity
        if self.training:
            moo_out_s, moo_out_z = self.moo_block(
                samples_s, interim_decoder_representation
            )
            x_params = []
            for head, s_i, k_i in zip(self.heads, moo_out_s, moo_out_z):
                x_params.append(head(k_i, s_i))
            x_params = tuple(x_params)
        else:
            x_params = tuple(
                [
                    head(interim_decoder_representation, samples_s)
                    for head in self.heads
                ]
            )

        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_parameters
        )

        # Compute the likelihood and kl divergences
        log_probs: List[torch.Tensor] = [torch.Tensor([-1])] * len(
            self.variable_types
        )
        samples: List[torch.Tensor] = [torch.Tensor([-1])] * len(
            self.variable_types
        )
        for i, (x_i, m_i, head_i, params_i) in enumerate(
            zip(data.T, mask.T, self.heads, x_params)
        ):
            head_i.dist(params_i)
            log_probs[i] = head_i.log_prob(x_i) * m_i

            if not self.training and self.decoding:
                # draw samples for evaluation and decoding
                samples[i] = head_i.sample()
            elif not self.training and not self.decoding:
                # draw samples for evaluation
                samples[i] = head_i.sample()

        # Stack the log likelihoods
        log_prob = torch.stack(log_probs, dim=1)  # batch, features
        log_prob = log_prob.sum(dim=1)  # / (mask.sum(dim=1) + 1e-6)  # batch
        cat_samples = self._cat_samples(samples)  # batch, features

        # Compute the KL divergences
        # KL divergence for s
        # samples_s of shape (batch, dim_s)
        kl_s = self.kl_s(encoder_output)  # shape (batch,)

        # KL divergence for z
        pz_loc = self.prior_loc_z(encoder_output.samples_s)  # batch, dim_z
        mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
        mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
        kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)  # shape (batch,)

        loss = -torch.sum(log_prob - kl_s - kl_z) / (
            (torch.sum(mask) / mask.shape[-1]) + 1e-6
        )
        print(f"Loss: {loss}, num_samples: {mask.sum()}")
        return HivaeOutput(
            enc_s=samples_s,
            enc_z=samples_z,
            samples=cat_samples,
            loss=loss,
        )

    @torch.no_grad()
    def decode(
        self,
        encoding_z: torch.Tensor,
        encoding_s: torch.Tensor,
        normalization_params: NormalizationParameters,
    ) -> torch.Tensor:
        """Decoding logic for the decoder.

        Args:
            encoding_z (torch.Tensor): Encoding for z.
            encoding_s (torch.Tensor): Encoding for s.
            normalization_params (NormalizationParameters): Parameters for normalization.

        Returns:
            torch.Tensor: Decoded samples.

        Raises:
            ValueError: If no samples were drawn.
        """
        # Implement the decoding logic here
        assert not self.training, "Model should be in eval mode"
        # shared_y of shape (batch, dim_y)
        decoder_interim_representation = self.decoder_shared(
            encoding_z
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        if encoding_s.shape[1] != self.s_dim:
            encoding_s = F.one_hot(
                encoding_s.squeeze(1).long(), num_classes=self.s_dim
            ).float()  # batch, dim_s

        x_params = tuple(
            [
                head(decoder_interim_representation, encoding_s)
                for head in self.heads
            ]
        )
        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_params
        )
        _ = [
            head.dist(
                params=params,
            )
            for i, (head, params) in enumerate(zip(self.heads, x_params))
        ]
        if self.decoding:
            samples = [head.sample() for head in self.heads]
        else:
            samples = [head.sample() for head in self.heads]
        cat_samples = self._cat_samples(samples)
        if cat_samples is not None:
            return cat_samples
        else:
            raise ValueError("No samples were drawn")
colnames: List[str] cached property

Gets the column names of the data.

Returns:

Type Description
List[str]

List[str]: List of column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

prior_s: torch.distributions.OneHotCategorical property

Gets the prior distribution for s.

Returns:

Type Description
OneHotCategorical

torch.distributions.OneHotCategorical: Prior distribution for s.

__init__(variable_types, s_dim, z_dim, y_dim, mtl_method=('identity'), decoder_shared=nn.Identity())

Initialize the HIVAE Decoder.

Parameters:

Name Type Description Default
variable_types List[VariableType]

List of VariableType objects.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
mtl_method Tuple[str]

List of methods to use for multi-task learning.

('identity')
Source code in vambn/modelling/models/hivae/decoder.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    variable_types: VarTypes,
    s_dim: int,
    z_dim: int,
    y_dim: int,
    mtl_method: Tuple[str, ...] = ("identity",),
    decoder_shared: nn.Module = nn.Identity(),
):
    """
    Initialize the HIVAE Decoder.

    Args:
        variable_types (List[VariableType]): List of VariableType objects.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        mtl_method (Tuple[str]): List of methods to use for multi-task learning.
    """
    super().__init__()

    # prior distributions
    self.prior_s_val = nn.Parameter(
        torch.ones(s_dim) / s_dim, requires_grad=False
    )
    self.prior_loc_z = ModifiedLinear(s_dim, z_dim, bias=True)

    self.s_dim = s_dim
    self.z_dim = z_dim
    self.y_dim = y_dim
    self.variable_types = variable_types

    self.decoder_shared = decoder_shared
    self.internal_layer_norm = nn.Identity()  # nn.LayerNorm(self.z_dim)
    self.heads: List[
        RealHead | PosHead | CountHead | CatHead
    ] | nn.ModuleList = nn.ModuleList(
        [
            HEAD_DICT[variable_types[i].data_type](
                variable_types[i], s_dim, z_dim, y_dim
            )
            for i in range(len(variable_types))
        ]
    )

    self.mtl_methods = mtl_method
    self._mtl_module_y: nn.Module = moo.setup_moo(
        [MtlMethodParams(x) for x in mtl_method],
        num_tasks=len(self.heads),
    )
    self._mtl_module_s: nn.Module = moo.setup_moo(
        [MtlMethodParams(x) for x in mtl_method],
        num_tasks=len(self.heads),
    )
    self.moo_block = moo.MultiMOOForLoop(
        len(self.heads),
        moo_methods=(self._mtl_module_y, self._mtl_module_s),
    )

    self._decoding = True
decode(encoding_z, encoding_s, normalization_params)

Decoding logic for the decoder.

Parameters:

Name Type Description Default
encoding_z Tensor

Encoding for z.

required
encoding_s Tensor

Encoding for s.

required
normalization_params NormalizationParameters

Parameters for normalization.

required

Returns:

Type Description
Tensor

torch.Tensor: Decoded samples.

Raises:

Type Description
ValueError

If no samples were drawn.

Source code in vambn/modelling/models/hivae/decoder.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
@torch.no_grad()
def decode(
    self,
    encoding_z: torch.Tensor,
    encoding_s: torch.Tensor,
    normalization_params: NormalizationParameters,
) -> torch.Tensor:
    """Decoding logic for the decoder.

    Args:
        encoding_z (torch.Tensor): Encoding for z.
        encoding_s (torch.Tensor): Encoding for s.
        normalization_params (NormalizationParameters): Parameters for normalization.

    Returns:
        torch.Tensor: Decoded samples.

    Raises:
        ValueError: If no samples were drawn.
    """
    # Implement the decoding logic here
    assert not self.training, "Model should be in eval mode"
    # shared_y of shape (batch, dim_y)
    decoder_interim_representation = self.decoder_shared(
        encoding_z
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    if encoding_s.shape[1] != self.s_dim:
        encoding_s = F.one_hot(
            encoding_s.squeeze(1).long(), num_classes=self.s_dim
        ).float()  # batch, dim_s

    x_params = tuple(
        [
            head(decoder_interim_representation, encoding_s)
            for head in self.heads
        ]
    )
    x_params = Normalization.denormalize_params(
        x_params, self.variable_types, normalization_params
    )
    _ = [
        head.dist(
            params=params,
        )
        for i, (head, params) in enumerate(zip(self.heads, x_params))
    ]
    if self.decoding:
        samples = [head.sample() for head in self.heads]
    else:
        samples = [head.sample() for head in self.heads]
    cat_samples = self._cat_samples(samples)
    if cat_samples is not None:
        return cat_samples
    else:
        raise ValueError("No samples were drawn")
forward(data, mask, encoder_output, normalization_parameters)

Forward pass of the decoder.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Mask for the data.

required
encoder_output EncoderOutput

Output from the encoder.

required
normalization_parameters NormalizationParameters

Parameters for normalization.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Output from the decoder.

Source code in vambn/modelling/models/hivae/decoder.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def forward(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
    normalization_parameters: NormalizationParameters,
) -> HivaeOutput:
    """Forward pass of the decoder.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Mask for the data.
        encoder_output (EncoderOutput): Output from the encoder.
        normalization_parameters (NormalizationParameters): Parameters for normalization.

    Returns:
        HivaeOutput: Output from the decoder.
    """
    samples_z = encoder_output.samples_z
    decoder_representation = encoder_output.decoder_representation
    samples_s = encoder_output.samples_s

    # Obtaining the parameters for the decoder
    # decoder representation of shape batch x dim_z
    interim_decoder_representation = self.decoder_shared(
        decoder_representation
    )  # identity
    interim_decoder_representation = self.internal_layer_norm(
        interim_decoder_representation
    )  # identity
    if self.training:
        moo_out_s, moo_out_z = self.moo_block(
            samples_s, interim_decoder_representation
        )
        x_params = []
        for head, s_i, k_i in zip(self.heads, moo_out_s, moo_out_z):
            x_params.append(head(k_i, s_i))
        x_params = tuple(x_params)
    else:
        x_params = tuple(
            [
                head(interim_decoder_representation, samples_s)
                for head in self.heads
            ]
        )

    x_params = Normalization.denormalize_params(
        x_params, self.variable_types, normalization_parameters
    )

    # Compute the likelihood and kl divergences
    log_probs: List[torch.Tensor] = [torch.Tensor([-1])] * len(
        self.variable_types
    )
    samples: List[torch.Tensor] = [torch.Tensor([-1])] * len(
        self.variable_types
    )
    for i, (x_i, m_i, head_i, params_i) in enumerate(
        zip(data.T, mask.T, self.heads, x_params)
    ):
        head_i.dist(params_i)
        log_probs[i] = head_i.log_prob(x_i) * m_i

        if not self.training and self.decoding:
            # draw samples for evaluation and decoding
            samples[i] = head_i.sample()
        elif not self.training and not self.decoding:
            # draw samples for evaluation
            samples[i] = head_i.sample()

    # Stack the log likelihoods
    log_prob = torch.stack(log_probs, dim=1)  # batch, features
    log_prob = log_prob.sum(dim=1)  # / (mask.sum(dim=1) + 1e-6)  # batch
    cat_samples = self._cat_samples(samples)  # batch, features

    # Compute the KL divergences
    # KL divergence for s
    # samples_s of shape (batch, dim_s)
    kl_s = self.kl_s(encoder_output)  # shape (batch,)

    # KL divergence for z
    pz_loc = self.prior_loc_z(encoder_output.samples_s)  # batch, dim_z
    mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
    mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
    kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)  # shape (batch,)

    loss = -torch.sum(log_prob - kl_s - kl_z) / (
        (torch.sum(mask) / mask.shape[-1]) + 1e-6
    )
    print(f"Loss: {loss}, num_samples: {mask.sum()}")
    return HivaeOutput(
        enc_s=samples_s,
        enc_z=samples_z,
        samples=cat_samples,
        loss=loss,
    )
kl_s(encoder_output)

Computes the KL divergence for s.

Parameters:

Name Type Description Default
encoder_output EncoderOutput

Encoder output.

required

Returns:

Type Description
Tensor

torch.Tensor: KL divergence for s.

Source code in vambn/modelling/models/hivae/decoder.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def kl_s(self, encoder_output: EncoderOutput) -> torch.Tensor:
    """Computes the KL divergence for s.

    Args:
        encoder_output (EncoderOutput): Encoder output.

    Returns:
        torch.Tensor: KL divergence for s.
    """
    return torch.distributions.kl.kl_divergence(
        torch.distributions.OneHotCategorical(
            logits=encoder_output.logits_s, validate_args=False
        ),
        self.prior_s,
    )
kl_z(mean_qz, std_qz, mean_pz, std_pz)

Computes the KL divergence for z.

Parameters:

Name Type Description Default
mean_qz Tensor

Mean of the posterior distribution.

required
std_qz Tensor

Standard deviation of the posterior distribution.

required
mean_pz Tensor

Mean of the prior distribution.

required
std_pz Tensor

Standard deviation of the prior distribution.

required

Returns:

Type Description
Tensor

torch.Tensor: KL divergence for z.

Source code in vambn/modelling/models/hivae/decoder.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def kl_z(
    self,
    mean_qz: torch.Tensor,
    std_qz: torch.Tensor,
    mean_pz: torch.Tensor,
    std_pz: torch.Tensor,
) -> torch.Tensor:
    """Computes the KL divergence for z.

    Args:
        mean_qz (torch.Tensor): Mean of the posterior distribution.
        std_qz (torch.Tensor): Standard deviation of the posterior distribution.
        mean_pz (torch.Tensor): Mean of the prior distribution.
        std_pz (torch.Tensor): Standard deviation of the prior distribution.

    Returns:
        torch.Tensor: KL divergence for z.
    """
    return torch.distributions.kl.kl_divergence(
        torch.distributions.Normal(mean_qz, std_qz),
        torch.distributions.Normal(mean_pz, std_pz),
    ).sum(dim=-1)
prior_z(loc)

Gets the prior distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Location parameter for z.

required

Returns:

Type Description
Normal

torch.distributions.Normal: Prior distribution for z.

Source code in vambn/modelling/models/hivae/decoder.py
155
156
157
158
159
160
161
162
163
164
def prior_z(self, loc: torch.Tensor) -> torch.distributions.Normal:
    """Gets the prior distribution for z.

    Args:
        loc (torch.Tensor): Location parameter for z.

    Returns:
        torch.distributions.Normal: Prior distribution for z.
    """
    return torch.distributions.Normal(loc, torch.ones_like(loc))
LstmDecoder

Bases: Decoder

LSTM-based HIVAE Decoder class.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects. See VarTypes in data/dataclasses.py.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
num_timepoints int

Number of timepoints.

required
n_layers int

Number of LSTM layers. Defaults to 1.

1
mtl_method Tuple[str, ...]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
decoder_shared Module

Shared decoder module. Defaults to nn.Identity().

Identity()

Attributes:

Name Type Description
num_timepoints int

Number of timepoints.

n_layers int

Number of LSTM layers.

lstm_decoder LSTM

LSTM decoder module.

fc Linear

Fully connected layer.

Source code in vambn/modelling/models/hivae/decoder.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
class LstmDecoder(Decoder):
    """LSTM-based HIVAE Decoder class.

    Args:
        variable_types (VarTypes): List of VariableType objects. See VarTypes in
            data/dataclasses.py.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        num_timepoints (int): Number of timepoints.
        n_layers (int, optional): Number of LSTM layers. Defaults to 1.
        mtl_method (Tuple[str, ...], optional): List of methods to use for multi-task learning. 
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".  
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        decoder_shared (nn.Module, optional): Shared decoder module. Defaults to nn.Identity().

    Attributes:
        num_timepoints (int): Number of timepoints.
        n_layers (int): Number of LSTM layers.
        lstm_decoder (nn.LSTM): LSTM decoder module.
        fc (nn.Linear): Fully connected layer.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        s_dim: int,
        z_dim: int,
        y_dim: int,
        num_timepoints: int,
        n_layers: int = 1,
        mtl_method: Tuple[str, ...] = ("identity",),
        decoder_shared: nn.Module = nn.Identity(),
    ) -> None:
        """HIVAE Decoder

        Args:
            type_array (np.ndarray): Array containing the data type information (type, class, ndim)
            s_dim (int): Dimension of s space
            z_dim (int): Dimension of z space
            y_dim (int): Dimension of y space
            num_timepoints (int): Number of num_timepoints.
            mtl_method (Tuple[str], optional): List of methods to use for multi-task learning. Defaults to
                ["gradnorm", "pcgrad"].
        """
        super().__init__(
            variable_types=variable_types,
            s_dim=s_dim,
            z_dim=z_dim,
            y_dim=y_dim,
            mtl_method=mtl_method,
            decoder_shared=decoder_shared,
        )
        self.num_timepoints = num_timepoints
        self.n_layers = n_layers
        self.lstm_decoder = nn.LSTM(
            input_size=z_dim,
            hidden_size=z_dim,
            num_layers=self.n_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(z_dim, z_dim)

    def forward(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
        normalization_parameters: NormalizationParameters,
    ) -> HivaeOutput:
        """Forward pass of the LSTM decoder.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Mask for the data.
            encoder_output (EncoderOutput): Output from the encoder.
            normalization_parameters (NormalizationParameters): Parameters for normalization.

        Returns:
            HivaeOutput: Output from the decoder.
        """
        samples_z = encoder_output.samples_z
        decoder_representation = encoder_output.decoder_representation
        samples_s = encoder_output.samples_s

        # Obtaining the parameters for the decoder
        # decoder representation of shape batch x dim_z
        decoder_interim_representation = self.decoder_shared(
            decoder_representation
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        h0 = self.fc(decoder_interim_representation).repeat(self.n_layers, 1, 1)
        c0 = torch.zeros_like(h0)
        y_repeated = decoder_interim_representation.unsqueeze(1).repeat(
            1, self.num_timepoints, 1
        )
        decoder_interim_representation, _ = self.lstm_decoder(
            y_repeated, (h0, c0)
        )

        log_probs = [None] * self.num_timepoints
        samples = [None] * self.num_timepoints
        for t in range(self.num_timepoints):
            sub_data = data[:, t, :]
            sub_mask = mask[:, t, :]
            if self.training:
                x_params = tuple(
                    [
                        head(y_i, s_i)
                        for head, s_i, y_i in zip(  # type: ignore
                            self.heads,
                            *self.moo_block.forward(
                                samples_s,
                                decoder_interim_representation[:, t, :],
                            ),
                        )
                    ]
                )
            else:
                x_params = tuple(
                    [
                        head(decoder_interim_representation[:, t, :], samples_s)
                        for head in self.heads
                    ]
                )

            x_params = Normalization.denormalize_params(
                x_params, self.variable_types, normalization_parameters[t]
            )

            log_probs[t] = torch.stack(
                [
                    head_i.log_prob(
                        params=params_i,
                        data=d_i,
                    )
                    * m_i
                    for i, (head_i, params_i, d_i, m_i) in enumerate(
                        zip(self.heads, x_params, sub_data.T, sub_mask.T)
                    )
                ],
                dim=1,
            ).sum(-1)  # batch
            assert isinstance(
                log_probs[t], torch.Tensor
            ), f"Log probs: {log_probs[t]}"
            assert log_probs[t].shape == (
                encoder_output.samples_s.shape[0],
            ), f"Log probs shape: {log_probs[t].shape}"

            if not self.training and self.decoding:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )
            elif not self.training and not self.decoding:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )

        log_prob = torch.stack(log_probs, dim=1)  # batch, timepoints
        log_prob = log_prob.sum(
            dim=1
        )  # / (mask.sum(dim=(1, 2)) + 1e-9)  # batch
        if torch.isnan(log_prob).any() or torch.isinf(log_prob).any():
            raise ValueError("Log likelihood is nan or inf")
        samples = (
            torch.stack(samples, dim=1) if samples[0] is not None else None
        )

        # Compute the KL divergences
        # # KL divergence for s
        kl_s = self.kl_s(encoder_output)
        if torch.isnan(kl_s).any() or torch.isinf(kl_s).any():
            raise ValueError("KL divergence for s is nan or inf")

        # KL divergence for z
        pz_loc = self.prior_loc_z(encoder_output.samples_s)
        mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
        mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
        kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)

        assert kl_z.shape == (encoder_output.samples_s.shape[0],)
        if torch.isnan(kl_z).any() or torch.isinf(kl_z).any():
            raise ValueError("KL divergence for z is nan or inf")

        # Compute the loss
        loss = -torch.sum(log_prob - kl_s - kl_z) / (
            torch.sum(mask) / mask.shape[-1]
        )

        return HivaeOutput(
            enc_s=encoder_output.samples_s,
            enc_z=samples_z,
            samples=samples,
            loss=loss,
        )

    @torch.no_grad()
    def decode(
        self,
        encoding_z: torch.Tensor,
        encoding_s: torch.Tensor,
        normalization_params: NormalizationParameters,
    ):
        """Decoding logic for 3D data.

        Args:
            encoding_z (torch.Tensor): Encoding for z.
            encoding_s (torch.Tensor): Encoding for s.
            normalization_params (NormalizationParameters): Parameters for normalization.

        Returns:
            torch.Tensor: Decoded samples.
        """
        assert not self.training, "Model should be in eval mode"

        # shared_y of shape (batch, dim_y)
        decoder_interim_representation = self.decoder_shared(
            encoding_z
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        decoder_interim_representation = (
            decoder_interim_representation.unsqueeze(1).repeat(
                1, self.num_timepoints, 1
            )
        )
        decoder_interim_representation, _ = self.lstm_decoder(
            decoder_interim_representation
        )

        if encoding_s.shape[1] != self.s_dim:
            encoding_s = F.one_hot(
                encoding_s.squeeze(1).long(), num_classes=self.s_dim
            ).float()

        samples = [None] * self.num_timepoints
        for t in range(self.num_timepoints):
            x_params = tuple(
                [
                    head(decoder_interim_representation[:, t, :], encoding_s)
                    for head in self.heads
                ]
            )
            x_params = Normalization.denormalize_params(
                x_params, self.variable_types, normalization_params[t]
            )
            [head.dist(params) for head, params in zip(self.heads, x_params)]
            if self.decoding:
                samples[t] = self._cat_samples(
                    [head.sample() for head in self.heads]
                )
            else:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )
        time_dim_samples = torch.stack(samples, dim=1)
        return time_dim_samples
__init__(variable_types, s_dim, z_dim, y_dim, num_timepoints, n_layers=1, mtl_method=('identity'), decoder_shared=nn.Identity())

HIVAE Decoder

Parameters:

Name Type Description Default
type_array ndarray

Array containing the data type information (type, class, ndim)

required
s_dim int

Dimension of s space

required
z_dim int

Dimension of z space

required
y_dim int

Dimension of y space

required
num_timepoints int

Number of num_timepoints.

required
mtl_method Tuple[str]

List of methods to use for multi-task learning. Defaults to ["gradnorm", "pcgrad"].

('identity')
Source code in vambn/modelling/models/hivae/decoder.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def __init__(
    self,
    variable_types: VarTypes,
    s_dim: int,
    z_dim: int,
    y_dim: int,
    num_timepoints: int,
    n_layers: int = 1,
    mtl_method: Tuple[str, ...] = ("identity",),
    decoder_shared: nn.Module = nn.Identity(),
) -> None:
    """HIVAE Decoder

    Args:
        type_array (np.ndarray): Array containing the data type information (type, class, ndim)
        s_dim (int): Dimension of s space
        z_dim (int): Dimension of z space
        y_dim (int): Dimension of y space
        num_timepoints (int): Number of num_timepoints.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning. Defaults to
            ["gradnorm", "pcgrad"].
    """
    super().__init__(
        variable_types=variable_types,
        s_dim=s_dim,
        z_dim=z_dim,
        y_dim=y_dim,
        mtl_method=mtl_method,
        decoder_shared=decoder_shared,
    )
    self.num_timepoints = num_timepoints
    self.n_layers = n_layers
    self.lstm_decoder = nn.LSTM(
        input_size=z_dim,
        hidden_size=z_dim,
        num_layers=self.n_layers,
        batch_first=True,
    )
    self.fc = nn.Linear(z_dim, z_dim)
decode(encoding_z, encoding_s, normalization_params)

Decoding logic for 3D data.

Parameters:

Name Type Description Default
encoding_z Tensor

Encoding for z.

required
encoding_s Tensor

Encoding for s.

required
normalization_params NormalizationParameters

Parameters for normalization.

required

Returns:

Type Description

torch.Tensor: Decoded samples.

Source code in vambn/modelling/models/hivae/decoder.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
@torch.no_grad()
def decode(
    self,
    encoding_z: torch.Tensor,
    encoding_s: torch.Tensor,
    normalization_params: NormalizationParameters,
):
    """Decoding logic for 3D data.

    Args:
        encoding_z (torch.Tensor): Encoding for z.
        encoding_s (torch.Tensor): Encoding for s.
        normalization_params (NormalizationParameters): Parameters for normalization.

    Returns:
        torch.Tensor: Decoded samples.
    """
    assert not self.training, "Model should be in eval mode"

    # shared_y of shape (batch, dim_y)
    decoder_interim_representation = self.decoder_shared(
        encoding_z
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    decoder_interim_representation = (
        decoder_interim_representation.unsqueeze(1).repeat(
            1, self.num_timepoints, 1
        )
    )
    decoder_interim_representation, _ = self.lstm_decoder(
        decoder_interim_representation
    )

    if encoding_s.shape[1] != self.s_dim:
        encoding_s = F.one_hot(
            encoding_s.squeeze(1).long(), num_classes=self.s_dim
        ).float()

    samples = [None] * self.num_timepoints
    for t in range(self.num_timepoints):
        x_params = tuple(
            [
                head(decoder_interim_representation[:, t, :], encoding_s)
                for head in self.heads
            ]
        )
        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_params[t]
        )
        [head.dist(params) for head, params in zip(self.heads, x_params)]
        if self.decoding:
            samples[t] = self._cat_samples(
                [head.sample() for head in self.heads]
            )
        else:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )
    time_dim_samples = torch.stack(samples, dim=1)
    return time_dim_samples
forward(data, mask, encoder_output, normalization_parameters)

Forward pass of the LSTM decoder.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Mask for the data.

required
encoder_output EncoderOutput

Output from the encoder.

required
normalization_parameters NormalizationParameters

Parameters for normalization.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Output from the decoder.

Source code in vambn/modelling/models/hivae/decoder.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def forward(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
    normalization_parameters: NormalizationParameters,
) -> HivaeOutput:
    """Forward pass of the LSTM decoder.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Mask for the data.
        encoder_output (EncoderOutput): Output from the encoder.
        normalization_parameters (NormalizationParameters): Parameters for normalization.

    Returns:
        HivaeOutput: Output from the decoder.
    """
    samples_z = encoder_output.samples_z
    decoder_representation = encoder_output.decoder_representation
    samples_s = encoder_output.samples_s

    # Obtaining the parameters for the decoder
    # decoder representation of shape batch x dim_z
    decoder_interim_representation = self.decoder_shared(
        decoder_representation
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    h0 = self.fc(decoder_interim_representation).repeat(self.n_layers, 1, 1)
    c0 = torch.zeros_like(h0)
    y_repeated = decoder_interim_representation.unsqueeze(1).repeat(
        1, self.num_timepoints, 1
    )
    decoder_interim_representation, _ = self.lstm_decoder(
        y_repeated, (h0, c0)
    )

    log_probs = [None] * self.num_timepoints
    samples = [None] * self.num_timepoints
    for t in range(self.num_timepoints):
        sub_data = data[:, t, :]
        sub_mask = mask[:, t, :]
        if self.training:
            x_params = tuple(
                [
                    head(y_i, s_i)
                    for head, s_i, y_i in zip(  # type: ignore
                        self.heads,
                        *self.moo_block.forward(
                            samples_s,
                            decoder_interim_representation[:, t, :],
                        ),
                    )
                ]
            )
        else:
            x_params = tuple(
                [
                    head(decoder_interim_representation[:, t, :], samples_s)
                    for head in self.heads
                ]
            )

        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_parameters[t]
        )

        log_probs[t] = torch.stack(
            [
                head_i.log_prob(
                    params=params_i,
                    data=d_i,
                )
                * m_i
                for i, (head_i, params_i, d_i, m_i) in enumerate(
                    zip(self.heads, x_params, sub_data.T, sub_mask.T)
                )
            ],
            dim=1,
        ).sum(-1)  # batch
        assert isinstance(
            log_probs[t], torch.Tensor
        ), f"Log probs: {log_probs[t]}"
        assert log_probs[t].shape == (
            encoder_output.samples_s.shape[0],
        ), f"Log probs shape: {log_probs[t].shape}"

        if not self.training and self.decoding:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )
        elif not self.training and not self.decoding:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )

    log_prob = torch.stack(log_probs, dim=1)  # batch, timepoints
    log_prob = log_prob.sum(
        dim=1
    )  # / (mask.sum(dim=(1, 2)) + 1e-9)  # batch
    if torch.isnan(log_prob).any() or torch.isinf(log_prob).any():
        raise ValueError("Log likelihood is nan or inf")
    samples = (
        torch.stack(samples, dim=1) if samples[0] is not None else None
    )

    # Compute the KL divergences
    # # KL divergence for s
    kl_s = self.kl_s(encoder_output)
    if torch.isnan(kl_s).any() or torch.isinf(kl_s).any():
        raise ValueError("KL divergence for s is nan or inf")

    # KL divergence for z
    pz_loc = self.prior_loc_z(encoder_output.samples_s)
    mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
    mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
    kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)

    assert kl_z.shape == (encoder_output.samples_s.shape[0],)
    if torch.isnan(kl_z).any() or torch.isinf(kl_z).any():
        raise ValueError("KL divergence for z is nan or inf")

    # Compute the loss
    loss = -torch.sum(log_prob - kl_s - kl_z) / (
        torch.sum(mask) / mask.shape[-1]
    )

    return HivaeOutput(
        enc_s=encoder_output.samples_s,
        enc_z=samples_z,
        samples=samples,
        loss=loss,
    )

encoder

Encoder

Bases: Module

HIVAE Encoder.

Parameters:

Name Type Description Default
input_dim int

Dimension of input data (e.g., columns in dataframe).

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required

Attributes:

Name Type Description
input_dim int

Dimension of input data.

dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

encoder_s ModifiedLinear

Linear layer for s encoding.

encoder_z Module

Identity layer for z encoding.

param_z ModifiedLinear

Linear layer for z parameterization.

_tau float

Temperature parameter for Gumbel softmax.

_decoding bool

Flag indicating whether the model is in decoding mode.

Source code in vambn/modelling/models/hivae/encoder.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class Encoder(nn.Module):
    """HIVAE Encoder.

    Args:
        input_dim (int): Dimension of input data (e.g., columns in dataframe).
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.

    Attributes:
        input_dim (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        encoder_s (ModifiedLinear): Linear layer for s encoding.
        encoder_z (nn.Module): Identity layer for z encoding.
        param_z (ModifiedLinear): Linear layer for z parameterization.
        _tau (float): Temperature parameter for Gumbel softmax.
        _decoding (bool): Flag indicating whether the model is in decoding mode.
    """


    def __init__(self, input_dim: int, dim_s: int, dim_z: int):
        """HIVAE Encoder

        Args:
            input_dim (int): Dimension of input data (e.g. columns in dataframe)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
        """
        super().__init__()
        if input_dim <= 0:
            raise ValueError(
                f"Input dimension must be positive, got {input_dim}"
            )

        if dim_s <= 0:
            raise ValueError(f"S dimension must be positive, got {dim_s}")

        if dim_z <= 0:
            raise ValueError(f"Z dimension must be positive, got {dim_z}")

        self.input_dim = input_dim
        self.dim_s = dim_s
        self.dim_z = dim_z

        self.encoder_s = ModifiedLinear(self.input_dim, dim_s, bias=True)
        self.encoder_z = nn.Identity()
        self.param_z = ModifiedLinear(
            self.input_dim + dim_s, dim_z * 2, bias=True
        )

        self._tau = 1.0
        self._decoding = True

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self._decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self._decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self._tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.

        Raises:
            ValueError: If the temperature value is not positive.
        """
        if value <= 0:
            raise ValueError(f"Tau must be positive, got {value}")
        self._tau = value

    @staticmethod
    def q_z(loc, scale: Tensor) -> dists.Normal:
        """Creates a normal distribution for z.

        Args:
            loc (Tensor): Mean of the distribution.
            scale (Tensor): Standard deviation of the distribution.

        Returns:
            dists.Normal: Normal distribution.
        """
        return dists.Normal(loc, scale)

    def q_s(self, probs: Tensor) -> GumbelDistribution:
        """Creates a Gumbel distribution for s.

        Args:
            probs (Tensor): Probabilities for the Gumbel distribution.

        Returns:
            GumbelDistribution: Gumbel distribution.
        """
        return GumbelDistribution(probs=probs, temperature=self.tau)

    def forward(self, x: Tensor) -> EncoderOutput:
        """Forward pass of the encoder.

        Args:
            x (Tensor): Normalized input data.

        Raises:
            Exception: If samples contain NaN values.

        Returns:
            EncoderOutput: Contains samples, logits, and parameters.
        """
        logits_s = self.encoder_s(x)
        probs_s = F.softmax(logits_s, dim=-1).clamp(1e-6, 1 - 1e-6)

        if self.training:
            samples_s = self.q_s(probs_s).rsample()
        elif self.decoding:
            # NOTE:
            # The idea behind this was that we can use e.g. the mode during decoding
            # Experiments indicate that this is not helpful; the correlation between
            # the real and decoded data is improved, but the JSD is significantly worse.
            samples_s = self.q_s(probs_s).sample()
        else:
            samples_s = self.q_s(probs_s).sample()

        x_and_s = torch.cat((x, samples_s), dim=-1)

        h = self.encoder_z(x_and_s)

        loc_z, scale_z_unnorm = torch.chunk(self.param_z(h), 2, dim=-1)
        scale_z = F.softplus(scale_z_unnorm) + 1e-6

        samples_z = (
            self.q_z(loc_z, scale_z).rsample()
            if self.training
            else self.q_z(loc_z, scale_z).sample()
            if self.decoding
            else self.q_z(loc_z, scale_z).sample()
        )

        return EncoderOutput(
            samples_s=samples_s,
            samples_z=samples_z,
            logits_s=logits_s,
            mean_z=loc_z,
            scale_z=scale_z,
        )
decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

__init__(input_dim, dim_s, dim_z)

HIVAE Encoder

Parameters:

Name Type Description Default
input_dim int

Dimension of input data (e.g. columns in dataframe)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
Source code in vambn/modelling/models/hivae/encoder.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, input_dim: int, dim_s: int, dim_z: int):
    """HIVAE Encoder

    Args:
        input_dim (int): Dimension of input data (e.g. columns in dataframe)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
    """
    super().__init__()
    if input_dim <= 0:
        raise ValueError(
            f"Input dimension must be positive, got {input_dim}"
        )

    if dim_s <= 0:
        raise ValueError(f"S dimension must be positive, got {dim_s}")

    if dim_z <= 0:
        raise ValueError(f"Z dimension must be positive, got {dim_z}")

    self.input_dim = input_dim
    self.dim_s = dim_s
    self.dim_z = dim_z

    self.encoder_s = ModifiedLinear(self.input_dim, dim_s, bias=True)
    self.encoder_z = nn.Identity()
    self.param_z = ModifiedLinear(
        self.input_dim + dim_s, dim_z * 2, bias=True
    )

    self._tau = 1.0
    self._decoding = True
forward(x)

Forward pass of the encoder.

Parameters:

Name Type Description Default
x Tensor

Normalized input data.

required

Raises:

Type Description
Exception

If samples contain NaN values.

Returns:

Name Type Description
EncoderOutput EncoderOutput

Contains samples, logits, and parameters.

Source code in vambn/modelling/models/hivae/encoder.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def forward(self, x: Tensor) -> EncoderOutput:
    """Forward pass of the encoder.

    Args:
        x (Tensor): Normalized input data.

    Raises:
        Exception: If samples contain NaN values.

    Returns:
        EncoderOutput: Contains samples, logits, and parameters.
    """
    logits_s = self.encoder_s(x)
    probs_s = F.softmax(logits_s, dim=-1).clamp(1e-6, 1 - 1e-6)

    if self.training:
        samples_s = self.q_s(probs_s).rsample()
    elif self.decoding:
        # NOTE:
        # The idea behind this was that we can use e.g. the mode during decoding
        # Experiments indicate that this is not helpful; the correlation between
        # the real and decoded data is improved, but the JSD is significantly worse.
        samples_s = self.q_s(probs_s).sample()
    else:
        samples_s = self.q_s(probs_s).sample()

    x_and_s = torch.cat((x, samples_s), dim=-1)

    h = self.encoder_z(x_and_s)

    loc_z, scale_z_unnorm = torch.chunk(self.param_z(h), 2, dim=-1)
    scale_z = F.softplus(scale_z_unnorm) + 1e-6

    samples_z = (
        self.q_z(loc_z, scale_z).rsample()
        if self.training
        else self.q_z(loc_z, scale_z).sample()
        if self.decoding
        else self.q_z(loc_z, scale_z).sample()
    )

    return EncoderOutput(
        samples_s=samples_s,
        samples_z=samples_z,
        logits_s=logits_s,
        mean_z=loc_z,
        scale_z=scale_z,
    )
q_s(probs)

Creates a Gumbel distribution for s.

Parameters:

Name Type Description Default
probs Tensor

Probabilities for the Gumbel distribution.

required

Returns:

Name Type Description
GumbelDistribution GumbelDistribution

Gumbel distribution.

Source code in vambn/modelling/models/hivae/encoder.py
117
118
119
120
121
122
123
124
125
126
def q_s(self, probs: Tensor) -> GumbelDistribution:
    """Creates a Gumbel distribution for s.

    Args:
        probs (Tensor): Probabilities for the Gumbel distribution.

    Returns:
        GumbelDistribution: Gumbel distribution.
    """
    return GumbelDistribution(probs=probs, temperature=self.tau)
q_z(loc, scale) staticmethod

Creates a normal distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Mean of the distribution.

required
scale Tensor

Standard deviation of the distribution.

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution.

Source code in vambn/modelling/models/hivae/encoder.py
104
105
106
107
108
109
110
111
112
113
114
115
@staticmethod
def q_z(loc, scale: Tensor) -> dists.Normal:
    """Creates a normal distribution for z.

    Args:
        loc (Tensor): Mean of the distribution.
        scale (Tensor): Standard deviation of the distribution.

    Returns:
        dists.Normal: Normal distribution.
    """
    return dists.Normal(loc, scale)
LstmEncoder

Bases: Module

Encoder for longitudinal input.

Parameters:

Name Type Description Default
input_dimension int

Dimension of input data.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
n_layers int

Number of LSTM layers.

required
hidden_size Optional[int]

Size of the hidden layer. Defaults to None.

None

Attributes:

Name Type Description
dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

hidden_size int

Size of the hidden layer.

lstm LSTM

LSTM layer.

encoder Encoder

Encoder module.

Source code in vambn/modelling/models/hivae/encoder.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class LstmEncoder(nn.Module):
    """Encoder for longitudinal input.

    Args:
        input_dimension (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        n_layers (int): Number of LSTM layers.
        hidden_size (Optional[int], optional): Size of the hidden layer. Defaults to None.

    Attributes:
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        hidden_size (int): Size of the hidden layer.
        lstm (nn.LSTM): LSTM layer.
        encoder (Encoder): Encoder module.
    """

    @typeguard.typechecked
    def __init__(
        self,
        input_dimension: int,
        dim_s: int,
        dim_z: int,
        n_layers: int,
        hidden_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.dim_s = dim_s
        self.dim_z = dim_z

        if hidden_size is None:
            hidden_size = input_dimension

        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(
            input_size=input_dimension,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
        )
        self.encoder = Encoder(input_dim=hidden_size, dim_s=dim_s, dim_z=dim_z)

    def forward(self, input_data: Tensor) -> EncoderOutput:
        """Forward pass of the LSTM encoder.

        Args:
            input_data (Tensor): Time points/visits x batch size x variables/input size.

        Returns:
            EncoderOutput: Output for each time point.
        """
        out, (_, _) = self.lstm(input_data)
        # out: batch_size x time points x hidden size
        last_out = out[:, -1, :]
        encoder_output = self.encoder.forward(last_out)

        return encoder_output

    def q_z(self, loc, scale: Tensor) -> dists.Normal:
        """Creates a normal distribution for z.

        Args:
            loc (Tensor): Mean of the distribution.
            scale (Tensor): Standard deviation of the distribution.

        Returns:
            dists.Normal: Normal distribution.
        """
        return self.encoder.q_z(loc, scale)

    def q_s(self, probs: Tensor) -> GumbelDistribution:
        """Creates a Gumbel distribution for s.

        Args:
            probs (Tensor): Probabilities for the Gumbel distribution.

        Returns:
            GumbelDistribution: Gumbel distribution.
        """
        return self.encoder.q_s(probs)
forward(input_data)

Forward pass of the LSTM encoder.

Parameters:

Name Type Description Default
input_data Tensor

Time points/visits x batch size x variables/input size.

required

Returns:

Name Type Description
EncoderOutput EncoderOutput

Output for each time point.

Source code in vambn/modelling/models/hivae/encoder.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def forward(self, input_data: Tensor) -> EncoderOutput:
    """Forward pass of the LSTM encoder.

    Args:
        input_data (Tensor): Time points/visits x batch size x variables/input size.

    Returns:
        EncoderOutput: Output for each time point.
    """
    out, (_, _) = self.lstm(input_data)
    # out: batch_size x time points x hidden size
    last_out = out[:, -1, :]
    encoder_output = self.encoder.forward(last_out)

    return encoder_output
q_s(probs)

Creates a Gumbel distribution for s.

Parameters:

Name Type Description Default
probs Tensor

Probabilities for the Gumbel distribution.

required

Returns:

Name Type Description
GumbelDistribution GumbelDistribution

Gumbel distribution.

Source code in vambn/modelling/models/hivae/encoder.py
249
250
251
252
253
254
255
256
257
258
def q_s(self, probs: Tensor) -> GumbelDistribution:
    """Creates a Gumbel distribution for s.

    Args:
        probs (Tensor): Probabilities for the Gumbel distribution.

    Returns:
        GumbelDistribution: Gumbel distribution.
    """
    return self.encoder.q_s(probs)
q_z(loc, scale)

Creates a normal distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Mean of the distribution.

required
scale Tensor

Standard deviation of the distribution.

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution.

Source code in vambn/modelling/models/hivae/encoder.py
237
238
239
240
241
242
243
244
245
246
247
def q_z(self, loc, scale: Tensor) -> dists.Normal:
    """Creates a normal distribution for z.

    Args:
        loc (Tensor): Mean of the distribution.
        scale (Tensor): Standard deviation of the distribution.

    Returns:
        dists.Normal: Normal distribution.
    """
    return self.encoder.q_z(loc, scale)

gan_hivae

GanHivae

Bases: AbstractGanModel[Tensor, Tensor, HivaeOutput, HivaeEncoding]

GAN-enhanced HIVAE model.

Parameters:

Name Type Description Default
variable_types VarTypes

Types of variables.

required
input_dim int

Dimension of input data.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
n_layers int

Number of layers.

required
noise_size int

Size of the noise vector. Defaults to 10.

10
num_timepoints Optional[int]

Number of time points. Defaults to None.

None
module_name Optional[str]

Name of the module. Defaults to "GanHivae".

'GanHivae'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Whether to use an imputation layer. Defaults to False.

False
individual_model bool

Whether to use individual models. Defaults to True.

True

Attributes:

Name Type Description
noise_size int

Size of the noise vector.

is_longitudinal bool

Flag for longitudinal data.

model Hivae or LstmHivae

HIVAE model.

generator Generator

GAN generator.

discriminator Discriminator

GAN discriminator.

device device

Device to run the model on.

one Parameter

Parameter for GAN training.

mone Parameter

Parameter for GAN training.

Source code in vambn/modelling/models/hivae/gan_hivae.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class GanHivae(
    AbstractGanModel[torch.Tensor, torch.Tensor, HivaeOutput, HivaeEncoding]
):
    """GAN-enhanced HIVAE model.

    Args:
        variable_types (VarTypes): Types of variables.
        input_dim (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        n_layers (int): Number of layers.
        noise_size (int, optional): Size of the noise vector. Defaults to 10.
        num_timepoints (Optional[int], optional): Number of time points. Defaults to None.
        module_name (Optional[str], optional): Name of the module. Defaults to "GanHivae".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Whether to use an imputation layer. Defaults to False.
        individual_model (bool, optional): Whether to use individual models. Defaults to True.

    Attributes:
        noise_size (int): Size of the noise vector.
        is_longitudinal (bool): Flag for longitudinal data.
        model (Hivae or LstmHivae): HIVAE model.
        generator (Generator): GAN generator.
        discriminator (Discriminator): GAN discriminator.
        device (torch.device): Device to run the model on.
        one (nn.Parameter): Parameter for GAN training.
        mone (nn.Parameter): Parameter for GAN training.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        n_layers: int,
        noise_size: int = 10,
        num_timepoints: Optional[int] = None,
        module_name: Optional[str] = "GanHivae",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ):
        super().__init__()

        self.noise_size = noise_size
        if num_timepoints is not None and num_timepoints > 1:
            self.is_longitudinal = True
            self.model = LstmHivae(
                variable_types=variable_types,
                input_dim=input_dim,
                dim_s=dim_s,
                dim_z=dim_z,
                dim_y=dim_y,
                n_layers=n_layers,
                num_timepoints=num_timepoints,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                individual_model=individual_model,
            )
        else:
            self.is_longitudinal = False
            self.model = Hivae(
                variable_types=variable_types,
                input_dim=input_dim,
                dim_s=dim_s,
                dim_z=dim_z,
                dim_y=dim_y,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                individual_model=individual_model,
            )

        self.generator = Generator(noise_size, (8, 4), dim_z)
        self.discriminator = Discriminator(dim_z, (8, 4), 1)
        #
        self.device = torch.device("cpu")
        self.one = nn.Parameter(
            torch.FloatTensor([1.0])[0], requires_grad=False
        )
        self.mone = nn.Parameter(
            torch.FloatTensor([-1.0])[0], requires_grad=False
        )

    def _train_gan_generator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> torch.Tensor:
        """Performs a training step for the GAN generator.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the generator.

        Returns:
            torch.Tensor: Generator loss.
        """
        optimizer.zero_grad()
        noise = torch.randn(data.shape[0], self.noise_size)
        fake_hidden = self.generator(noise)
        errG = self.discriminator(fake_hidden)
        self.fabric.backward(errG, self.one)
        optimizer.step()
        return errG

    def _train_gan_discriminator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs a training step for the GAN discriminator.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the discriminator.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Discriminator loss, real loss, and fake loss.
        """
        optimizer.zero_grad()
        _, _, encoder_output = self.model.encoder_part(data, mask)
        real_hidden = encoder_output.samples_z
        errD_real = self.discriminator(real_hidden.detach())
        self.fabric.backward(errD_real, self.one)

        noise = torch.randn(data.shape[0], self.noise_size)
        fake_hidden = self.generator(noise)
        errD_fake = self.discriminator(fake_hidden.detach())
        self.fabric.backward(errD_fake, self.mone)

        gradient_penaltiy = self._calc_gradient_penalty(
            real_hidden, fake_hidden
        )
        self.fabric.backward(gradient_penaltiy)
        optimizer.step()
        errD = -(errD_real - errD_fake)
        return errD, errD_real, errD_fake

    def _train_model_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> HivaeOutput:
        """Performs a training step for the HIVAE model.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the model.

        Returns:
            HivaeOutput: Model output.
        """
        optimizer.zero_grad()
        output = self.model.forward(data, mask)
        loss = output.loss
        # if loss > 1e10:
        #     raise optuna.TrialPruned("Unsuited hyperparameters. High loss")
        self.fabric.backward(loss)
        optimizer.step()
        return output

    def _train_model_from_discriminator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> torch.Tensor:
        """Performs a training step for the model from the discriminator's perspective.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the model.

        Returns:
            torch.Tensor: Discriminator real loss.
        """
        optimizer.zero_grad()
        _, _, encoder_output = self.model.encoder_part(data, mask)
        real_hidden = encoder_output.samples_z
        errD_real = self.discriminator(real_hidden)
        self.fabric.backward(errD_real, self.mone)
        optimizer.step()
        return errD_real

    def _get_loss_from_output(self, output: HivaeOutput) -> torch.Tensor:
        """Extracts the loss from the model output.

        Args:
            output (HivaeOutput): Model output.

        Returns:
            torch.Tensor: Loss value.
        """
        return output.loss

    def _get_number_of_items(self, mask: torch.Tensor) -> int:
        """Gets the number of items in the mask.

        Args:
            mask (torch.Tensor): Data mask.

        Returns:
            int: Number of items.
        """
        return mask.sum()

    @cached_property
    def colnames(self) -> Tuple[str]:
        """Gets the column names of the data.

        Returns:
            Tuple[str]: Column names.
        """
        return self.model.colnames

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self.model.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self.model.decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self.model.tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.
        """
        self.model.tau = value

    @property
    def normalization_parameters(self) -> NormalizationParameters:
        """NormalizationParameters: Parameters for normalization."""
        return self.model.normalization_parameters

    @normalization_parameters.setter
    def normalization_parameters(self, value: NormalizationParameters) -> None:
        """Sets the normalization parameters.

        Args:
            value (NormalizationParameters): Normalization parameters.
        """
        self.model.normalization_parameters = value

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """Forward pass of the model.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            HivaeOutput: Model output.
        """
        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """Performs the decoder part of the forward pass.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: Model output.
        """
        return self.model.decoder_part(data, mask, encoder_output)

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """Performs the encoder part of the forward pass.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.
        """
        return self.model.encoder_part(data, mask)

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """Decodes the given encoding.

        Args:
            encoding (HivaeEncoding): Encoding to decode.

        Returns:
            torch.Tensor: Decoded data.
        """
        return self.model.decode(encoding)

    def _training_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        optimizer: Tuple[Optimizer],
    ) -> float:
        """Training step is not needed for this class."""
        raise Exception(f"Method not needed for {self.__class__.__name__}")

    def _validation_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> float:
        """Performs a validation step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            float: Validation loss.
        """
        return self.model._validation_step(data, mask)

    def _test_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> float:
        """Performs a test step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            float: Test loss.
        """
        return self.model._test_step(data, mask)

    def _predict_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> HivaeOutput:
        """Performs a prediction step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            HivaeOutput: Prediction output.
        """
        return self.model._predict_step(data, mask)
colnames: Tuple[str] cached property

Gets the column names of the data.

Returns:

Type Description
Tuple[str]

Tuple[str]: Column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

normalization_parameters: NormalizationParameters property writable

NormalizationParameters: Parameters for normalization.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

decode(encoding)

Decodes the given encoding.

Parameters:

Name Type Description Default
encoding HivaeEncoding

Encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: Decoded data.

Source code in vambn/modelling/models/hivae/gan_hivae.py
333
334
335
336
337
338
339
340
341
342
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """Decodes the given encoding.

    Args:
        encoding (HivaeEncoding): Encoding to decode.

    Returns:
        torch.Tensor: Decoded data.
    """
    return self.model.decode(encoding)
decoder_part(data, mask, encoder_output)

Performs the decoder part of the forward pass.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """Performs the decoder part of the forward pass.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: Model output.
    """
    return self.model.decoder_part(data, mask, encoder_output)
encoder_part(data, mask)

Performs the encoder part of the forward pass.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.

Source code in vambn/modelling/models/hivae/gan_hivae.py
319
320
321
322
323
324
325
326
327
328
329
330
331
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """Performs the encoder part of the forward pass.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.
    """
    return self.model.encoder_part(data, mask)
forward(data, mask)

Forward pass of the model.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
287
288
289
290
291
292
293
294
295
296
297
298
299
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """Forward pass of the model.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.

    Returns:
        HivaeOutput: Model output.
    """
    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output
GanModularHivae

Bases: AbstractGanModularModel[Tuple[Tensor, ...], Tuple[Tensor, ...], ModularHivaeOutput, ModularHivaeEncoding]

GAN-enhanced Modular HIVAE model.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for the data modules.

required
dim_s int | Dict[str, int]

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_ys int

Dimension of YS space.

required
dim_y int | Dict[str, int]

Dimension of y space.

required
noise_size int

Size of the noise vector. Defaults to 10.

10
shared_element_type str

Type of shared element. Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Whether to use an imputation layer. Defaults to False.

False

Attributes:

Name Type Description
module_configs Tuple[DataModuleConfig]

Configuration for the data modules.

mtl_method Tuple[str, ...]

Methods for multi-task learning.

use_imputation_layer bool

Whether to use an imputation layer.

dim_s int | Dict[str, int]

Dimension of s space.

dim_z int

Dimension of z space.

dim_ys int

Dimension of YS space.

dim_y int | Dict[str, int]

Dimension of y space.

model ModularHivae

Modular HIVAE model.

generators ModuleList

List of GAN generators.

discriminators ModuleList

List of GAN discriminators.

device device

Device to run the model on.

one Parameter

Parameter for GAN training.

mone Parameter

Parameter for GAN training.

noise_size int

Size of the noise vector.

Source code in vambn/modelling/models/hivae/gan_hivae.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
class GanModularHivae(
    AbstractGanModularModel[
        Tuple[Tensor, ...],
        Tuple[Tensor, ...],
        ModularHivaeOutput,
        ModularHivaeEncoding,
    ]
):
    """GAN-enhanced Modular HIVAE model.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for the data modules.
        dim_s (int | Dict[str, int]): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_ys (int): Dimension of YS space.
        dim_y (int | Dict[str, int]): Dimension of y space.
        noise_size (int, optional): Size of the noise vector. Defaults to 10.
        shared_element_type (str, optional): Type of shared element. Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Whether to use an imputation layer. Defaults to False.

    Attributes:
        module_configs (Tuple[DataModuleConfig]): Configuration for the data modules.
        mtl_method (Tuple[str, ...]): Methods for multi-task learning.
        use_imputation_layer (bool): Whether to use an imputation layer.
        dim_s (int | Dict[str, int]): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_ys (int): Dimension of YS space.
        dim_y (int | Dict[str, int]): Dimension of y space.
        model (ModularHivae): Modular HIVAE model.
        generators (nn.ModuleList): List of GAN generators.
        discriminators (nn.ModuleList): List of GAN discriminators.
        device (torch.device): Device to run the model on.
        one (nn.Parameter): Parameter for GAN training.
        mone (nn.Parameter): Parameter for GAN training.
        noise_size (int): Size of the noise vector.
    """

    def __init__(
        self,
        module_config: Tuple[DataModuleConfig],
        dim_s: int | Dict[str, int],
        dim_z: int,
        dim_ys: int,
        dim_y: int | Dict[str, int],
        noise_size: int = 10,
        shared_element_type: str = "none",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
    ):
        super().__init__()
        self.module_configs = module_config
        self.mtl_method = mtl_method
        self.use_imputation_layer = use_imputation_layer
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_ys = dim_ys
        self.dim_y = dim_y

        self.model = ModularHivae(
            module_config=module_config,
            dim_s=dim_s,
            dim_z=dim_z,
            dim_ys=dim_ys,
            dim_y=dim_y,
            shared_element_type=shared_element_type,
            mtl_method=mtl_method,
            use_imputation_layer=use_imputation_layer,
        )

        self.generators = nn.ModuleList(
            [
                Generator(noise_size, (8, 4), dim_z)
                for _ in range(len(module_config))
            ]
        )
        self.discriminators = nn.ModuleList(
            [Discriminator(dim_z, (8, 4), 1) for _ in range(len(module_config))]
        )

        self.device = torch.device("cpu")
        self.one = nn.Parameter(
            torch.FloatTensor([1.0])[0], requires_grad=False
        )
        self.mone = nn.Parameter(
            torch.FloatTensor([-1.0])[0], requires_grad=False
        )
        self.noise_size = noise_size

    def _train_gan_generator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor]:
        """Performs a training step for the GAN generators.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the generators.

        Returns:
            Tuple[Tensor]: Generator losses.
        """
        errG_list = []
        for data, mask, generator, discriminator, opt in zip(
            data, mask, self.generators, self.discriminators, optimizer
        ):
            opt.zero_grad()
            noise = torch.randn(data.shape[0], self.noise_size)
            fake_hidden = generator(noise)
            errG = discriminator(fake_hidden)
            self.fabric.backward(errG, self.one)
            opt.step()
            errG_list.append(errG.detach())
        return tuple(errG_list)

    def _train_gan_discriminator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Performs a training step for the GAN discriminators.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the discriminators.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Discriminator loss, real loss, and fake loss.
        """
        errD_list = []
        errD_real_list = []
        errD_fake_list = []
        for data, mask, generator, discriminator, opt, module in zip(
            data,
            mask,
            self.generators,
            self.discriminators,
            optimizer,
            self.model.module_models.values(),
        ):
            opt.zero_grad()
            _, _, encoder_output = module.encoder_part(data, mask)
            real_hidden = encoder_output.samples_z
            errD_real = discriminator(real_hidden.detach())
            self.fabric.backward(errD_real, self.one)

            noise = torch.randn((data.shape[0], self.noise_size))
            fake_hidden = generator(noise)
            errD_fake = discriminator(fake_hidden.detach())
            self.fabric.backward(errD_fake, self.mone)

            gradient_penaltiy = self._calc_gradient_penalty(
                real_hidden, fake_hidden, discriminator
            )
            self.fabric.backward(gradient_penaltiy)
            opt.step()
            errD = -(errD_real - errD_fake)
            errD_list.append(errD.detach())
            errD_real_list.append(errD_real.detach())
            errD_fake_list.append(errD_fake.detach())
        return tuple(errD_list), tuple(errD_real_list), tuple(errD_fake_list)

    def _train_model_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> ModularHivaeOutput:
        """Performs a training step for the Modular HIVAE model.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the model.

        Returns:
            ModularHivaeOutput: Model output.
        """
        for opt in optimizer:
            if opt is None:
                continue
            opt.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        # if output.loss > 1e10:
        #     raise optuna.TrialPruned("Unsuited hyperparameters. High loss")
        for opt in optimizer:
            if opt is None:
                continue
            opt.step()
        return output

    def _train_model_from_discriminator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor]:
        """Performs a training step for the model from the discriminator's perspective.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the model.

        Returns:
            Tuple[Tensor]: Discriminator real losses.
        """
        errD_real_list = []
        for data, mask, discriminator, opt, module in zip(
            data,
            mask,
            self.discriminators,
            optimizer,
            self.model.module_models.values(),
        ):
            opt.zero_grad()
            _, _, encoder_output = module.encoder_part(data, mask)
            real_hidden = encoder_output.samples_z
            errD_real = discriminator(real_hidden)
            self.fabric.backward(errD_real, self.mone)
            opt.step()
            errD_real_list.append(errD_real.detach())
        return tuple(errD_real_list)

    def _get_loss_from_output(self, output: ModularHivaeOutput) -> float:
        """Extracts the loss from the model output.

        Args:
            output (ModularHivaeOutput): Model output.

        Returns:
            float: Loss value.
        """
        return output.avg_loss

    def _get_number_of_items(self, mask: Tuple[Tensor]) -> int:
        """Gets the number of items in the mask.

        Args:
            mask (Tuple[Tensor]): Data mask.

        Returns:
            int: Number of items.
        """
        return sum([m.sum() for m in mask])

    @cached_property
    def colnames(self, module_name: str) -> Tuple[str, ...]:
        """Gets the column names for a specific module.

        Args:
            module_name (str): Name of the module.

        Returns:
            Tuple[str, ...]: Column names.
        """
        return self.model.module_models[module_name].colnames

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self.model.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self.model.decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self.model.tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.
        """
        self.model.tau = value

    def forward(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """Forward pass of the model.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            ModularHivaeOutput: Model output.
        """
        return self.model.forward(data, mask)

    def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
        """Decodes the given encoding.

        Args:
            encoding (ModularHivaeEncoding): Encoding to decode.

        Returns:
            Tuple[Tensor, ...]: Decoded data.
        """
        return self.model.decode(encoding)

    def _training_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> float:
        """Training step is not needed for this class."""
        raise Exception(f"Method not needed for {self.__class__.__name__}")

    def _validation_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> float:
        """Performs a validation step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            float: Validation loss.
        """
        return self.model._validation_step(data, mask)

    def _test_step(self, data: Tuple[Tensor], mask: Tuple[Tensor]) -> float:
        """Performs a test step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            float: Test loss.
        """
        return self.model._test_step(data, mask)

    def _predict_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """Performs a prediction step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            ModularHivaeOutput: Prediction output.
        """
        return self.model._predict_step(data, mask)
colnames: Tuple[str, ...] cached property

Gets the column names for a specific module.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: Column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

decode(encoding)

Decodes the given encoding.

Parameters:

Name Type Description Default
encoding ModularHivaeEncoding

Encoding to decode.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[Tensor, ...]: Decoded data.

Source code in vambn/modelling/models/hivae/gan_hivae.py
708
709
710
711
712
713
714
715
716
717
def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
    """Decodes the given encoding.

    Args:
        encoding (ModularHivaeEncoding): Encoding to decode.

    Returns:
        Tuple[Tensor, ...]: Decoded data.
    """
    return self.model.decode(encoding)
forward(data, mask)

Forward pass of the model.

Parameters:

Name Type Description Default
data Tuple[Tensor]

Input data.

required
mask Tuple[Tensor]

Data mask.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
694
695
696
697
698
699
700
701
702
703
704
705
706
def forward(
    self, data: Tuple[Tensor], mask: Tuple[Tensor]
) -> ModularHivaeOutput:
    """Forward pass of the model.

    Args:
        data (Tuple[Tensor]): Input data.
        mask (Tuple[Tensor]): Data mask.

    Returns:
        ModularHivaeOutput: Model output.
    """
    return self.model.forward(data, mask)

heads

BaseModuleHead

Bases: Generic[ParameterType, DistributionType], Module, ABC

Base class for different data types.

Parameters:

Name Type Description Default
variable_type VariableType

Dataclass containing the type information.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required

Attributes:

Name Type Description
types VariableType

Dataclass containing the type information.

dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

dim_y int

Dimension of y space.

internal_pass ModifiedLinear

Internal pass module.

Properties

num_parameters (int): Number of parameters.

Methods:

Name Description
forward

Tensor, samples_s: Tensor) -> ParameterType: Forward pass of the module.

dist

ParameterType) -> DistributionType: Compute the distribution given the parameters.

log_prob

Tensor, params: Optional[ParameterType] = None) -> Tensor: Compute the log probability of the data given the parameters.

sample

Sample from the distribution.

rsample

Sample using the reparameterization trick.

mode

Compute the mode of the distribution.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class BaseModuleHead(
    Generic[ParameterType, DistributionType], nn.Module, ABC
):
    """Base class for different data types.

    Args:
        variable_type (VariableType): Dataclass containing the type information.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.

    Attributes:
        types (VariableType): Dataclass containing the type information.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        internal_pass (ModifiedLinear): Internal pass module.

    Properties:
        num_parameters (int): Number of parameters.

    Methods:
        forward(samples_z: Tensor, samples_s: Tensor) -> ParameterType: Forward pass of the module.
        dist(params: ParameterType) -> DistributionType: Compute the distribution given the parameters.
        log_prob(data: Tensor, params: Optional[ParameterType] = None) -> Tensor: Compute the log probability of the data given the parameters.
        sample() -> Tensor: Sample from the distribution.
        rsample() -> Tensor: Sample using the reparameterization trick.
        mode() -> Tensor: Compute the mode of the distribution.

    Raises:
        RuntimeError: If the distribution is not initialized.

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Base class for the different datatypes

        Args:
            types (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__()
        # General parameters
        self.types = variable_type
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_y = dim_y
        self.internal_pass = ModifiedLinear(self.dim_z, self.dim_y, bias=True)

        # Specific parameters
        self._dist = None
        self._n_pars = None

    @abstractmethod
    def forward(self, samples_z: Tensor, samples_s: Tensor) -> ParameterType:
        """Forward pass of the module.

        Args:
            samples_z (Tensor): Samples from the z space.
            samples_s (Tensor): Samples from the s space.

        Returns:
            ParameterType: The output of the forward pass.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError

    @abstractmethod
    def dist(self, params: ParameterType) -> DistributionType:
        """Compute the distribution given the parameters.

        Args:
            params (ParameterType): The parameters of the distribution.

        Returns:
            DistributionType: The computed distribution.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError

    def log_prob(
        self, data: Tensor, params: Optional[ParameterType] = None
    ) -> Tensor:
        """Compute the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (Optional[ParameterType], optional): The parameters of the distribution. Defaults to None.

        Returns:
            Tensor: The log probability of the data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None and params is None:
            raise RuntimeError("Distribution is not initialized.")
        elif params is not None:
            self.dist(params)

        return self._dist.log_prob(data)

    def sample(self) -> Tensor:
        """Sample from the distribution.

        Returns:
            Tensor: The sampled data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")

        gen_sample = self._dist.sample()
        if gen_sample.ndim == 1:
            return gen_sample.unsqueeze(1)
        elif gen_sample.ndim == 0:
            # ensure 2 dim output
            return gen_sample.unsqueeze(0).unsqueeze(1)
        return gen_sample

    def rsample(self) -> Tensor:
        """Sample using the reparameterization trick.

        Returns:
            Tensor: The sampled data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")

        gen_sample = self._dist.rsample()
        if gen_sample.ndim == 1:
            return gen_sample.unsqueeze(1)
        return gen_sample

    @property
    def mode(self) -> Tensor:
        """Compute the mode of the distribution.

        Returns:
            Tensor: The mode of the distribution.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")
        res = self._dist.mode
        if res.ndim == 1:
            return res.unsqueeze(1)
        return res

    @property
    def num_parameters(self) -> int:
        """Number of parameters.

        Returns:
            int: The number of parameters.
        """
        return self._n_pars
mode: Tensor property

Compute the mode of the distribution.

Returns:

Name Type Description
Tensor Tensor

The mode of the distribution.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

num_parameters: int property

Number of parameters.

Returns:

Name Type Description
int int

The number of parameters.

__init__(variable_type, dim_s, dim_z, dim_y)

Base class for the different datatypes

Parameters:

Name Type Description Default
types VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Base class for the different datatypes

    Args:
        types (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__()
    # General parameters
    self.types = variable_type
    self.dim_s = dim_s
    self.dim_z = dim_z
    self.dim_y = dim_y
    self.internal_pass = ModifiedLinear(self.dim_z, self.dim_y, bias=True)

    # Specific parameters
    self._dist = None
    self._n_pars = None
dist(params) abstractmethod

Compute the distribution given the parameters.

Parameters:

Name Type Description Default
params ParameterType

The parameters of the distribution.

required

Returns:

Name Type Description
DistributionType DistributionType

The computed distribution.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in vambn/modelling/models/hivae/heads.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@abstractmethod
def dist(self, params: ParameterType) -> DistributionType:
    """Compute the distribution given the parameters.

    Args:
        params (ParameterType): The parameters of the distribution.

    Returns:
        DistributionType: The computed distribution.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError
forward(samples_z, samples_s) abstractmethod

Forward pass of the module.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space.

required
samples_s Tensor

Samples from the s space.

required

Returns:

Name Type Description
ParameterType ParameterType

The output of the forward pass.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in vambn/modelling/models/hivae/heads.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@abstractmethod
def forward(self, samples_z: Tensor, samples_s: Tensor) -> ParameterType:
    """Forward pass of the module.

    Args:
        samples_z (Tensor): Samples from the z space.
        samples_s (Tensor): Samples from the s space.

    Returns:
        ParameterType: The output of the forward pass.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError
log_prob(data, params=None)

Compute the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params Optional[ParameterType]

The parameters of the distribution. Defaults to None.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def log_prob(
    self, data: Tensor, params: Optional[ParameterType] = None
) -> Tensor:
    """Compute the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (Optional[ParameterType], optional): The parameters of the distribution. Defaults to None.

    Returns:
        Tensor: The log probability of the data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None and params is None:
        raise RuntimeError("Distribution is not initialized.")
    elif params is not None:
        self.dist(params)

    return self._dist.log_prob(data)
rsample()

Sample using the reparameterization trick.

Returns:

Name Type Description
Tensor Tensor

The sampled data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def rsample(self) -> Tensor:
    """Sample using the reparameterization trick.

    Returns:
        Tensor: The sampled data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None:
        raise RuntimeError("Distribution is not initialized.")

    gen_sample = self._dist.rsample()
    if gen_sample.ndim == 1:
        return gen_sample.unsqueeze(1)
    return gen_sample
sample()

Sample from the distribution.

Returns:

Name Type Description
Tensor Tensor

The sampled data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def sample(self) -> Tensor:
    """Sample from the distribution.

    Returns:
        Tensor: The sampled data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None:
        raise RuntimeError("Distribution is not initialized.")

    gen_sample = self._dist.sample()
    if gen_sample.ndim == 1:
        return gen_sample.unsqueeze(1)
    elif gen_sample.ndim == 0:
        # ensure 2 dim output
        return gen_sample.unsqueeze(0).unsqueeze(1)
    return gen_sample
CatHead

Bases: BaseModuleHead[CategoricalParameters, ReparameterizedCategorical]

Class representing the categorical head of a model.

Attributes:

Name Type Description
variable_type VariableType

Array containing the data type information (type, class, ndim)

dim_s int

Dimension of s space

dim_z int

Dimension of z space

dim_y int

Dimension of y space

logit_layer ModifiedLinear

Linear layer for computing logits

_dist_class ReparameterizedCategorical

Class for representing reparameterized categorical distribution

Source code in vambn/modelling/models/hivae/heads.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
class CatHead(
    BaseModuleHead[
        CategoricalParameters, ReparameterizedCategorical
    ]
):
    """Class representing the categorical head of a model.

    Attributes:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
        logit_layer (ModifiedLinear): Linear layer for computing logits
        _dist_class (ReparameterizedCategorical): Class for representing reparameterized categorical distribution
    """

    # @typechecked
    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initialize the CatHead class.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.logit_layer = ModifiedLinear(
            self.dim_y + self.dim_s, variable_type.n_parameters, bias=False
        )
        self._dist_class = ReparameterizedCategorical

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> CategoricalParameters:
        """Forward pass of the CatHead.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            CategoricalParameters: Categorical parameters
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        logits = self.logit_layer(s_and_y)
        return CategoricalParameters(logits)

    def dist(self, params: CategoricalParameters) -> ReparameterizedCategorical:
        """Compute the reparameterized categorical distribution.

        Args:
            params (CategoricalParameters): Categorical parameters

        Returns:
            ReparameterizedCategorical: Reparameterized categorical distribution
        """
        self._dist = self._dist_class(logits=params.logits)
        return self._dist
__init__(variable_type, dim_s, dim_z, dim_y)

Initialize the CatHead class.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initialize the CatHead class.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.logit_layer = ModifiedLinear(
        self.dim_y + self.dim_s, variable_type.n_parameters, bias=False
    )
    self._dist_class = ReparameterizedCategorical
dist(params)

Compute the reparameterized categorical distribution.

Parameters:

Name Type Description Default
params CategoricalParameters

Categorical parameters

required

Returns:

Name Type Description
ReparameterizedCategorical ReparameterizedCategorical

Reparameterized categorical distribution

Source code in vambn/modelling/models/hivae/heads.py
554
555
556
557
558
559
560
561
562
563
564
def dist(self, params: CategoricalParameters) -> ReparameterizedCategorical:
    """Compute the reparameterized categorical distribution.

    Args:
        params (CategoricalParameters): Categorical parameters

    Returns:
        ReparameterizedCategorical: Reparameterized categorical distribution
    """
    self._dist = self._dist_class(logits=params.logits)
    return self._dist
forward(samples_z, samples_s)

Forward pass of the CatHead.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
CategoricalParameters CategoricalParameters

Categorical parameters

Source code in vambn/modelling/models/hivae/heads.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def forward(self, samples_z: Tensor, samples_s: Tensor) -> CategoricalParameters:
    """Forward pass of the CatHead.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        CategoricalParameters: Categorical parameters
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    logits = self.logit_layer(s_and_y)
    return CategoricalParameters(logits)
CountHead

Bases: BaseModuleHead[PoissonParameters, Poisson]

Head module for the Poisson distribution (Count data).

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
lambda_layer ModifiedLinear

Linear layer for computing the rate parameter

_dist_class Poisson

Class for representing the Poisson distribution

Source code in vambn/modelling/models/hivae/heads.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
class CountHead(BaseModuleHead[PoissonParameters, dists.Poisson]):
    """Head module for the Poisson distribution (Count data).

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        lambda_layer (ModifiedLinear): Linear layer for computing the rate parameter
        _dist_class (dists.Poisson): Class for representing the Poisson distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initializes the CountHead class.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.lambda_layer = ModifiedLinear(
            self.dim_y + self.dim_s, 1, bias=False
        )
        self._dist_class = dists.Poisson

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> PoissonParameters:
        """Performs the forward pass of the CountHead.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            PoissonParameters: The Poisson parameter
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        rate = self.lambda_layer(s_and_y)
        return PoissonParameters(rate)

    def dist(self, params: PoissonParameters) -> dists.Poisson:
        """Creates a Poisson distribution from the given parameters.

        Args:
            params (PoissonParameters): The Poisson parameters

        Returns:
            dists.Poisson: The Poisson distribution
        """
        self._dist = self._dist_class(params.rate.squeeze())
        return self._dist
__init__(variable_type, dim_s, dim_z, dim_y)

Initializes the CountHead class.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initializes the CountHead class.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.lambda_layer = ModifiedLinear(
        self.dim_y + self.dim_s, 1, bias=False
    )
    self._dist_class = dists.Poisson
dist(params)

Creates a Poisson distribution from the given parameters.

Parameters:

Name Type Description Default
params PoissonParameters

The Poisson parameters

required

Returns:

Type Description
Poisson

dists.Poisson: The Poisson distribution

Source code in vambn/modelling/models/hivae/heads.py
492
493
494
495
496
497
498
499
500
501
502
def dist(self, params: PoissonParameters) -> dists.Poisson:
    """Creates a Poisson distribution from the given parameters.

    Args:
        params (PoissonParameters): The Poisson parameters

    Returns:
        dists.Poisson: The Poisson distribution
    """
    self._dist = self._dist_class(params.rate.squeeze())
    return self._dist
forward(samples_z, samples_s)

Performs the forward pass of the CountHead.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
PoissonParameters PoissonParameters

The Poisson parameter

Source code in vambn/modelling/models/hivae/heads.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def forward(self, samples_z: Tensor, samples_s: Tensor) -> PoissonParameters:
    """Performs the forward pass of the CountHead.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        PoissonParameters: The Poisson parameter
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    rate = self.lambda_layer(s_and_y)
    return PoissonParameters(rate)
PosHead

Bases: BaseModuleHead[LogNormalParameters, LogNormal]

Head module for the LogNormal (pos) distribution

Attributes:

Name Type Description
variable_type VariableType

The type of variable.

dim_s int

The dimension of s.

dim_z int

The dimension of z.

dim_y int

The dimension of y.

loc_layer ModifiedLinear

The linear layer for computing the location parameter.

scale_layer ModifiedLinear

The linear layer for computing the scale parameter.

_dist_class LogNormal

The class representing the LogNormal distribution.

_n_pars int

The number of parameters in the distribution.

Source code in vambn/modelling/models/hivae/heads.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class PosHead(
    BaseModuleHead[LogNormalParameters, dists.LogNormal]
):
    """
    Head module for the LogNormal (pos) distribution

    Attributes:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimension of s.
        dim_z (int): The dimension of z.
        dim_y (int): The dimension of y.
        loc_layer (ModifiedLinear): The linear layer for computing the location parameter.
        scale_layer (ModifiedLinear): The linear layer for computing the scale parameter.
        _dist_class (dists.LogNormal): The class representing the LogNormal distribution.
        _n_pars (int): The number of parameters in the distribution.

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """
        Initializes the PosHead class.

        Args:
            variable_type (VariableType): The type of variable.
            dim_s (int): The dimension of s.
            dim_z (int): The dimension of z.
            dim_y (int): The dimension of y.
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._dist_class = dists.LogNormal
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> LogNormalParameters:
        """
        Performs the forward pass of the PosHead.

        Args:
            samples_z (Tensor): The z samples.
            samples_s (Tensor): The s samples.

        Returns:
            LogNormalParameters: The output of the forward pass.
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)

        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return LogNormalParameters(loc, scale)

    def dist(self, params: LogNormalParameters) -> dists.LogNormal:
        """
        Creates a LogNormal distribution based on the given parameters.

        Args:
            params (LogNormalParameters): The parameters of the distribution.

        Returns:
            dists.LogNormal: The LogNormal distribution.
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze()
        )
        return self._dist

    def log_prob(
        self, data: Tensor, params: LogNormalParameters | None = None
    ) -> Tensor:
        """
        Computes the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (LogNormalParameters | None): The parameters of the distribution.

        Returns:
            Tensor: The log probability of the data.
        """
        return super().log_prob(data.clamp(min=1e-3), params)
__init__(variable_type, dim_s, dim_z, dim_y)

Initializes the PosHead class.

Parameters:

Name Type Description Default
variable_type VariableType

The type of variable.

required
dim_s int

The dimension of s.

required
dim_z int

The dimension of z.

required
dim_y int

The dimension of y.

required
Source code in vambn/modelling/models/hivae/heads.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """
    Initializes the PosHead class.

    Args:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimension of s.
        dim_z (int): The dimension of z.
        dim_y (int): The dimension of y.
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._dist_class = dists.LogNormal
    self._n_pars = 2
dist(params)

Creates a LogNormal distribution based on the given parameters.

Parameters:

Name Type Description Default
params LogNormalParameters

The parameters of the distribution.

required

Returns:

Type Description
LogNormal

dists.LogNormal: The LogNormal distribution.

Source code in vambn/modelling/models/hivae/heads.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def dist(self, params: LogNormalParameters) -> dists.LogNormal:
    """
    Creates a LogNormal distribution based on the given parameters.

    Args:
        params (LogNormalParameters): The parameters of the distribution.

    Returns:
        dists.LogNormal: The LogNormal distribution.
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze()
    )
    return self._dist
forward(samples_z, samples_s)

Performs the forward pass of the PosHead.

Parameters:

Name Type Description Default
samples_z Tensor

The z samples.

required
samples_s Tensor

The s samples.

required

Returns:

Name Type Description
LogNormalParameters LogNormalParameters

The output of the forward pass.

Source code in vambn/modelling/models/hivae/heads.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def forward(self, samples_z: Tensor, samples_s: Tensor) -> LogNormalParameters:
    """
    Performs the forward pass of the PosHead.

    Args:
        samples_z (Tensor): The z samples.
        samples_s (Tensor): The s samples.

    Returns:
        LogNormalParameters: The output of the forward pass.
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)

    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return LogNormalParameters(loc, scale)
log_prob(data, params=None)

Computes the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params LogNormalParameters | None

The parameters of the distribution.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Source code in vambn/modelling/models/hivae/heads.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def log_prob(
    self, data: Tensor, params: LogNormalParameters | None = None
) -> Tensor:
    """
    Computes the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (LogNormalParameters | None): The parameters of the distribution.

    Returns:
        Tensor: The log probability of the data.
    """
    return super().log_prob(data.clamp(min=1e-3), params)
RealHead

Bases: BaseModuleHead[NormalParameters, Normal]

Class representing the RealHead module.

This module is used for data of type real or pos.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
loc_layer ModifiedLinear

Linear layer for computing the location parameter

scale_layer ModifiedLinear

Linear layer for computing the scale parameter

_dist_class type

Class representing the distribution

_n_pars int

Number of parameters in the distribution

Source code in vambn/modelling/models/hivae/heads.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class RealHead(BaseModuleHead[NormalParameters, dists.Normal]):
    """Class representing the RealHead module.

    This module is used for data of type real or pos.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        loc_layer (ModifiedLinear): Linear layer for computing the location parameter
        scale_layer (ModifiedLinear): Linear layer for computing the scale parameter
        _dist_class (type): Class representing the distribution
        _n_pars (int): Number of parameters in the distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initialize the RealHead module.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._parameter_class = NormalParameters
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
        """Forward pass of the RealHead module.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            NormalParameters: Parameters of the Normal distribution
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return NormalParameters(loc, scale)

    def dist(self, params: NormalParameters) -> dists.Normal:
        """Create a Normal distribution based on the given parameters.

        Args:
            params (NormalParameters): Parameters of the Normal distribution

        Returns:
            dists.Normal: Normal distribution
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze()
        )
        return self._dist
__init__(variable_type, dim_s, dim_z, dim_y)

Initialize the RealHead module.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initialize the RealHead module.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._parameter_class = NormalParameters
    self._n_pars = 2
dist(params)

Create a Normal distribution based on the given parameters.

Parameters:

Name Type Description Default
params NormalParameters

Parameters of the Normal distribution

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution

Source code in vambn/modelling/models/hivae/heads.py
255
256
257
258
259
260
261
262
263
264
265
266
267
def dist(self, params: NormalParameters) -> dists.Normal:
    """Create a Normal distribution based on the given parameters.

    Args:
        params (NormalParameters): Parameters of the Normal distribution

    Returns:
        dists.Normal: Normal distribution
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze()
    )
    return self._dist
forward(samples_z, samples_s)

Forward pass of the RealHead module.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
NormalParameters NormalParameters

Parameters of the Normal distribution

Source code in vambn/modelling/models/hivae/heads.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
    """Forward pass of the RealHead module.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        NormalParameters: Parameters of the Normal distribution
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return NormalParameters(loc, scale)
TruncatedNormalHead

Bases: BaseModuleHead[NormalParameters, TruncatedNormal]

Class representing the TruncatedNormalHead module.

This module is used for data of type real or pos.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
loc_layer ModifiedLinear

Linear layer for computing the location parameter

scale_layer ModifiedLinear

Linear layer for computing the scale parameter

_dist_class type

Class representing the distribution

_n_pars int

Number of parameters in the distribution

Source code in vambn/modelling/models/hivae/heads.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class TruncatedNormalHead(
    BaseModuleHead[NormalParameters, TruncatedNormal]
):
    """
    Class representing the TruncatedNormalHead module.

    This module is used for data of type real or pos.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        loc_layer (ModifiedLinear): Linear layer for computing the location parameter
        scale_layer (ModifiedLinear): Linear layer for computing the scale parameter
        _dist_class (type): Class representing the distribution
        _n_pars (int): Number of parameters in the distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """
        Initializes a TruncatedNormalHead object.

        Args:
            variable_type (VariableType): The type of variable.
            dim_s (int): The dimensionality of s.
            dim_z (int): The dimensionality of z.
            dim_y (int): The dimensionality of y.
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._dist_class = TruncatedNormal
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
        """
        Performs a forward pass through the TruncatedNormalHead.

        Args:
            samples_z (Tensor): The z samples.
            samples_s (Tensor): The s samples.

        Returns:
            NormalParameters: The output of the forward pass.
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return NormalParameters(loc, scale)

    def dist(self, params: NormalParameters) -> TruncatedNormal:
        """
        Creates a TruncatedNormal distribution based on the given parameters.

        Args:
            params (NormalParameters): The parameters of the distribution.

        Returns:
            TruncatedNormal: The created TruncatedNormal distribution.
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze(), low=torch.tensor(0.0)
        )
        return self._dist

    def log_prob(
        self, data: Tensor, params: NormalParameters | None = None
    ) -> Tensor:
        """
        Computes the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (NormalParameters | None): The parameters of the distribution.

        Returns:
            Tensor: The log probability of the data.
        """
        return super().log_prob(data.clamp(min=1e-3), params)
__init__(variable_type, dim_s, dim_z, dim_y)

Initializes a TruncatedNormalHead object.

Parameters:

Name Type Description Default
variable_type VariableType

The type of variable.

required
dim_s int

The dimensionality of s.

required
dim_z int

The dimensionality of z.

required
dim_y int

The dimensionality of y.

required
Source code in vambn/modelling/models/hivae/heads.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """
    Initializes a TruncatedNormalHead object.

    Args:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimensionality of s.
        dim_z (int): The dimensionality of z.
        dim_y (int): The dimensionality of y.
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._dist_class = TruncatedNormal
    self._n_pars = 2
dist(params)

Creates a TruncatedNormal distribution based on the given parameters.

Parameters:

Name Type Description Default
params NormalParameters

The parameters of the distribution.

required

Returns:

Name Type Description
TruncatedNormal TruncatedNormal

The created TruncatedNormal distribution.

Source code in vambn/modelling/models/hivae/heads.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def dist(self, params: NormalParameters) -> TruncatedNormal:
    """
    Creates a TruncatedNormal distribution based on the given parameters.

    Args:
        params (NormalParameters): The parameters of the distribution.

    Returns:
        TruncatedNormal: The created TruncatedNormal distribution.
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze(), low=torch.tensor(0.0)
    )
    return self._dist
forward(samples_z, samples_s)

Performs a forward pass through the TruncatedNormalHead.

Parameters:

Name Type Description Default
samples_z Tensor

The z samples.

required
samples_s Tensor

The s samples.

required

Returns:

Name Type Description
NormalParameters NormalParameters

The output of the forward pass.

Source code in vambn/modelling/models/hivae/heads.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
    """
    Performs a forward pass through the TruncatedNormalHead.

    Args:
        samples_z (Tensor): The z samples.
        samples_s (Tensor): The s samples.

    Returns:
        NormalParameters: The output of the forward pass.
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return NormalParameters(loc, scale)
log_prob(data, params=None)

Computes the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params NormalParameters | None

The parameters of the distribution.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Source code in vambn/modelling/models/hivae/heads.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def log_prob(
    self, data: Tensor, params: NormalParameters | None = None
) -> Tensor:
    """
    Computes the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (NormalParameters | None): The parameters of the distribution.

    Returns:
        Tensor: The log probability of the data.
    """
    return super().log_prob(data.clamp(min=1e-3), params)

hivae

Hivae

Bases: AbstractNormalModel[Tensor, Tensor, HivaeOutput, HivaeEncoding]

Entire HIVAE model containing Encoder and Decoder structure.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects defining the types of the variables in the data.

required
input_dim int

Dimension of input data (number of columns in the dataframe). If the data contains categorical variables, the input dimension is larger than the number of features.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
module_name str

Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.

'HIVAE'
mtl_method Tuple[str]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
individual_model bool

Flag to indicate if the current model is applied individually or as part of e.g. a modular HIVAE. Defaults to True.

True
Source code in vambn/modelling/models/hivae/hivae.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class Hivae(
    AbstractNormalModel[torch.Tensor, torch.Tensor, HivaeOutput, HivaeEncoding]
):
    """
    Entire HIVAE model containing Encoder and Decoder structure.

    Args:
        variable_types (VarTypes): List of VariableType objects defining the types
            of the variables in the data.
        input_dim (int): Dimension of input data (number of columns in the dataframe).
            If the data contains categorical variables, the input dimension is
            larger than the number of features.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        module_name (str, optional): Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        individual_model (bool, optional): Flag to indicate if the current model
            is applied individually or as part of e.g. a modular HIVAE. Defaults to True.
    """


    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        module_name: Optional[str] = "HIVAE",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ) -> None:


        super().__init__()

        self.variable_types = variable_types
        self.input_dim = input_dim
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_y = dim_y
        self.individual_model = individual_model

        self.use_imp_layer = use_imputation_layer
        self.module_name = module_name

        # Imputation layer
        if use_imputation_layer:
            self.imputation_layer = ImputationLayer(self.input_dim)
        else:
            self.imputation_layer = None

        # Normalization parameters
        self.register_buffer(
            "_mean_data",
            torch.zeros(len(self.variable_types), requires_grad=False),
        )
        self.register_buffer(
            "_std_data",
            torch.ones(len(self.variable_types), requires_grad=False),
        )
        self._batch_mean_data = self._batch_std_data = None

        self.encoder = Encoder(input_dim=input_dim, dim_s=dim_s, dim_z=dim_z)
        self.decoder = Decoder(
            variable_types=variable_types,
            s_dim=dim_s,
            z_dim=dim_z,
            y_dim=dim_y,
            mtl_method=mtl_method,
            decoder_shared=nn.Identity(),
        )
        self._temperature = 1.0
        self.tau = self._temperature

        # set to cpu by default
        self.device = torch.device("cpu")
        self.module_name = module_name

    @cached_property
    def colnames(self) -> Tuple[str, ...]:
        """
        Tuple of column names derived from variable types.

        Returns:
            Tuple[str, ...]: A tuple containing column names.
        """

        return tuple([v.name for v in self.variable_types])

    @property
    def decoding(self) -> bool:
        """
        Decoding flag indicating if the encoder and decoder are in decoding mode.

        Returns:
            bool: Decoding flag.
        """
        assert self.encoder.decoding == self.decoder.decoding
        return self.encoder.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """
        Sets the decoding flag for both encoder and decoder.

        Args:
            value (bool): The decoding flag to set.
        """
        self.encoder.decoding = value
        self.decoder.decoding = value

    @property
    def tau(self) -> float:
        """
        Gets the temperature parameter for the model.

        Returns:
            float: The temperature parameter.
        """
        return self._temperature

    @tau.setter
    def tau(self, value: float) -> None:
        """
        Sets the temperature parameter for the model.

        Args:
            value (float): The temperature parameter to set.

        Raises:
            ValueError: If the value is not positive.
        """

        if value <= 0:
            raise ValueError(f"Tau must be positive, got {value}")
        self._temperature = value
        self.encoder.tau = value

    @property
    def normalization_parameters(self) -> NormalizationParameters:
        """
        Gets the normalization parameters (mean and standard deviation).

        Returns:
            NormalizationParameters: The normalization parameters.
        """

        if self._batch_std_data is not None:
            return NormalizationParameters(
                mean=self._batch_mean_data, std=self._batch_std_data
            )
        else:
            return NormalizationParameters(
                mean=self._mean_data, std=self._std_data
            )

    @normalization_parameters.setter
    def normalization_parameters(self, value: NormalizationParameters) -> None:
        """
        Sets the normalization parameters (mean and standard deviation).

        Args:
            value (NormalizationParameters): The normalization parameters to set.
        """

        if self.training:
            # calculate the mean and std of the data with momentum
            momentum = 0.01
            # Choosing a momentum of 0.01 to update the mean and std of the data
            # this leads to a smoother update of the mean and std of the data
            new_mean = (1 - momentum) * self._mean_data + momentum * value.mean
            new_std = (1 - momentum) * self._std_data + momentum * value.std
            self._mean_data = new_mean
            self._std_data = new_std
            self._batch_mean_data = value.mean
            self._batch_std_data = value.std
        else:
            logger.warning(
                "Running parameters are not updated during evaluation"
            )
            self._batch_mean_data = value.mean
            self._batch_std_data = value.std

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """
        Forward pass through the HIVAE model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the HIVAE model.
        """

        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """
        Pass through the decoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: The output of the decoder.
        """

        decoder_output = self.decoder(
            data=data,
            mask=mask,
            encoder_output=encoder_output,
            normalization_parameters=self.normalization_parameters,
        )

        return decoder_output

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """
        Pass through the encoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

        Raises:
            ValueError: If data tensor has invalid shape.
        """

        if data.ndim == 3 and data.shape[1] == 1:
            data = data.squeeze(1)
            mask = mask.squeeze(1)
        elif data.ndim == 3:
            raise ValueError(
                f"Data should have shape (batch_size, n_features), got {data.shape}"
            )

        self._batch_mean_data = None
        self._batch_std_data = None
        (
            new_x,
            new_mask,
            self.normalization_parameters,
        ) = Normalization.normalize_data(
            data, mask, variable_types=self.variable_types, prior_parameters=self.normalization_parameters
        )

        if self.use_imp_layer and self.imputation_layer is not None:
            new_x = self.imputation_layer(new_x, new_mask)

        encoder_output = self.encoder(new_x)
        return data, mask, encoder_output

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (HivaeEncoding): The encoding to decode.

        Returns:
            torch.Tensor: The reconstructed data tensor.
        """

        return self.decoder.decode(
            encoding_s=encoding.s,
            encoding_z=encoding.decoder_representation,
            normalization_params=self.normalization_parameters,
        )

    def _predict_step(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> HivaeOutput:
        """
        Prediction step without gradient calculation.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the HIVAE model.
        """

        with torch.no_grad():
            return self.forward(data, mask)

    def _test_step(self, data: torch.Tensor, mask: torch.Tensor) -> float:
        """
        Test step to evaluate the model on test data.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            float: The test loss.
        """

        with torch.no_grad():
            return self.forward(data, mask).loss.detach()

    def _training_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> float:
        """
        Training step to update the model parameters.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            optimizer (Optimizer): Optimizer for updating model parameters.

        Returns:
            float: The training loss.
        """

        optimizer.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        optimizer.step()
        return output.loss.detach()

    def _validation_step(self, data: torch.Tensor, mask: torch.Tensor) -> float:
        """
        Validation step to evaluate the model on validation data.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            float: The validation loss.
        """

        with torch.no_grad():
            return self.forward(data, mask).loss.detach()
colnames: Tuple[str, ...] cached property

Tuple of column names derived from variable types.

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: A tuple containing column names.

decoding: bool property writable

Decoding flag indicating if the encoder and decoder are in decoding mode.

Returns:

Name Type Description
bool bool

Decoding flag.

normalization_parameters: NormalizationParameters property writable

Gets the normalization parameters (mean and standard deviation).

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

The normalization parameters.

tau: float property writable

Gets the temperature parameter for the model.

Returns:

Name Type Description
float float

The temperature parameter.

decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding HivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: The reconstructed data tensor.

Source code in vambn/modelling/models/hivae/hivae.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (HivaeEncoding): The encoding to decode.

    Returns:
        torch.Tensor: The reconstructed data tensor.
    """

    return self.decoder.decode(
        encoding_s=encoding.s,
        encoding_z=encoding.decoder_representation,
        normalization_params=self.normalization_parameters,
    )
decoder_part(data, mask, encoder_output)

Pass through the decoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the decoder.

Source code in vambn/modelling/models/hivae/hivae.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """
    Pass through the decoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: The output of the decoder.
    """

    decoder_output = self.decoder(
        data=data,
        mask=mask,
        encoder_output=encoder_output,
        normalization_parameters=self.normalization_parameters,
    )

    return decoder_output
encoder_part(data, mask)

Pass through the encoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

Raises:

Type Description
ValueError

If data tensor has invalid shape.

Source code in vambn/modelling/models/hivae/hivae.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """
    Pass through the encoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

    Raises:
        ValueError: If data tensor has invalid shape.
    """

    if data.ndim == 3 and data.shape[1] == 1:
        data = data.squeeze(1)
        mask = mask.squeeze(1)
    elif data.ndim == 3:
        raise ValueError(
            f"Data should have shape (batch_size, n_features), got {data.shape}"
        )

    self._batch_mean_data = None
    self._batch_std_data = None
    (
        new_x,
        new_mask,
        self.normalization_parameters,
    ) = Normalization.normalize_data(
        data, mask, variable_types=self.variable_types, prior_parameters=self.normalization_parameters
    )

    if self.use_imp_layer and self.imputation_layer is not None:
        new_x = self.imputation_layer(new_x, new_mask)

    encoder_output = self.encoder(new_x)
    return data, mask, encoder_output
forward(data, mask)

Forward pass through the HIVAE model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the HIVAE model.

Source code in vambn/modelling/models/hivae/hivae.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """
    Forward pass through the HIVAE model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        HivaeOutput: The output of the HIVAE model.
    """

    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output
LstmHivae

Bases: Hivae

LSTM-based HIVAE model with Encoder and Decoder structure for temporal data.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects defining the types of the variables in the data.

required
input_dim int

Dimension of input data (number of columns in the dataframe). If the data contains categorical variables, the input dimension is larger than the number of features.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
n_layers int

Number of layers in the LSTM.

required
num_timepoints int

Number of time points in the temporal data.

required
module_name str

Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.

'HIVAE'
mtl_method Tuple[str]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
individual_model bool

Flag to indicate if the current model is applied individually or as part of e.g. a modular HIVAE. Defaults to True.

True
Source code in vambn/modelling/models/hivae/hivae.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
class LstmHivae(Hivae):
    """
    LSTM-based HIVAE model with Encoder and Decoder structure for temporal data.

    Args:
        variable_types (VarTypes): List of VariableType objects defining the types
            of the variables in the data.
        input_dim (int): Dimension of input data (number of columns in the dataframe).
            If the data contains categorical variables, the input dimension is
            larger than the number of features.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        n_layers (int): Number of layers in the LSTM.
        num_timepoints (int): Number of time points in the temporal data.
        module_name (str, optional): Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        individual_model (bool, optional): Flag to indicate if the current model
            is applied individually or as part of e.g. a modular HIVAE. Defaults to True.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        n_layers: int,
        num_timepoints: int,
        module_name: str | None = "HIVAE",
        mtl_method: Tuple[str] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ) -> None:
        super().__init__(
            variable_types,
            input_dim,
            dim_s,
            dim_z,
            dim_y,
            module_name,
            mtl_method,
            use_imputation_layer,
            individual_model,
        )
        self.n_layers = n_layers
        self.num_timepoints = num_timepoints

        self.encoder = LstmEncoder(
            input_dimension=input_dim,
            dim_s=dim_s,
            dim_z=dim_z,
            n_layers=n_layers,
            hidden_size=input_dim,
        )
        self.decoder = LstmDecoder(
            mtl_method=mtl_method,
            n_layers=n_layers,
            num_timepoints=num_timepoints,
            s_dim=dim_s,
            variable_types=variable_types,
            y_dim=dim_y,
            z_dim=dim_z,
            decoder_shared=nn.Identity(),
        )
        self.imputation_layer = nn.ModuleList(
            [ImputationLayer(input_dim) for _ in range(num_timepoints)]
        )

        # set normalization params for each timepoint
        self.register_buffer(
            "_mean_data",
            torch.zeros(
                num_timepoints, len(self.variable_types), requires_grad=False
            ),
        )
        self.register_buffer(
            "_std_data",
            torch.ones(
                num_timepoints, len(self.variable_types), requires_grad=False
            ),
        )

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """
        Forward pass through the LSTM HIVAE model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the LSTM HIVAE model.
        """

        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """
        Pass through the decoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: The output of the decoder.
        """

        decoder_output = self.decoder(
            data=data,
            mask=mask,
            encoder_output=encoder_output,
            normalization_parameters=self.normalization_parameters,
        )

        return decoder_output

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """
        Pass through the encoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.
        """

        time_point_x = []
        time_point_mask = []
        normalization_parameters = [None] * self.num_timepoints
        for i in range(self.num_timepoints):
            (
                new_x,
                new_mask,
                normalization_parameters[i],
            ) = Normalization.normalize_data(
                data[:, i], mask[:, i], variable_types=self.variable_types, prior_parameters=self.normalization_parameters[i]
            )
            time_point_x.append(new_x)
            time_point_mask.append(new_mask)

        self.normalization_parameters = reduce(
            lambda x, y: x + y, normalization_parameters
        )
        new_x = torch.stack(time_point_x, dim=1)
        new_mask = torch.stack(time_point_mask, dim=1)

        if self.use_imp_layer and self.imputation_layer is not None:
            for i in range(self.num_timepoints):
                new_x[:, i] = self.imputation_layer[i](
                    new_x[:, i], new_mask[:, i]
                )

        encoder_output = self.encoder(new_x)
        return data, mask, encoder_output

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (HivaeEncoding): The encoding to decode.

        Returns:
            torch.Tensor: The reconstructed data tensor.
        """

        return self.decoder.decode(
            encoding_s=encoding.s,
            encoding_z=encoding.decoder_representation,
            normalization_params=self.normalization_parameters,
        )
decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding HivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: The reconstructed data tensor.

Source code in vambn/modelling/models/hivae/hivae.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (HivaeEncoding): The encoding to decode.

    Returns:
        torch.Tensor: The reconstructed data tensor.
    """

    return self.decoder.decode(
        encoding_s=encoding.s,
        encoding_z=encoding.decoder_representation,
        normalization_params=self.normalization_parameters,
    )
decoder_part(data, mask, encoder_output)

Pass through the decoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the decoder.

Source code in vambn/modelling/models/hivae/hivae.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """
    Pass through the decoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: The output of the decoder.
    """

    decoder_output = self.decoder(
        data=data,
        mask=mask,
        encoder_output=encoder_output,
        normalization_parameters=self.normalization_parameters,
    )

    return decoder_output
encoder_part(data, mask)

Pass through the encoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

Source code in vambn/modelling/models/hivae/hivae.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """
    Pass through the encoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.
    """

    time_point_x = []
    time_point_mask = []
    normalization_parameters = [None] * self.num_timepoints
    for i in range(self.num_timepoints):
        (
            new_x,
            new_mask,
            normalization_parameters[i],
        ) = Normalization.normalize_data(
            data[:, i], mask[:, i], variable_types=self.variable_types, prior_parameters=self.normalization_parameters[i]
        )
        time_point_x.append(new_x)
        time_point_mask.append(new_mask)

    self.normalization_parameters = reduce(
        lambda x, y: x + y, normalization_parameters
    )
    new_x = torch.stack(time_point_x, dim=1)
    new_mask = torch.stack(time_point_mask, dim=1)

    if self.use_imp_layer and self.imputation_layer is not None:
        for i in range(self.num_timepoints):
            new_x[:, i] = self.imputation_layer[i](
                new_x[:, i], new_mask[:, i]
            )

    encoder_output = self.encoder(new_x)
    return data, mask, encoder_output
forward(data, mask)

Forward pass through the LSTM HIVAE model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the LSTM HIVAE model.

Source code in vambn/modelling/models/hivae/hivae.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """
    Forward pass through the LSTM HIVAE model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        HivaeOutput: The output of the LSTM HIVAE model.
    """

    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output

modular

ModularHivae

Bases: AbstractModularModel[Tuple[Tensor, ...], Tuple[Tensor, ...], ModularHivaeOutput, ModularHivaeEncoding]

Modular HIVAE model containing multiple data modules.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for each data module. See DataModuleConfig for details.

required
dim_s int | Dict[str, int]

Number of mixture components for each module individually (dict) or a single value for all modules (int).

required
dim_z int

Dimension of the latent space. Equal for all modules.

required
dim_ys int

Dimension of the latent space ys. Equal for all modules.

required
dim_y int | Dict[str, int]

Dimension of the latent variable y for each module or a single value for all modules.

required
shared_element_type str

Type of shared element. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
Source code in vambn/modelling/models/hivae/modular.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class ModularHivae(
    AbstractModularModel[
        Tuple[Tensor, ...],
        Tuple[Tensor, ...],
        ModularHivaeOutput,
        ModularHivaeEncoding,
    ]
):
    """
    Modular HIVAE model containing multiple data modules.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for each data module. See DataModuleConfig for details.
        dim_s (int | Dict[str, int]): Number of mixture components for each module individually (dict) or a single value for all modules (int).
        dim_z (int): Dimension of the latent space. Equal for all modules.
        dim_ys (int): Dimension of the latent space ys. Equal for all modules.
        dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
        shared_element_type (str, optional): Type of shared element. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
    """

    def __init__(
        self,
        module_config: Tuple[DataModuleConfig],
        dim_s: int | Dict[str, int],
        dim_z: int,
        dim_ys: int,
        dim_y: int | Dict[str, int],
        shared_element_type: str = "none",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
    ):
        """
        Initialize the modular HIVAE model.

        Args:
            module_config (Tuple[DataModuleConfig]): Configuration for each data module.
            dim_s (int | Dict[str, int]): Number of the mixture components for each module or a single value for all modules.
            dim_z (int): Dimension of the latent space. Equal for all modules.
            dim_ys (int): Dimension of the latent space ys. Equal for all modules.
            dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
            shared_element_type (str, optional): Type of shared element. Defaults to "none".
            mtl_method (Tuple[str, ...], optional): Methods for MTL. Defaults to ("identity",).
            use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        """

        super().__init__()
        self.module_configs = module_config
        self.mtl_method = mtl_method
        self.use_imputation_layer = use_imputation_layer
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_ys = dim_ys
        self.dim_y = dim_y

        if not isinstance(dim_s, int) and len(dim_s) != len(module_config):
            raise ValueError(
                "If dim_s is a tuple, it must have the same length as module_config"
            )

        # Initialize the shared element
        # The shared element is a module that takes the z samples from the encoder outputs
        # and generates a shared representation ys, finally a representation y for each module
        shared_element_class = SHARED_MODULES[shared_element_type]
        if shared_element_class is None:
            raise ValueError(
                f"Shared element {shared_element_type} is not available"
            )
        self.shared_element = shared_element_class(
            z_dim=self.dim_z,
            n_modules=len(module_config),
            ys_dim=self.dim_ys,
            y_dim=self.dim_y
            if isinstance(dim_y, int)
            else tuple([dim_y[module.name] for module in self.module_configs]),
            mtl_method=mtl_method if mtl_method else ("graddrop",),
            module_names=[module.name for module in module_config],
        )

        module_models = {}
        for module in module_config:
            module_name = module.name
            if module.is_longitudinal:
                module_models[module_name] = LstmHivae(
                    dim_s=dim_s
                    if isinstance(dim_s, int)
                    else dim_s[module_name],
                    dim_y=dim_y
                    if isinstance(dim_y, int)
                    else dim_y[module_name],
                    dim_z=dim_z,
                    individual_model=False,
                    input_dim=module.input_dim,
                    module_name=module_name,
                    mtl_method=mtl_method,
                    n_layers=module.n_layers,
                    num_timepoints=module.num_timepoints,
                    use_imputation_layer=use_imputation_layer,
                    variable_types=module.variable_types,
                )
            else:
                module_models[module_name] = Hivae(
                    dim_s=dim_s
                    if isinstance(dim_s, int)
                    else dim_s[module_name],
                    dim_y=dim_y
                    if isinstance(dim_y, int)
                    else dim_y[module_name],
                    dim_z=dim_z,
                    individual_model=False,
                    input_dim=module.input_dim,
                    module_name=module_name,
                    mtl_method=mtl_method,
                    use_imputation_layer=use_imputation_layer,
                    variable_types=module.variable_types,
                )

        self.module_models = nn.ModuleDict(module_models)
        self._tau = 1.0

    @property
    def decoding(self) -> bool:
        """
        Decoding flag indicating if the encoder and decoder are in decoding mode.

        Returns:
            bool: Decoding flag.
        """
        assert all([module.decoding for module in self.module_models.values()])
        return self.module_models[self.module_configs[0].name].decoding


    @decoding.setter
    def decoding(self, value: bool) -> None:
        """
        Sets the decoding flag for all modules.

        Args:
            value (bool): The decoding flag to set.
        """
        for module in self.module_models.values():
            module.decoding = value


    def colnames(self, module_name: str) -> Tuple[str, ...]:
        """
        Get column names for a specific module.

        Args:
            module_name (str): Name of the module.

        Returns:
            Tuple[str, ...]: Column names for the specified module.
        """
        return self.module_models[module_name].colnames


    def is_longitudinal(self, module_name: str) -> bool:
        """
        Check if a specific module is longitudinal.

        Args:
            module_name (str): Name of the module.

        Returns:
            bool: True if the module is longitudinal, False otherwise.
        """
        return self.module_models[module_name].is_longitudinal


    def forward(
        self, data: Tuple[Tensor, ...], mask: Tuple[Tensor, ...]
    ) -> ModularHivaeOutput:
        """
        Forward pass through the modular HIVAE model.

        Args:
            data (Tuple[Tensor, ...]): Input data tensors for each module.
            mask (Tuple[Tensor, ...]): Mask tensors indicating missing values for each module.

        Returns:
            ModularHivaeOutput: The output of the modular HIVAE model.
        """

        # Generate the encoder outputs and data for each module
        # This step does not differ from the normal HIVAE model
        encoder_outputs: List[EncoderOutput] = []
        modified_data = []
        modified_mask = []
        for i, module in enumerate(self.module_configs):
            mdata, mmask, output = self.module_models[module.name].encoder_part(
                data[i], mask[i]
            )
            encoder_outputs.append(output)
            modified_data.append(mdata)
            modified_mask.append(mmask)

        # Then retrieve the z samples from the encoder outputs
        # These z samples are then passed through the shared element to
        # achieve the modularity
        # The shared representation is a tuple of tensors, one for each module
        z_samples = tuple([x.samples_z for x in encoder_outputs])
        shared_representations = self.shared_element(z=z_samples)

        # This shared representation is then passed back to the encoder outputs
        for i, enc in enumerate(encoder_outputs):
            enc.decoder_representation = shared_representations[i]

        # Then we continue with the decoder part as usual
        decoder_outputs = []
        for i, module in enumerate(self.module_configs):
            output = self.module_models[module.name].decoder_part(
                data=modified_data[i],
                mask=modified_mask[i],
                encoder_output=encoder_outputs[i],
            )
            decoder_outputs.append(output)

        gathered_output = ModularHivaeOutput(outputs=tuple(decoder_outputs))
        return gathered_output

    @property
    def tau(self):
        """
        Get the temperature parameter for the model.

        Returns:
            float: The temperature parameter.
        """

        return self._tau

    @tau.setter
    def tau(self, value: float):
        """
        Set the temperature parameter for the model.

        Args:
            value (float): The temperature parameter to set.
        """

        self._tau = value
        for module in self.module_models.values():
            module.tau = value

    def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (ModularHivaeEncoding): The encoding to decode.

        Returns:
            Tuple[Tensor, ...]: The reconstructed data tensors for each module.
        """

        self.eval()
        self.tau = 1e-3
        z_samples = tuple([x.z for x in encoding])

        modified_output = self.shared_element(z=z_samples)
        output_mapping = {
            x: i for i, x in enumerate(self.shared_element.module_names)
        }

        for enc in encoding.encodings:
            out_pos = output_mapping[enc.module]
            enc.decoder_representation = modified_output[out_pos]

        outputs = []
        for i, module in enumerate(self.module_configs):
            module_enc = encoding.get(module.name)
            output = self.module_models[module.name].decode(module_enc)
            outputs.append(output)
        assert len(outputs) == len(self.module_configs)
        return tuple(outputs)

    def _training_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[optim.Optimizer],
    ) -> float:
        """
        Perform a training step to update the model parameters.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.
            optimizer (Tuple[optim.Optimizer]): Optimizers for updating model parameters.

        Returns:
            float: The training loss.
        """

        # set all gradients to zero
        for opt in optimizer:
            if opt is None:
                continue
            opt.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        for opt in optimizer:
            if opt is None:
                continue
            opt.step()
        return output.loss.item()

    def _validation_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> float:
        """
        Perform a validation step to evaluate the model.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            float: The validation loss.
        """

        output = self.forward(data, mask)
        return output.loss.item()

    def _test_step(self, data: Tuple[Tensor], mask: Tuple[Tensor]) -> float:
        """
        Perform a test step to evaluate the model on test data.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            float: The test loss.
        """

        output = self.forward(data, mask)
        return output.loss.item()

    def _predict_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """
        Perform a prediction step without gradient calculation.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            ModularHivaeOutput: The output of the modular HIVAE model.
        """

        return self.forward(data, mask)
decoding: bool property writable

Decoding flag indicating if the encoder and decoder are in decoding mode.

Returns:

Name Type Description
bool bool

Decoding flag.

tau property writable

Get the temperature parameter for the model.

Returns:

Name Type Description
float

The temperature parameter.

__init__(module_config, dim_s, dim_z, dim_ys, dim_y, shared_element_type='none', mtl_method=('identity'), use_imputation_layer=False)

Initialize the modular HIVAE model.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for each data module.

required
dim_s int | Dict[str, int]

Number of the mixture components for each module or a single value for all modules.

required
dim_z int

Dimension of the latent space. Equal for all modules.

required
dim_ys int

Dimension of the latent space ys. Equal for all modules.

required
dim_y int | Dict[str, int]

Dimension of the latent variable y for each module or a single value for all modules.

required
shared_element_type str

Type of shared element. Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for MTL. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
Source code in vambn/modelling/models/hivae/modular.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    module_config: Tuple[DataModuleConfig],
    dim_s: int | Dict[str, int],
    dim_z: int,
    dim_ys: int,
    dim_y: int | Dict[str, int],
    shared_element_type: str = "none",
    mtl_method: Tuple[str, ...] = ("identity",),
    use_imputation_layer: bool = False,
):
    """
    Initialize the modular HIVAE model.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for each data module.
        dim_s (int | Dict[str, int]): Number of the mixture components for each module or a single value for all modules.
        dim_z (int): Dimension of the latent space. Equal for all modules.
        dim_ys (int): Dimension of the latent space ys. Equal for all modules.
        dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
        shared_element_type (str, optional): Type of shared element. Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for MTL. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
    """

    super().__init__()
    self.module_configs = module_config
    self.mtl_method = mtl_method
    self.use_imputation_layer = use_imputation_layer
    self.dim_s = dim_s
    self.dim_z = dim_z
    self.dim_ys = dim_ys
    self.dim_y = dim_y

    if not isinstance(dim_s, int) and len(dim_s) != len(module_config):
        raise ValueError(
            "If dim_s is a tuple, it must have the same length as module_config"
        )

    # Initialize the shared element
    # The shared element is a module that takes the z samples from the encoder outputs
    # and generates a shared representation ys, finally a representation y for each module
    shared_element_class = SHARED_MODULES[shared_element_type]
    if shared_element_class is None:
        raise ValueError(
            f"Shared element {shared_element_type} is not available"
        )
    self.shared_element = shared_element_class(
        z_dim=self.dim_z,
        n_modules=len(module_config),
        ys_dim=self.dim_ys,
        y_dim=self.dim_y
        if isinstance(dim_y, int)
        else tuple([dim_y[module.name] for module in self.module_configs]),
        mtl_method=mtl_method if mtl_method else ("graddrop",),
        module_names=[module.name for module in module_config],
    )

    module_models = {}
    for module in module_config:
        module_name = module.name
        if module.is_longitudinal:
            module_models[module_name] = LstmHivae(
                dim_s=dim_s
                if isinstance(dim_s, int)
                else dim_s[module_name],
                dim_y=dim_y
                if isinstance(dim_y, int)
                else dim_y[module_name],
                dim_z=dim_z,
                individual_model=False,
                input_dim=module.input_dim,
                module_name=module_name,
                mtl_method=mtl_method,
                n_layers=module.n_layers,
                num_timepoints=module.num_timepoints,
                use_imputation_layer=use_imputation_layer,
                variable_types=module.variable_types,
            )
        else:
            module_models[module_name] = Hivae(
                dim_s=dim_s
                if isinstance(dim_s, int)
                else dim_s[module_name],
                dim_y=dim_y
                if isinstance(dim_y, int)
                else dim_y[module_name],
                dim_z=dim_z,
                individual_model=False,
                input_dim=module.input_dim,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                variable_types=module.variable_types,
            )

    self.module_models = nn.ModuleDict(module_models)
    self._tau = 1.0
colnames(module_name)

Get column names for a specific module.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: Column names for the specified module.

Source code in vambn/modelling/models/hivae/modular.py
168
169
170
171
172
173
174
175
176
177
178
def colnames(self, module_name: str) -> Tuple[str, ...]:
    """
    Get column names for a specific module.

    Args:
        module_name (str): Name of the module.

    Returns:
        Tuple[str, ...]: Column names for the specified module.
    """
    return self.module_models[module_name].colnames
decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding ModularHivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[Tensor, ...]: The reconstructed data tensors for each module.

Source code in vambn/modelling/models/hivae/modular.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (ModularHivaeEncoding): The encoding to decode.

    Returns:
        Tuple[Tensor, ...]: The reconstructed data tensors for each module.
    """

    self.eval()
    self.tau = 1e-3
    z_samples = tuple([x.z for x in encoding])

    modified_output = self.shared_element(z=z_samples)
    output_mapping = {
        x: i for i, x in enumerate(self.shared_element.module_names)
    }

    for enc in encoding.encodings:
        out_pos = output_mapping[enc.module]
        enc.decoder_representation = modified_output[out_pos]

    outputs = []
    for i, module in enumerate(self.module_configs):
        module_enc = encoding.get(module.name)
        output = self.module_models[module.name].decode(module_enc)
        outputs.append(output)
    assert len(outputs) == len(self.module_configs)
    return tuple(outputs)
forward(data, mask)

Forward pass through the modular HIVAE model.

Parameters:

Name Type Description Default
data Tuple[Tensor, ...]

Input data tensors for each module.

required
mask Tuple[Tensor, ...]

Mask tensors indicating missing values for each module.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The output of the modular HIVAE model.

Source code in vambn/modelling/models/hivae/modular.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def forward(
    self, data: Tuple[Tensor, ...], mask: Tuple[Tensor, ...]
) -> ModularHivaeOutput:
    """
    Forward pass through the modular HIVAE model.

    Args:
        data (Tuple[Tensor, ...]): Input data tensors for each module.
        mask (Tuple[Tensor, ...]): Mask tensors indicating missing values for each module.

    Returns:
        ModularHivaeOutput: The output of the modular HIVAE model.
    """

    # Generate the encoder outputs and data for each module
    # This step does not differ from the normal HIVAE model
    encoder_outputs: List[EncoderOutput] = []
    modified_data = []
    modified_mask = []
    for i, module in enumerate(self.module_configs):
        mdata, mmask, output = self.module_models[module.name].encoder_part(
            data[i], mask[i]
        )
        encoder_outputs.append(output)
        modified_data.append(mdata)
        modified_mask.append(mmask)

    # Then retrieve the z samples from the encoder outputs
    # These z samples are then passed through the shared element to
    # achieve the modularity
    # The shared representation is a tuple of tensors, one for each module
    z_samples = tuple([x.samples_z for x in encoder_outputs])
    shared_representations = self.shared_element(z=z_samples)

    # This shared representation is then passed back to the encoder outputs
    for i, enc in enumerate(encoder_outputs):
        enc.decoder_representation = shared_representations[i]

    # Then we continue with the decoder part as usual
    decoder_outputs = []
    for i, module in enumerate(self.module_configs):
        output = self.module_models[module.name].decoder_part(
            data=modified_data[i],
            mask=modified_mask[i],
            encoder_output=encoder_outputs[i],
        )
        decoder_outputs.append(output)

    gathered_output = ModularHivaeOutput(outputs=tuple(decoder_outputs))
    return gathered_output
is_longitudinal(module_name)

Check if a specific module is longitudinal.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Name Type Description
bool bool

True if the module is longitudinal, False otherwise.

Source code in vambn/modelling/models/hivae/modular.py
181
182
183
184
185
186
187
188
189
190
191
def is_longitudinal(self, module_name: str) -> bool:
    """
    Check if a specific module is longitudinal.

    Args:
        module_name (str): Name of the module.

    Returns:
        bool: True if the module is longitudinal, False otherwise.
    """
    return self.module_models[module_name].is_longitudinal

normalization

Normalization

Class for normalization utilities, including broadcasting masks and normalizing/denormalizing data.

Source code in vambn/modelling/models/hivae/normalization.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class Normalization:
    """
    Class for normalization utilities, including broadcasting masks and normalizing/denormalizing data.
    """
    @staticmethod
    def _broadcast_mask(
        mask: torch.Tensor, variable_types: VarTypes
    ) -> torch.Tensor:
        """
        Broadcast the mask tensor to match the shape required by variable types.

        Args:
            mask (torch.Tensor): The input mask tensor.
            variable_types (VarTypes): The variable types information.

        Returns:
            torch.Tensor: The broadcasted mask tensor.
        """
        # if all(d.n_parameters == 1 for d in variable_types):
        #     return mask

        new_mask = []
        for i, vtype in enumerate(variable_types):
            if vtype.data_type == "cat":
                new_mask.append(
                    mask[:, i].unsqueeze(1).expand(-1, vtype.n_parameters)
                )
            else:
                new_mask.append(mask[:, i].unsqueeze(1))
        return torch.cat(new_mask, dim=1)

    @staticmethod
    def normalize_data(
        x: torch.Tensor,
        mask: torch.Tensor,
        variable_types: VarTypes,
        prior_parameters: NormalizationParameters,
        eps: float = 1e-6,
    ) -> Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]:
        """
        Normalize the input data based on variable types and prior parameters.

        Args:
            x (torch.Tensor): The input data tensor.
            mask (torch.Tensor): The mask tensor indicating missing values.
            variable_types (VarTypes): The variable types information.
            prior_parameters (NormalizationParameters): The prior normalization parameters.
            eps (float, optional): A small value to prevent division by zero. Defaults to 1e-6.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.
        """
        if x.ndim == 3 and x.shape[1] == 1:
            x = x.squeeze(1)
            mask = mask.squeeze(1)

        assert len(variable_types) == x.shape[-1]
        mean_data = prior_parameters.mean
        std_data = prior_parameters.std
        new_x = []
        for i, vtype in enumerate(variable_types):
            x_i = torch.masked_select(x[..., i], mask[..., i].bool())
            new_x_i = torch.unsqueeze(x[..., i], -1)

            if vtype.data_type == "real" or vtype.data_type == "truncate_norm":
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (new_x_i - mean_data[i]) / std_data[i]
            elif vtype.data_type == "pos":
                x_i = torch.log1p(x_i)
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
            elif vtype.data_type == "gamma":
                x_i = torch.log1p(x_i)
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
            elif vtype.data_type == "count":
                new_x_i = torch.log1p(new_x_i)
            elif vtype.data_type == "cat":
                # convert to one hot
                new_x_i = torch.nn.functional.one_hot(
                    new_x_i.long().squeeze(1), vtype.n_parameters
                )

            if torch.isnan(new_x_i).any():
                raise ValueError(
                    f"NaN values found in normalized data for {vtype}"
                )
            if torch.isnan(mean_data[i]) or torch.isnan(std_data[i]):
                raise ValueError(
                    f"NaN values found in normalization parameters for {vtype}"
                )
            new_x.append(new_x_i)

        new_x = torch.cat(new_x, dim=-1)
        mask = Normalization._broadcast_mask(mask, variable_types)
        new_x = new_x * mask
        return (
            new_x,
            mask,
            NormalizationParameters.from_tensors(mean_data, std_data),
        )

    @staticmethod
    def denormalize_params(
        params: Tuple[Parameters, ...],
        variable_types: VarTypes,
        normalization_params: NormalizationParameters,
    ) -> Tuple[Parameters, ...]:
        """
        Denormalize parameters based on variable types and normalization parameters.

        Args:
            etas (Tuple[Etas, ...]): The parameters to denormalize.
            variable_types (VarTypes): The variable types information.
            normalization_params (NormalizationParameters): The normalization parameters.

        Returns:
            Tuple[Parameters, ...]: The denormalized parameters.
        """
        out_params = []
        for i, vtype in enumerate(variable_types):
            param_i = params[i]
            if vtype.data_type in ["truncate_norm", "real"] and isinstance(
                param_i, NormalParameters
            ):
                mean_data, std_data = normalization_params[i]
                std_data = std_data

                mean, std = param_i.loc, param_i.scale
                mean = mean * std_data + mean_data
                std = std * std_data

                out_params.append(NormalParameters(mean, std))
            elif vtype.data_type == "pos" and isinstance(
                param_i, LogNormalParameters
            ):
                mean_data, std_data = normalization_params[i]

                mean, std = param_i.loc, param_i.scale
                mean = mean * std_data + mean_data
                std = std * std_data

                out_params.append(LogNormalParameters(mean, std))
            elif vtype.data_type == "count":
                out_params.append(param_i)
            elif vtype.data_type == "cat":
                out_params.append(param_i)
            else:
                raise ValueError(f"Unknown data type {vtype.data_type}")
        return tuple(params)
denormalize_params(params, variable_types, normalization_params) staticmethod

Denormalize parameters based on variable types and normalization parameters.

Parameters:

Name Type Description Default
etas Tuple[Etas, ...]

The parameters to denormalize.

required
variable_types VarTypes

The variable types information.

required
normalization_params NormalizationParameters

The normalization parameters.

required

Returns:

Type Description
Tuple[Parameters, ...]

Tuple[Parameters, ...]: The denormalized parameters.

Source code in vambn/modelling/models/hivae/normalization.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@staticmethod
def denormalize_params(
    params: Tuple[Parameters, ...],
    variable_types: VarTypes,
    normalization_params: NormalizationParameters,
) -> Tuple[Parameters, ...]:
    """
    Denormalize parameters based on variable types and normalization parameters.

    Args:
        etas (Tuple[Etas, ...]): The parameters to denormalize.
        variable_types (VarTypes): The variable types information.
        normalization_params (NormalizationParameters): The normalization parameters.

    Returns:
        Tuple[Parameters, ...]: The denormalized parameters.
    """
    out_params = []
    for i, vtype in enumerate(variable_types):
        param_i = params[i]
        if vtype.data_type in ["truncate_norm", "real"] and isinstance(
            param_i, NormalParameters
        ):
            mean_data, std_data = normalization_params[i]
            std_data = std_data

            mean, std = param_i.loc, param_i.scale
            mean = mean * std_data + mean_data
            std = std * std_data

            out_params.append(NormalParameters(mean, std))
        elif vtype.data_type == "pos" and isinstance(
            param_i, LogNormalParameters
        ):
            mean_data, std_data = normalization_params[i]

            mean, std = param_i.loc, param_i.scale
            mean = mean * std_data + mean_data
            std = std * std_data

            out_params.append(LogNormalParameters(mean, std))
        elif vtype.data_type == "count":
            out_params.append(param_i)
        elif vtype.data_type == "cat":
            out_params.append(param_i)
        else:
            raise ValueError(f"Unknown data type {vtype.data_type}")
    return tuple(params)
normalize_data(x, mask, variable_types, prior_parameters, eps=1e-06) staticmethod

Normalize the input data based on variable types and prior parameters.

Parameters:

Name Type Description Default
x Tensor

The input data tensor.

required
mask Tensor

The mask tensor indicating missing values.

required
variable_types VarTypes

The variable types information.

required
prior_parameters NormalizationParameters

The prior normalization parameters.

required
eps float

A small value to prevent division by zero. Defaults to 1e-6.

1e-06

Returns:

Type Description
Tuple[Tensor, Tensor, NormalizationParameters]

Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.

Source code in vambn/modelling/models/hivae/normalization.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@staticmethod
def normalize_data(
    x: torch.Tensor,
    mask: torch.Tensor,
    variable_types: VarTypes,
    prior_parameters: NormalizationParameters,
    eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]:
    """
    Normalize the input data based on variable types and prior parameters.

    Args:
        x (torch.Tensor): The input data tensor.
        mask (torch.Tensor): The mask tensor indicating missing values.
        variable_types (VarTypes): The variable types information.
        prior_parameters (NormalizationParameters): The prior normalization parameters.
        eps (float, optional): A small value to prevent division by zero. Defaults to 1e-6.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.
    """
    if x.ndim == 3 and x.shape[1] == 1:
        x = x.squeeze(1)
        mask = mask.squeeze(1)

    assert len(variable_types) == x.shape[-1]
    mean_data = prior_parameters.mean
    std_data = prior_parameters.std
    new_x = []
    for i, vtype in enumerate(variable_types):
        x_i = torch.masked_select(x[..., i], mask[..., i].bool())
        new_x_i = torch.unsqueeze(x[..., i], -1)

        if vtype.data_type == "real" or vtype.data_type == "truncate_norm":
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (new_x_i - mean_data[i]) / std_data[i]
        elif vtype.data_type == "pos":
            x_i = torch.log1p(x_i)
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
        elif vtype.data_type == "gamma":
            x_i = torch.log1p(x_i)
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
        elif vtype.data_type == "count":
            new_x_i = torch.log1p(new_x_i)
        elif vtype.data_type == "cat":
            # convert to one hot
            new_x_i = torch.nn.functional.one_hot(
                new_x_i.long().squeeze(1), vtype.n_parameters
            )

        if torch.isnan(new_x_i).any():
            raise ValueError(
                f"NaN values found in normalized data for {vtype}"
            )
        if torch.isnan(mean_data[i]) or torch.isnan(std_data[i]):
            raise ValueError(
                f"NaN values found in normalization parameters for {vtype}"
            )
        new_x.append(new_x_i)

    new_x = torch.cat(new_x, dim=-1)
    mask = Normalization._broadcast_mask(mask, variable_types)
    new_x = new_x * mask
    return (
        new_x,
        mask,
        NormalizationParameters.from_tensors(mean_data, std_data),
    )
NormalizationParameters dataclass

Data class for normalization parameters, including mean and standard deviation.

This class is only used for the parameters of real and pos typed variables.

Parameters:

Name Type Description Default
mean Tensor

The mean values for normalization.

required
std Tensor

The standard deviation values for normalization.

required
Source code in vambn/modelling/models/hivae/normalization.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class NormalizationParameters:
    """
    Data class for normalization parameters, including mean and standard deviation.

    This class is only used for the parameters of real and pos typed variables.

    Args:
        mean (torch.Tensor): The mean values for normalization.
        std (torch.Tensor): The standard deviation values for normalization.
    """

    mean: torch.Tensor
    std: torch.Tensor

    @classmethod
    def from_tensors(
        cls, mean: torch.Tensor, std: torch.Tensor
    ) -> "NormalizationParameters":
        """
        Create NormalizationParameters from mean and std tensors.

        Args:
            mean (torch.Tensor): The mean tensor.
            std (torch.Tensor): The standard deviation tensor.

        Returns:
            NormalizationParameters: An instance of NormalizationParameters.
        """
        return NormalizationParameters(mean, std)

    def __getitem__(
        self, idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor] | "NormalizationParameters":  # type: ignore
        """
        Get normalization parameters by index.

        Args:
            idx (int): The index to retrieve parameters for.

        Returns:
            Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.
        """
        if self.mean.ndim == 1:
            return self.mean[idx], self.std[idx]
        elif self.mean.ndim == 2:
            mean = self.mean[idx, :]
            std = self.std[idx, :]
            return NormalizationParameters(mean, std)

    def __setitem__(
        self, idx: int, value: Tuple[torch.Tensor, torch.Tensor]
    ) -> None:
        """
        Set normalization parameters by index.

        Args:
            idx (int): The index to set parameters for.
            value (Tuple[torch.Tensor, torch.Tensor]): The mean and std tensors to set.
        """
        self.mean[idx], self.std[idx] = value

    def __add__(
        self, other: "NormalizationParameters"
    ) -> "NormalizationParameters":
        """
        Add two NormalizationParameters instances.

        Args:
            other (NormalizationParameters): Another instance to add.

        Returns:
            NormalizationParameters: A new instance with combined parameters.
        """
        if self.mean.ndim == 1:
            # create a 2d tensor by stacking the tensors
            mean = torch.stack([self.mean, other.mean])
            std = torch.stack([self.std, other.std])
        elif self.mean.ndim == 2:
            mean = torch.cat([self.mean, other.mean.unsqueeze(0)], dim=0)
            std = torch.cat([self.std, other.std.unsqueeze(0)], dim=0)
        else:
            raise ValueError("Invalid dimension for mean tensor")

        return NormalizationParameters(mean, std)
__add__(other)

Add two NormalizationParameters instances.

Parameters:

Name Type Description Default
other NormalizationParameters

Another instance to add.

required

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

A new instance with combined parameters.

Source code in vambn/modelling/models/hivae/normalization.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __add__(
    self, other: "NormalizationParameters"
) -> "NormalizationParameters":
    """
    Add two NormalizationParameters instances.

    Args:
        other (NormalizationParameters): Another instance to add.

    Returns:
        NormalizationParameters: A new instance with combined parameters.
    """
    if self.mean.ndim == 1:
        # create a 2d tensor by stacking the tensors
        mean = torch.stack([self.mean, other.mean])
        std = torch.stack([self.std, other.std])
    elif self.mean.ndim == 2:
        mean = torch.cat([self.mean, other.mean.unsqueeze(0)], dim=0)
        std = torch.cat([self.std, other.std.unsqueeze(0)], dim=0)
    else:
        raise ValueError("Invalid dimension for mean tensor")

    return NormalizationParameters(mean, std)
__getitem__(idx)

Get normalization parameters by index.

Parameters:

Name Type Description Default
idx int

The index to retrieve parameters for.

required

Returns:

Type Description
Tuple[Tensor, Tensor] | NormalizationParameters

Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.

Source code in vambn/modelling/models/hivae/normalization.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __getitem__(
    self, idx: int
) -> Tuple[torch.Tensor, torch.Tensor] | "NormalizationParameters":  # type: ignore
    """
    Get normalization parameters by index.

    Args:
        idx (int): The index to retrieve parameters for.

    Returns:
        Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.
    """
    if self.mean.ndim == 1:
        return self.mean[idx], self.std[idx]
    elif self.mean.ndim == 2:
        mean = self.mean[idx, :]
        std = self.std[idx, :]
        return NormalizationParameters(mean, std)
__setitem__(idx, value)

Set normalization parameters by index.

Parameters:

Name Type Description Default
idx int

The index to set parameters for.

required
value Tuple[Tensor, Tensor]

The mean and std tensors to set.

required
Source code in vambn/modelling/models/hivae/normalization.py
67
68
69
70
71
72
73
74
75
76
77
def __setitem__(
    self, idx: int, value: Tuple[torch.Tensor, torch.Tensor]
) -> None:
    """
    Set normalization parameters by index.

    Args:
        idx (int): The index to set parameters for.
        value (Tuple[torch.Tensor, torch.Tensor]): The mean and std tensors to set.
    """
    self.mean[idx], self.std[idx] = value
from_tensors(mean, std) classmethod

Create NormalizationParameters from mean and std tensors.

Parameters:

Name Type Description Default
mean Tensor

The mean tensor.

required
std Tensor

The standard deviation tensor.

required

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

An instance of NormalizationParameters.

Source code in vambn/modelling/models/hivae/normalization.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@classmethod
def from_tensors(
    cls, mean: torch.Tensor, std: torch.Tensor
) -> "NormalizationParameters":
    """
    Create NormalizationParameters from mean and std tensors.

    Args:
        mean (torch.Tensor): The mean tensor.
        std (torch.Tensor): The standard deviation tensor.

    Returns:
        NormalizationParameters: An instance of NormalizationParameters.
    """
    return NormalizationParameters(mean, std)

outputs

DecoderOutput dataclass

Dataclass to hold the output from the decoder.

Attributes:

Name Type Description
log_p_x Tensor

Log-likelihood of the data.

kl_z Tensor

KL divergence for the z variable.

kl_s Tensor

KL divergence for the s variable.

corr_loss Tensor

Correlation loss if applicable.

recon Optional[Tensor]

Reconstruction, if any.

samples Optional[Tensor]

Samples generated, if any.

enc_s Optional[Tensor]

Encoded s values, if any.

enc_z Optional[Tensor]

Encoded z values, if any.

output_name Optional[str]

Name of the output, if any.

detached bool

Whether the tensors have been detached.

Source code in vambn/modelling/models/hivae/outputs.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@dataclass
class DecoderOutput:
    """Dataclass to hold the output from the decoder.

    Attributes:
        log_p_x (Tensor): Log-likelihood of the data.
        kl_z (Tensor): KL divergence for the z variable.
        kl_s (Tensor): KL divergence for the s variable.
        corr_loss (Tensor): Correlation loss if applicable.
        recon (Optional[Tensor]): Reconstruction, if any.
        samples (Optional[Tensor]): Samples generated, if any.
        enc_s (Optional[Tensor]): Encoded s values, if any.
        enc_z (Optional[Tensor]): Encoded z values, if any.
        output_name (Optional[str]): Name of the output, if any.
        detached (bool): Whether the tensors have been detached.
    """


    log_p_x: Tensor
    kl_z: Tensor
    kl_s: Tensor
    corr_loss: Tensor
    recon: Optional[Tensor] = None
    samples: Optional[Tensor] = None
    enc_s: Optional[Tensor] = None
    enc_z: Optional[Tensor] = None
    output_name: Optional[str] = None
    detached: bool = False

    def __post_init__(self):
        """
        Validate dimensions of the log-likelihood tensor.

        Raises:
            Exception: If log-likelihood tensor is not of dimension 1.
        """
        if self.log_p_x.ndim != 1:
            raise Exception(
                f"Log-likelihood is not of correct dimension ({self.log_p_x.ndim}, expected 1)"
            )

    def __str__(self) -> str:
        """
        String representation of the DecoderOutput object.

        Returns:
            str: A string describing the DecoderOutput object.
        """

        if self.output_name is not None:
            return f"Decoder output for {self.output_name})"
        else:
            return f"Decoder output (id={id(self)})"

    def detach_and_move(self) -> "DecoderOutput":
        """
        Detach all tensors and move them to CPU.

        Returns:
            DecoderOutput: The detached DecoderOutput object.
        """

        self.detached = True
        self.log_p_x = self.log_p_x.detach().cpu()
        self.kl_z = self.kl_z.detach().cpu()
        self.kl_s = self.kl_s.detach().cpu()
        if self.corr_loss is not None:
            self.corr_loss = self.corr_loss.detach().cpu()
        if self.samples is not None:
            self.samples = self.samples.detach().cpu()
        if self.enc_s is not None:
            self.enc_s = self.enc_s.detach().cpu()
        if self.enc_z is not None:
            self.enc_z = self.enc_z.detach().cpu()
        return self

    @property
    def elbo(self) -> Tensor:
        """
        Calculate the negative Evidence Lower Bound (ELBO).

        Returns:
            Tensor: The negative ELBO.
        """

        return self.log_p_x - self.kl_z - self.kl_s

    @property
    def loss(self) -> Tensor:
        """
        Calculate the loss based on the negative ELBO.

        Returns:
            Tensor: The loss tensor.

        Raises:
            Exception: If tensors have been detached.
        """

        if self.detached:
            logger.error("Cannot calculate loss. Tensors have been detached.")
            raise Exception(
                "Cannot calculate loss. Tensors have been detached."
            )
        loss = -self.elbo.sum()
        return loss
elbo: Tensor property

Calculate the negative Evidence Lower Bound (ELBO).

Returns:

Name Type Description
Tensor Tensor

The negative ELBO.

loss: Tensor property

Calculate the loss based on the negative ELBO.

Returns:

Name Type Description
Tensor Tensor

The loss tensor.

Raises:

Type Description
Exception

If tensors have been detached.

__post_init__()

Validate dimensions of the log-likelihood tensor.

Raises:

Type Description
Exception

If log-likelihood tensor is not of dimension 1.

Source code in vambn/modelling/models/hivae/outputs.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __post_init__(self):
    """
    Validate dimensions of the log-likelihood tensor.

    Raises:
        Exception: If log-likelihood tensor is not of dimension 1.
    """
    if self.log_p_x.ndim != 1:
        raise Exception(
            f"Log-likelihood is not of correct dimension ({self.log_p_x.ndim}, expected 1)"
        )
__str__()

String representation of the DecoderOutput object.

Returns:

Name Type Description
str str

A string describing the DecoderOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
107
108
109
110
111
112
113
114
115
116
117
118
def __str__(self) -> str:
    """
    String representation of the DecoderOutput object.

    Returns:
        str: A string describing the DecoderOutput object.
    """

    if self.output_name is not None:
        return f"Decoder output for {self.output_name})"
    else:
        return f"Decoder output (id={id(self)})"
detach_and_move()

Detach all tensors and move them to CPU.

Returns:

Name Type Description
DecoderOutput DecoderOutput

The detached DecoderOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def detach_and_move(self) -> "DecoderOutput":
    """
    Detach all tensors and move them to CPU.

    Returns:
        DecoderOutput: The detached DecoderOutput object.
    """

    self.detached = True
    self.log_p_x = self.log_p_x.detach().cpu()
    self.kl_z = self.kl_z.detach().cpu()
    self.kl_s = self.kl_s.detach().cpu()
    if self.corr_loss is not None:
        self.corr_loss = self.corr_loss.detach().cpu()
    if self.samples is not None:
        self.samples = self.samples.detach().cpu()
    if self.enc_s is not None:
        self.enc_s = self.enc_s.detach().cpu()
    if self.enc_z is not None:
        self.enc_z = self.enc_z.detach().cpu()
    return self
EncoderOutput dataclass

Dataclass for encoder output.

Attributes:

Name Type Description
samples_s Tensor

Samples from the s distribution.

logits_s Tensor

Logits for the s distribution.

mean_z Tensor

Mean of the z distribution.

scale_z Tensor

Scale of the z distribution.

samples_z Optional[Tensor]

Samples from the z distribution.

h_representation Optional[Tensor]

Hidden representation, if any.

Source code in vambn/modelling/models/hivae/outputs.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class EncoderOutput:
    """Dataclass for encoder output.

    Attributes:
        samples_s (Tensor): Samples from the s distribution.
        logits_s (Tensor): Logits for the s distribution.
        mean_z (Tensor): Mean of the z distribution.
        scale_z (Tensor): Scale of the z distribution.
        samples_z (Optional[Tensor]): Samples from the z distribution.
        h_representation (Optional[Tensor]): Hidden representation, if any.
    """


    samples_s: Tensor
    logits_s: Tensor
    mean_z: Tensor
    scale_z: Tensor
    samples_z: Optional[Tensor]
    h_representation: Optional[Tensor] = None

    @property
    def decoder_representation(self) -> Tensor:
        """
        Get the decoder representation.

        Returns:
            Tensor: The hidden representation if available, otherwise the samples from the z distribution.
        """
        return (
            self.h_representation
            if self.h_representation is not None
            else self.samples_z
        )

    @decoder_representation.setter
    def decoder_representation(self, value: Tensor) -> None:
        """
        Set the decoder representation.

        Args:
            value (Tensor): The value to set as the hidden representation.
        """
        # if value.shape != self.samples_z.shape:
        #     raise ValueError(
        #         f"Shape of value ({value.shape}) does not match shape of samples_z ({self.samples_z.shape})"
        #     )

        self.h_representation = value
decoder_representation: Tensor property writable

Get the decoder representation.

Returns:

Name Type Description
Tensor Tensor

The hidden representation if available, otherwise the samples from the z distribution.

HivaeEncoding dataclass

Dataclass for HIVAE encoding.

Attributes:

Name Type Description
s Tensor

Encoding for s.

z Tensor

Encoding for z.

module str

Module name.

samples Optional[Tensor]

Samples generated, if any.

subjid Optional[List[str | int]]

Subject IDs.

h_representation Optional[Tensor]

Hidden representation, if any.

Source code in vambn/modelling/models/hivae/outputs.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
@dataclass
class HivaeEncoding:
    """Dataclass for HIVAE encoding.

    Attributes:
        s (torch.Tensor): Encoding for s.
        z (torch.Tensor): Encoding for z.
        module (str): Module name.
        samples (Optional[Tensor]): Samples generated, if any.
        subjid (Optional[List[str | int]]): Subject IDs.
        h_representation (Optional[Tensor]): Hidden representation, if any.
    """

    s: torch.Tensor
    z: torch.Tensor
    module: str
    samples: Optional[Tensor] = None
    subjid: Optional[List[str | int]] = None
    h_representation: Optional[Tensor] = None

    def __post_init__(self):
        """
        Initialize the encoding and ensure tensors are on CPU and have the correct dtype.
        """

        if self.s.device != torch.device("cpu"):
            self.s = self.s.cpu()
        if self.z.device != torch.device("cpu"):
            self.z = self.z.cpu()
        if self.samples is not None and self.samples.device != torch.device(
            "cpu"
        ):
            self.samples = self.samples.cpu()

        if self.s.dtype != torch.float32:
            self.s = self.s.float()

        # make sure z encoding is float
        if self.z.dtype != torch.float32:
            self.z = self.z.float()

    @property
    def decoder_representation(self) -> Tensor:
        """
        Get the decoder representation.

        Returns:
            Tensor: The hidden representation if available, otherwise the z encoding.
        """
        return (
            self.h_representation
            if self.h_representation is not None
            else self.z
        )

    @decoder_representation.setter
    def decoder_representation(self, value: Tensor) -> None:
        """
        Set the decoder representation.

        Args:
            value (Tensor): The value to set as the hidden representation.
        """
        self.h_representation = value

    def convert(self) -> Dict[str, List[str | float]]:
        """
        Convert the encoding to a dictionary format.

        Returns:
            Dict[str, List[str | float]]: The converted encoding.
        """
        data = {
            f"{self.module}_s": self.s.argmax(dim=1).tolist(),
            "SUBJID": self.subjid,
        }
        if self.z.ndim == 2 and self.z.shape[1] > 1:
            for i in range(self.z.shape[1]):
                data[f"{self.module}_z{i}"] = self.z[:, i].tolist()
        else:
            data[f"{self.module}_z"] = self.z.view(-1).tolist()
        return data

    def get_samples(self, module: Optional[str] = None) -> Tensor:
        """
        Get the samples for the specified module.

        Args:
            module (Optional[str]): The module name. If None, return all samples.

        Returns:
            Tensor: The samples tensor.
        """
        if module is None:
            return self.samples
        else:
            assert self.module is not None
            assert module == self.module
            return self.samples

    @typeguard.typechecked
    def save_meta_enc(self, path: Path):
        """
        Save the metadata encoding to a CSV file.

        Args:
            path (Path): The file path to save the metadata.
        """
        path.parent.mkdir(parents=True, exist_ok=True)
        converted = self.convert()
        df = pd.DataFrame(converted)
        df.to_csv(path, index=False)
decoder_representation: Tensor property writable

Get the decoder representation.

Returns:

Name Type Description
Tensor Tensor

The hidden representation if available, otherwise the z encoding.

__post_init__()

Initialize the encoding and ensure tensors are on CPU and have the correct dtype.

Source code in vambn/modelling/models/hivae/outputs.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def __post_init__(self):
    """
    Initialize the encoding and ensure tensors are on CPU and have the correct dtype.
    """

    if self.s.device != torch.device("cpu"):
        self.s = self.s.cpu()
    if self.z.device != torch.device("cpu"):
        self.z = self.z.cpu()
    if self.samples is not None and self.samples.device != torch.device(
        "cpu"
    ):
        self.samples = self.samples.cpu()

    if self.s.dtype != torch.float32:
        self.s = self.s.float()

    # make sure z encoding is float
    if self.z.dtype != torch.float32:
        self.z = self.z.float()
convert()

Convert the encoding to a dictionary format.

Returns:

Type Description
Dict[str, List[str | float]]

Dict[str, List[str | float]]: The converted encoding.

Source code in vambn/modelling/models/hivae/outputs.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def convert(self) -> Dict[str, List[str | float]]:
    """
    Convert the encoding to a dictionary format.

    Returns:
        Dict[str, List[str | float]]: The converted encoding.
    """
    data = {
        f"{self.module}_s": self.s.argmax(dim=1).tolist(),
        "SUBJID": self.subjid,
    }
    if self.z.ndim == 2 and self.z.shape[1] > 1:
        for i in range(self.z.shape[1]):
            data[f"{self.module}_z{i}"] = self.z[:, i].tolist()
    else:
        data[f"{self.module}_z"] = self.z.view(-1).tolist()
    return data
get_samples(module=None)

Get the samples for the specified module.

Parameters:

Name Type Description Default
module Optional[str]

The module name. If None, return all samples.

None

Returns:

Name Type Description
Tensor Tensor

The samples tensor.

Source code in vambn/modelling/models/hivae/outputs.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def get_samples(self, module: Optional[str] = None) -> Tensor:
    """
    Get the samples for the specified module.

    Args:
        module (Optional[str]): The module name. If None, return all samples.

    Returns:
        Tensor: The samples tensor.
    """
    if module is None:
        return self.samples
    else:
        assert self.module is not None
        assert module == self.module
        return self.samples
save_meta_enc(path)

Save the metadata encoding to a CSV file.

Parameters:

Name Type Description Default
path Path

The file path to save the metadata.

required
Source code in vambn/modelling/models/hivae/outputs.py
493
494
495
496
497
498
499
500
501
502
503
504
@typeguard.typechecked
def save_meta_enc(self, path: Path):
    """
    Save the metadata encoding to a CSV file.

    Args:
        path (Path): The file path to save the metadata.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    converted = self.convert()
    df = pd.DataFrame(converted)
    df.to_csv(path, index=False)
HivaeOutput dataclass

Dataclass for HIVAE output.

Attributes:

Name Type Description
loss Tensor

Loss tensor.

enc_z Tensor

Encoded z values.

enc_s Tensor

Encoded s values.

samples Optional[Tensor]

Samples generated, if any.

n Optional[int]

Number of samples.

single bool

Whether this is a single output.

Source code in vambn/modelling/models/hivae/outputs.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@dataclass
class HivaeOutput:
    """Dataclass for HIVAE output.

    Attributes:
        loss (Tensor): Loss tensor.
        enc_z (Tensor): Encoded z values.
        enc_s (Tensor): Encoded s values.
        samples (Optional[Tensor]): Samples generated, if any.
        n (Optional[int]): Number of samples.
        single (bool): Whether this is a single output.
    """

    loss: Tensor
    enc_z: Tensor
    enc_s: Tensor
    samples: Optional[Tensor] = None
    n: Optional[int] = None
    single: bool = True

    def __post_init__(self):
        """
        Initialize the number of samples.
        """
        self.n = self.enc_z.shape[0]

    @property
    def n_loss(self) -> int:
        """
        Get the number of loss values.

        Returns:
            int: Number of loss values.
        """
        return 1 if self.loss.ndim == 0 else self.loss.shape[0]

    def detach(self) -> "HivaeOutput":
        """
        Detach all tensors and move them to CPU.

        Returns:
            HivaeOutput: The detached HivaeOutput object.
        """
        self.loss = self.loss.detach().cpu()
        self.enc_z = self.enc_z.detach().cpu()
        self.enc_s = self.enc_s.detach().cpu()

        if self.samples is not None:
            self.samples = self.samples.detach().cpu()
        return self

    @property
    def avg_loss(self) -> float:
        """
        Calculate the average loss.

        Returns:
            float: The average loss.

        Raises:
            ValueError: If the loss tensor has an invalid dimension.
        """
        if self.loss.ndim == 0:
            return float(self.loss)
        elif self.loss.ndim == 1:
            return float(self.loss.mean())
        else:
            raise ValueError(
                f"Loss is of wrong dimension ({self.loss.ndim}), expected 0 or 1"
            )

    def stack(self, other: "HivaeOutput") -> "HivaeOutput":
        """
        Stack another HivaeOutput object with this one.

        Args:
            other (HivaeOutput): The other HivaeOutput object to stack.

        Returns:
            HivaeOutput: The stacked HivaeOutput object.
        """

        self.single = False
        self.loss = torch.cat([self.loss.view(-1), other.loss.view(-1)])
        self.enc_z = torch.cat([self.enc_z, other.enc_z])
        self.enc_s = torch.cat([self.enc_s, other.enc_s])

        if self.samples is not None:
            self.samples = torch.cat([self.samples, other.samples])
        self.n += other.n
        return self

    def __add__(self, other: "HivaeOutput") -> "HivaeOutput":
        """
        Add another HivaeOutput object to this one.

        Args:
            other (HivaeOutput): The other HivaeOutput object to add.

        Returns:
            HivaeOutput: The resulting HivaeOutput object.
        """

        return self.stack(other)
avg_loss: float property

Calculate the average loss.

Returns:

Name Type Description
float float

The average loss.

Raises:

Type Description
ValueError

If the loss tensor has an invalid dimension.

n_loss: int property

Get the number of loss values.

Returns:

Name Type Description
int int

Number of loss values.

__add__(other)

Add another HivaeOutput object to this one.

Parameters:

Name Type Description Default
other HivaeOutput

The other HivaeOutput object to add.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The resulting HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
282
283
284
285
286
287
288
289
290
291
292
293
def __add__(self, other: "HivaeOutput") -> "HivaeOutput":
    """
    Add another HivaeOutput object to this one.

    Args:
        other (HivaeOutput): The other HivaeOutput object to add.

    Returns:
        HivaeOutput: The resulting HivaeOutput object.
    """

    return self.stack(other)
__post_init__()

Initialize the number of samples.

Source code in vambn/modelling/models/hivae/outputs.py
210
211
212
213
214
def __post_init__(self):
    """
    Initialize the number of samples.
    """
    self.n = self.enc_z.shape[0]
detach()

Detach all tensors and move them to CPU.

Returns:

Name Type Description
HivaeOutput HivaeOutput

The detached HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def detach(self) -> "HivaeOutput":
    """
    Detach all tensors and move them to CPU.

    Returns:
        HivaeOutput: The detached HivaeOutput object.
    """
    self.loss = self.loss.detach().cpu()
    self.enc_z = self.enc_z.detach().cpu()
    self.enc_s = self.enc_s.detach().cpu()

    if self.samples is not None:
        self.samples = self.samples.detach().cpu()
    return self
stack(other)

Stack another HivaeOutput object with this one.

Parameters:

Name Type Description Default
other HivaeOutput

The other HivaeOutput object to stack.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The stacked HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def stack(self, other: "HivaeOutput") -> "HivaeOutput":
    """
    Stack another HivaeOutput object with this one.

    Args:
        other (HivaeOutput): The other HivaeOutput object to stack.

    Returns:
        HivaeOutput: The stacked HivaeOutput object.
    """

    self.single = False
    self.loss = torch.cat([self.loss.view(-1), other.loss.view(-1)])
    self.enc_z = torch.cat([self.enc_z, other.enc_z])
    self.enc_s = torch.cat([self.enc_s, other.enc_s])

    if self.samples is not None:
        self.samples = torch.cat([self.samples, other.samples])
    self.n += other.n
    return self
LogLikelihoodOutput dataclass

Dataclass to hold the output from log-likelihood functions.

Attributes:

Name Type Description
log_p_x Tensor

Log-likelihood for observed data.

log_p_x_missing Tensor

Log-likelihood for missing data.

samples Optional[Tensor]

Samples generated, if any.

Source code in vambn/modelling/models/hivae/outputs.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@dataclass
class LogLikelihoodOutput:
    """Dataclass to hold the output from log-likelihood functions.

    Attributes:
        log_p_x (Tensor): Log-likelihood for observed data.
        log_p_x_missing (Tensor): Log-likelihood for missing data.
        samples (Optional[Tensor]): Samples generated, if any.
    """


    log_p_x: Tensor
    log_p_x_missing: Tensor
    samples: Optional[Tensor] = None
LstmHivaeOutput dataclass

Bases: HivaeOutput

Dataclass for LSTM HIVAE output.

Source code in vambn/modelling/models/hivae/outputs.py
296
297
298
299
300
@dataclass
class LstmHivaeOutput(HivaeOutput):
    """Dataclass for LSTM HIVAE output."""

    pass
ModularHivaeEncoding dataclass

Dataclass for modular HIVAE encoding.

Attributes:

Name Type Description
encodings Tuple[HivaeEncoding, ...]

Tuple of HIVAE encodings. See HivaeEncoding for details.

modules List[str]

List of module names in the same order as encodings.

Source code in vambn/modelling/models/hivae/outputs.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
@dataclass
class ModularHivaeEncoding:
    """Dataclass for modular HIVAE encoding.

    Attributes:
        encodings (Tuple[HivaeEncoding, ...]): Tuple of HIVAE encodings. See HivaeEncoding for details.
        modules (List[str]): List of module names in the same order as encodings.
    """
    encodings: Tuple[HivaeEncoding, ...]
    modules: List[str]

    def __post_init__(self):
        """
        Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

        Raises:
            Exception: If modules in encodings do not match modules in ModularHivaeEncoding.
        """
        for encoding in self.encodings:
            if encoding.s.device != torch.device("cpu"):
                encoding.s = encoding.s.cpu()
            if encoding.z.device != torch.device("cpu"):
                encoding.z = encoding.z.cpu()
            if (
                encoding.samples is not None
                and encoding.samples.device != torch.device("cpu")
            ):
                encoding.samples = encoding.samples.cpu()

            if any(
                [x.module != y for x, y in zip(self.encodings, self.modules)]
            ):
                raise Exception(
                    "Modules in encodings do not match modules in ModularHivaeEncoding"
                )

    def convert(self) -> Dict[str, List[float | str]]:
        """
        Convert the modular encoding to a dictionary format.

        Returns:
            Dict[str, List[float | str]]: The converted encoding.
        """
        out = {}
        for encoding in self.encodings:
            data = encoding.convert()
            out.update(data)
        return out

    @typeguard.typechecked
    def save_meta_enc(self, path: Path):
        """
        Save the metadata encoding to a CSV file.

        Args:
            path (Path): The file path to save the metadata.
        """
        path.parent.mkdir(parents=True, exist_ok=True)
        converted = self.convert()
        df = pd.DataFrame(converted)
        df.to_csv(path, index=False)

    def get_samples(
        self, module: Optional[str] = None
    ) -> Tensor | Dict[str, Tensor]:
        """
        Get the samples for the specified module.

        Args:
            module (Optional[str]): The module name. If None, return all samples.

        Returns:
            Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.
        """
        if module is None:
            raise {x.module: x.samples for x in self.encodings}
        else:
            selected_encodings = [
                x for x in self.encodings if x.module == module
            ]
            assert len(selected_encodings) == 1
            return selected_encodings[0].samples

    def __getitem__(self, idx: int) -> HivaeEncoding:
        """
        Get an encoding by index.

        Args:
            idx (int): The index of the encoding to retrieve.

        Returns:
            HivaeEncoding: The encoding at the specified index.
        """
        return self.encodings[idx]

    def get(self, module: str) -> HivaeEncoding:
        """
        Get an encoding by module name.

        Args:
            module (str): The module name.

        Returns:
            HivaeEncoding: The encoding for the specified module.

        Raises:
            Exception: If the module is not found in encodings.
        """
        for encoding in self.encodings:
            if encoding.module == module:
                return encoding
        raise Exception(f"Module {module} not found in encodings")
__getitem__(idx)

Get an encoding by index.

Parameters:

Name Type Description Default
idx int

The index of the encoding to retrieve.

required

Returns:

Name Type Description
HivaeEncoding HivaeEncoding

The encoding at the specified index.

Source code in vambn/modelling/models/hivae/outputs.py
590
591
592
593
594
595
596
597
598
599
600
def __getitem__(self, idx: int) -> HivaeEncoding:
    """
    Get an encoding by index.

    Args:
        idx (int): The index of the encoding to retrieve.

    Returns:
        HivaeEncoding: The encoding at the specified index.
    """
    return self.encodings[idx]
__post_init__()

Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

Raises:

Type Description
Exception

If modules in encodings do not match modules in ModularHivaeEncoding.

Source code in vambn/modelling/models/hivae/outputs.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def __post_init__(self):
    """
    Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

    Raises:
        Exception: If modules in encodings do not match modules in ModularHivaeEncoding.
    """
    for encoding in self.encodings:
        if encoding.s.device != torch.device("cpu"):
            encoding.s = encoding.s.cpu()
        if encoding.z.device != torch.device("cpu"):
            encoding.z = encoding.z.cpu()
        if (
            encoding.samples is not None
            and encoding.samples.device != torch.device("cpu")
        ):
            encoding.samples = encoding.samples.cpu()

        if any(
            [x.module != y for x, y in zip(self.encodings, self.modules)]
        ):
            raise Exception(
                "Modules in encodings do not match modules in ModularHivaeEncoding"
            )
convert()

Convert the modular encoding to a dictionary format.

Returns:

Type Description
Dict[str, List[float | str]]

Dict[str, List[float | str]]: The converted encoding.

Source code in vambn/modelling/models/hivae/outputs.py
543
544
545
546
547
548
549
550
551
552
553
554
def convert(self) -> Dict[str, List[float | str]]:
    """
    Convert the modular encoding to a dictionary format.

    Returns:
        Dict[str, List[float | str]]: The converted encoding.
    """
    out = {}
    for encoding in self.encodings:
        data = encoding.convert()
        out.update(data)
    return out
get(module)

Get an encoding by module name.

Parameters:

Name Type Description Default
module str

The module name.

required

Returns:

Name Type Description
HivaeEncoding HivaeEncoding

The encoding for the specified module.

Raises:

Type Description
Exception

If the module is not found in encodings.

Source code in vambn/modelling/models/hivae/outputs.py
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def get(self, module: str) -> HivaeEncoding:
    """
    Get an encoding by module name.

    Args:
        module (str): The module name.

    Returns:
        HivaeEncoding: The encoding for the specified module.

    Raises:
        Exception: If the module is not found in encodings.
    """
    for encoding in self.encodings:
        if encoding.module == module:
            return encoding
    raise Exception(f"Module {module} not found in encodings")
get_samples(module=None)

Get the samples for the specified module.

Parameters:

Name Type Description Default
module Optional[str]

The module name. If None, return all samples.

None

Returns:

Type Description
Tensor | Dict[str, Tensor]

Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.

Source code in vambn/modelling/models/hivae/outputs.py
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def get_samples(
    self, module: Optional[str] = None
) -> Tensor | Dict[str, Tensor]:
    """
    Get the samples for the specified module.

    Args:
        module (Optional[str]): The module name. If None, return all samples.

    Returns:
        Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.
    """
    if module is None:
        raise {x.module: x.samples for x in self.encodings}
    else:
        selected_encodings = [
            x for x in self.encodings if x.module == module
        ]
        assert len(selected_encodings) == 1
        return selected_encodings[0].samples
save_meta_enc(path)

Save the metadata encoding to a CSV file.

Parameters:

Name Type Description Default
path Path

The file path to save the metadata.

required
Source code in vambn/modelling/models/hivae/outputs.py
556
557
558
559
560
561
562
563
564
565
566
567
@typeguard.typechecked
def save_meta_enc(self, path: Path):
    """
    Save the metadata encoding to a CSV file.

    Args:
        path (Path): The file path to save the metadata.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    converted = self.convert()
    df = pd.DataFrame(converted)
    df.to_csv(path, index=False)
ModularHivaeOutput dataclass

Dataclass for modular HIVAE output.

Attributes:

Name Type Description
outputs Tuple[HivaeOutputs, ...]

Tuple of HIVAE outputs. See HivaeOutputs for details.

Source code in vambn/modelling/models/hivae/outputs.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@dataclass
class ModularHivaeOutput:
    """Dataclass for modular HIVAE output.

    Attributes:
        outputs (Tuple[HivaeOutputs, ...]): Tuple of HIVAE outputs. See HivaeOutputs for details.
    """

    outputs: Tuple[HivaeOutputs, ...]

    def __add__(self, other: "ModularHivaeOutput") -> "ModularHivaeOutput":
        """
        Add another ModularHivaeOutput object to this one.

        Args:
            other (ModularHivaeOutput): The other ModularHivaeOutput object to add.

        Returns:
            ModularHivaeOutput: The resulting ModularHivaeOutput object.
        """
        for old, new in zip(self.outputs, other.outputs):
            old += new
        logger.debug(f"Added {other} to {self}")
        return self

    def detach(self) -> "ModularHivaeOutput":
        """
        Detach all tensors in the outputs and move them to CPU.

        Returns:
            ModularHivaeOutput: The detached ModularHivaeOutput object.
        """
        for output in self.outputs:
            output.detach()
        return self

    @property
    def avg_loss(self) -> float:
        """
        Calculate the average loss across all outputs.

        Returns:
            float: The average loss.
        """
        return sum([x.avg_loss for x in self.outputs])

    @property
    def loss(self) -> Tensor:
        """
        Calculate the total loss across all outputs.

        Returns:
            Tensor: The total loss tensor.
        """
        return torch.stack([x.loss for x in self.outputs]).sum()

    def __iter__(self):
        """
        Iterate over the outputs.

        Returns:
            Iterator: An iterator over the outputs.
        """
        return iter(self.outputs)

    def __len__(self):
        """
        Get the number of outputs.

        Returns:
            int: The number of outputs.
        """
        return len(self.outputs)

    def __item__(self, idx: int) -> HivaeOutputs:
        """
        Get an output by index.

        Args:
            idx (int): The index of the output to retrieve.

        Returns:
            HivaeOutputs: The output at the specified index.
        """
        return self.outputs[idx]
avg_loss: float property

Calculate the average loss across all outputs.

Returns:

Name Type Description
float float

The average loss.

loss: Tensor property

Calculate the total loss across all outputs.

Returns:

Name Type Description
Tensor Tensor

The total loss tensor.

__add__(other)

Add another ModularHivaeOutput object to this one.

Parameters:

Name Type Description Default
other ModularHivaeOutput

The other ModularHivaeOutput object to add.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The resulting ModularHivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def __add__(self, other: "ModularHivaeOutput") -> "ModularHivaeOutput":
    """
    Add another ModularHivaeOutput object to this one.

    Args:
        other (ModularHivaeOutput): The other ModularHivaeOutput object to add.

    Returns:
        ModularHivaeOutput: The resulting ModularHivaeOutput object.
    """
    for old, new in zip(self.outputs, other.outputs):
        old += new
    logger.debug(f"Added {other} to {self}")
    return self
__item__(idx)

Get an output by index.

Parameters:

Name Type Description Default
idx int

The index of the output to retrieve.

required

Returns:

Name Type Description
HivaeOutputs HivaeOutputs

The output at the specified index.

Source code in vambn/modelling/models/hivae/outputs.py
380
381
382
383
384
385
386
387
388
389
390
def __item__(self, idx: int) -> HivaeOutputs:
    """
    Get an output by index.

    Args:
        idx (int): The index of the output to retrieve.

    Returns:
        HivaeOutputs: The output at the specified index.
    """
    return self.outputs[idx]
__iter__()

Iterate over the outputs.

Returns:

Name Type Description
Iterator

An iterator over the outputs.

Source code in vambn/modelling/models/hivae/outputs.py
362
363
364
365
366
367
368
369
def __iter__(self):
    """
    Iterate over the outputs.

    Returns:
        Iterator: An iterator over the outputs.
    """
    return iter(self.outputs)
__len__()

Get the number of outputs.

Returns:

Name Type Description
int

The number of outputs.

Source code in vambn/modelling/models/hivae/outputs.py
371
372
373
374
375
376
377
378
def __len__(self):
    """
    Get the number of outputs.

    Returns:
        int: The number of outputs.
    """
    return len(self.outputs)
detach()

Detach all tensors in the outputs and move them to CPU.

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The detached ModularHivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
331
332
333
334
335
336
337
338
339
340
def detach(self) -> "ModularHivaeOutput":
    """
    Detach all tensors in the outputs and move them to CPU.

    Returns:
        ModularHivaeOutput: The detached ModularHivaeOutput object.
    """
    for output in self.outputs:
        output.detach()
    return self

shared

AvgModuleMtl

Bases: BaseModule

This module averages the z's and passes them through the MOO block. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z.

required
ys_dim int

The dimension of the shared output ys.

required
y_dim int | Tuple[int, ...]

The dimension of the individual outputs y.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method used for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
class AvgModuleMtl(BaseModule):
    """
    This module averages the z's and passes them through the MOO block. The output is then
    passed through individual layers to generate the final outputs.

    Args:
        z_dim (int): The dimension of the input z.
        ys_dim (int): The dimension of the shared output ys.
        y_dim (int | Tuple[int, ...]): The dimension of the individual outputs y.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method used for multi-task learning. Defaults to None.

    Attributes:
        _mtl_module (MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (MultiMOOForLoop): The multi-objective optimization block.
        scaling_layers (nn.ModuleList): The list of scaling layers.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [BaseElement(self.z_dim, self.z_dim) for i in range(self.n_modules)]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs the forward pass of the module. The input z's are averaged and passed through
        the MOO block. The output is then passed through the individual scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """
        # Average the z's and pass through the shared layer
        (h,) = self.moo_block(torch.mean(torch.stack(z), dim=0))
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), h)
            ]
        )
forward(z)

Performs the forward pass of the module. The input z's are averaged and passed through the MOO block. The output is then passed through the individual scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs the forward pass of the module. The input z's are averaged and passed through
    the MOO block. The output is then passed through the individual scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """
    # Average the z's and pass through the shared layer
    (h,) = self.moo_block(torch.mean(torch.stack(z), dim=0))
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), h)
        ]
    )
order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
469
470
471
472
473
474
475
476
477
478
479
480
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
BaseElement

Bases: Module

A base neural network element that consists of a single modified linear layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension of the layer.

required
output_dim int

Output dimension of the layer.

required
Source code in vambn/modelling/models/hivae/shared.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class BaseElement(nn.Module):
    """
    A base neural network element that consists of a single modified linear layer.

    Args:
        input_dim (int): Input dimension of the layer.
        output_dim (int): Output dimension of the layer.
    """

    def __init__(self, input_dim: int, output_dim: int) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layer = ModifiedLinear(self.input_dim, self.output_dim, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform the forward pass through the layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after passing through the layer.
        """
        return self.layer(x)
forward(x)

Perform the forward pass through the layer.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Output tensor after passing through the layer.

Source code in vambn/modelling/models/hivae/shared.py
27
28
29
30
31
32
33
34
35
36
37
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Perform the forward pass through the layer.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output tensor after passing through the layer.
    """
    return self.layer(x)
BaseModule

Bases: Module, ABC

Base class for all modules in the HIVAE model.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Raises:

Type Description
ValueError

If the length of y_dim is not equal to the number of modules.

Source code in vambn/modelling/models/hivae/shared.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class BaseModule(nn.Module, ABC):
    """
    Base class for all modules in the HIVAE model.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Raises:
        ValueError: If the length of y_dim is not equal to the number of modules.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        """
        Base class for all modules in the HIVAE model.

        Args:
            z_dim (int): Input dimension of the latent samples z.
            ys_dim (int): Intermediate dimension of the latent samples ys.
            y_dim (int | Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
            n_modules (int): Number of modules.
            module_names (Tuple[str, ...]): Names of the modules.
            mtl_method (Optional[Tuple[str, ...]], optional): MTL methods used to avoid conflicting gradients. Defaults to None.

        Raises:
            ValueError: If the length of y_dim is not equal to the number of modules.
        """
        super().__init__()
        self.z_dim = z_dim
        self.ys_dim = ys_dim
        self.y_dim = y_dim
        self.n_modules = n_modules
        self.mtl_method = mtl_method
        self.module_names = module_names
        self.has_params = True

        if (
            isinstance(self.y_dim, Iterable)
            and len(self.y_dim) != self.n_modules
        ):
            raise ValueError(
                f"Length of y_dim must be equal to the number of modules: {self.n_modules}"
            )


    @abstractmethod
    def order_layers(self, module_names: Tuple[str, ...]) -> None:
        """
        Order the layers of the module according to the module names.

        Args:
            module_names (Tuple[str, ...]): Names of the modules
        """
        pass

    @abstractmethod
    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        pass

    def _y_dim(self, i: int) -> int:
        """
        Retruns the output dimension of the ith module.

        Args:
            i (int): Index of the module.

        Returns:
            int: Output dimension of the ith module.
        """
        return self.y_dim if isinstance(self.y_dim, int) else self.y_dim[i]
__init__(z_dim, ys_dim, y_dim, n_modules, module_names, mtl_method=None)

Base class for all modules in the HIVAE model.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int | Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Raises:

Type Description
ValueError

If the length of y_dim is not equal to the number of modules.

Source code in vambn/modelling/models/hivae/shared.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    z_dim: int,
    ys_dim: int,
    y_dim: int | Tuple[int, ...],
    n_modules: int,
    module_names: Tuple[str, ...],
    mtl_method: Optional[Tuple[str, ...]] = None,
) -> None:
    """
    Base class for all modules in the HIVAE model.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int | Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]], optional): MTL methods used to avoid conflicting gradients. Defaults to None.

    Raises:
        ValueError: If the length of y_dim is not equal to the number of modules.
    """
    super().__init__()
    self.z_dim = z_dim
    self.ys_dim = ys_dim
    self.y_dim = y_dim
    self.n_modules = n_modules
    self.mtl_method = mtl_method
    self.module_names = module_names
    self.has_params = True

    if (
        isinstance(self.y_dim, Iterable)
        and len(self.y_dim) != self.n_modules
    ):
        raise ValueError(
            f"Length of y_dim must be equal to the number of modules: {self.n_modules}"
        )
order_layers(module_names) abstractmethod

Order the layers of the module according to the module names.

Parameters:

Name Type Description Default
module_names Tuple[str, ...]

Names of the modules

required
Source code in vambn/modelling/models/hivae/shared.py
 97
 98
 99
100
101
102
103
104
105
@abstractmethod
def order_layers(self, module_names: Tuple[str, ...]) -> None:
    """
    Order the layers of the module according to the module names.

    Args:
        module_names (Tuple[str, ...]): Names of the modules
    """
    pass
ConcatModule

Bases: BaseModule

A module that concatenates multiple input tensors and applies scaling layers.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z tensors.

required
ys_dim int

The dimension of the output ys tensors.

required
y_dim int or Tuple[int, ...]

The dimension of the output y tensors.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method used for multi-task learning (default: None).

None

Attributes:

Name Type Description
shared_layer BaseElement

The shared layer that concatenates the input z tensors.

scaling_layers ModuleList

The list of scaling layers for each module.

Source code in vambn/modelling/models/hivae/shared.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class ConcatModule(BaseModule):
    """
    A module that concatenates multiple input tensors and applies scaling layers.

    Args:
        z_dim (int): The dimension of the input z tensors.
        ys_dim (int): The dimension of the output ys tensors.
        y_dim (int or Tuple[int, ...]): The dimension of the output y tensors.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method used for multi-task learning (default: None).

    Attributes:
        shared_layer (BaseElement): The shared layer that concatenates the input z tensors.
        scaling_layers (nn.ModuleList): The list of scaling layers for each module.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.shared_layer = BaseElement(
            self.z_dim * n_modules, self.ys_dim * n_modules
        )
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Reorders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The new order of the module names.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the ConcatModule. The input tensors z are concatenated
        and passed through the shared layer to generate a single output tensor. The output
        tensor is then passed through the scaling layers to generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

        """
        # z-type: Tuple[torch.Tensor, ...], with each z having shape (batch_size, z_dim)
        h = self.shared_layer(torch.cat(z, dim=1))
        # size of h: (batch_size, ys_dim * n_modules)
        return tuple(
            [
                self.scaling_layers[i](
                    h[:, i * self.ys_dim : (i + 1) * self.ys_dim]
                )
                for i in range(self.n_modules)
            ]
        )  # size of h[:, i * self.ys_dim : (i + 1) * self.ys_dim]: (batch_size, ys_dim)
forward(z)

Performs forward pass through the ConcatModule. The input tensors z are concatenated and passed through the shared layer to generate a single output tensor. The output tensor is then passed through the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the ConcatModule. The input tensors z are concatenated
    and passed through the shared layer to generate a single output tensor. The output
    tensor is then passed through the scaling layers to generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

    """
    # z-type: Tuple[torch.Tensor, ...], with each z having shape (batch_size, z_dim)
    h = self.shared_layer(torch.cat(z, dim=1))
    # size of h: (batch_size, ys_dim * n_modules)
    return tuple(
        [
            self.scaling_layers[i](
                h[:, i * self.ys_dim : (i + 1) * self.ys_dim]
            )
            for i in range(self.n_modules)
        ]
    )  # size of h[:, i * self.ys_dim : (i + 1) * self.ys_dim]: (batch_size, ys_dim)
order_layers(module_names)

Reorders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The new order of the module names.

required
Source code in vambn/modelling/models/hivae/shared.py
382
383
384
385
386
387
388
389
390
391
392
393
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Reorders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The new order of the module names.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
ConcatModuleMtl

Bases: BaseModule

This module concatenates the z's and passes them through a shared layer before passing through the MOO block. The output is then passed through individual layers to generate the final outputs. This is the same as the ConcatModule, but with the addition of the MOO block.

Source code in vambn/modelling/models/hivae/shared.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class ConcatModuleMtl(BaseModule):
    """
    This module concatenates the z's and passes them through a shared layer before
    passing through the MOO block. The output is then passed through individual layers
    to generate the final outputs. This is the same as the ConcatModule, but with the
    addition of the MOO block.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.mtl_method = mtl_method
        self.shared_layer = BaseElement(z_dim * n_modules, ys_dim)
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. The z's are concatenated and passed through
        the shared layer to generate one output which is identical for all modules. The
        output is then passed through the MOO block and individual layers to generate
        the final outputs.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
        """
        # Concatenate the z's and pass through the
        # shared layer to generate one output which is identical for all modules
        h = self.shared_layer(torch.cat(z, dim=1))
        # Pass through the MOO block
        (hs,) = self.moo_block(h)
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), hs)
            ]
        )
forward(z)

Forward pass through the module. The z's are concatenated and passed through the shared layer to generate one output which is identical for all modules. The output is then passed through the MOO block and individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.

Source code in vambn/modelling/models/hivae/shared.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. The z's are concatenated and passed through
    the shared layer to generate one output which is identical for all modules. The
    output is then passed through the MOO block and individual layers to generate
    the final outputs.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
    """
    # Concatenate the z's and pass through the
    # shared layer to generate one output which is identical for all modules
    h = self.shared_layer(torch.cat(z, dim=1))
    # Pass through the MOO block
    (hs,) = self.moo_block(h)
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), hs)
        ]
    )
order_layers(module_names)

Order the layers based on the module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
298
299
300
301
302
303
304
305
306
307
308
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
EncoderModule

Bases: BaseModule

EncoderModule class represents a module for encoding input data.

Parameters:

Name Type Description Default
z_dim int

The dimension of the latent space.

required
ys_dim int

The dimension of the output space.

required
y_dim int or Tuple[int, ...]

The dimension(s) of the input data.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method(s) for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
attention SelfAttention

The self-attention layer.

feed_forward Sequential

The feed-forward neural network.

dropout Dropout

The dropout layer.

layer_norm_1 LayerNorm

The layer normalization layer.

layer_norm_2 LayerNorm

The layer normalization layer.

ys_layer Linear

The linear layer for output.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
class EncoderModule(BaseModule):
    """
    EncoderModule class represents a module for encoding input data.

    Args:
        z_dim (int): The dimension of the latent space.
        ys_dim (int): The dimension of the output space.
        y_dim (int or Tuple[int, ...]): The dimension(s) of the input data.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method(s) for multi-task learning. Defaults to None.

    Attributes:
        attention (SelfAttention): The self-attention layer.
        feed_forward (nn.Sequential): The feed-forward neural network.
        dropout (nn.Dropout): The dropout layer.
        layer_norm_1 (nn.LayerNorm): The layer normalization layer.
        layer_norm_2 (nn.LayerNorm): The layer normalization layer.
        ys_layer (nn.Linear): The linear layer for output.
        scaling_layers (nn.ModuleList): The list of scaling layers.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        cat_dim = z_dim * n_modules
        self.attention = SelfAttention(cat_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(cat_dim, ys_dim),
            nn.GELU(),
            nn.Linear(ys_dim, cat_dim),
            nn.Dropout(0.1),
        )
        self.dropout = nn.Dropout(0.1)
        self.layer_norm_1 = nn.LayerNorm(cat_dim)
        self.layer_norm_2 = nn.LayerNorm(cat_dim)
        self.ys_layer = nn.Linear(cat_dim, ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.

        Returns:
            None

        """

        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the encoder module. The input tensors z are concatenated
        and passed through the self-attention layer. The output is then passed through the feed-forward
        neural network and the output layer. The output is then passed through the scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """

        x = torch.cat(z, dim=1)
        attended = self.attention(torch.cat(z, dim=1))
        x = self.layer_norm_1(x + self.dropout(attended))

        h = self.feed_forward(x)
        h = self.layer_norm_2(x + self.dropout(h))
        yb = self.ys_layer(h)
        y = tuple([self.scaling_layers[i](yb) for i in range(self.n_modules)])
        return y
forward(z)

Performs forward pass through the encoder module. The input tensors z are concatenated and passed through the self-attention layer. The output is then passed through the feed-forward neural network and the output layer. The output is then passed through the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the encoder module. The input tensors z are concatenated
    and passed through the self-attention layer. The output is then passed through the feed-forward
    neural network and the output layer. The output is then passed through the scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """

    x = torch.cat(z, dim=1)
    attended = self.attention(torch.cat(z, dim=1))
    x = self.layer_norm_1(x + self.dropout(attended))

    h = self.feed_forward(x)
    h = self.layer_norm_2(x + self.dropout(h))
    yb = self.ys_layer(h)
    y = tuple([self.scaling_layers[i](yb) for i in range(self.n_modules)])
    return y
order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/shared.py
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.

    Returns:
        None

    """

    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
EncoderModuleMtl

Bases: BaseModule

Encoder module for multi-task learning.

Parameters:

Name Type Description Default
z_dim int

The dimension of the latent space.

required
ys_dim int

The dimension of the shared representation.

required
y_dim int | Tuple[int, ...]

The dimension(s) of the task-specific representations.

required
n_modules int

The number of task-specific modules.

required
module_names Tuple[str, ...]

The names of the task-specific modules.

required
mtl_method Optional[Tuple[str, ...]]

The multi-task learning method(s) to use. Defaults to None.

None

Attributes:

Name Type Description
attention SelfAttention

The self-attention layer.

feed_forward Sequential

The feed-forward neural network.

dropout Dropout

The dropout layer.

layer_norm_1 LayerNorm

The first layer normalization.

layer_norm_2 LayerNorm

The second layer normalization.

ys_layer Linear

The linear layer for shared representation.

scaling_layers ModuleList

The list of scaling layers for task-specific representations.

_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

Raises:

Type Description
Exception

This class should no longer be used.

Source code in vambn/modelling/models/hivae/shared.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
class EncoderModuleMtl(BaseModule):
    """Encoder module for multi-task learning.

    Args:
        z_dim (int): The dimension of the latent space.
        ys_dim (int): The dimension of the shared representation.
        y_dim (int | Tuple[int, ...]): The dimension(s) of the task-specific representations.
        n_modules (int): The number of task-specific modules.
        module_names (Tuple[str, ...]): The names of the task-specific modules.
        mtl_method (Optional[Tuple[str, ...]], optional): The multi-task learning method(s) to use. Defaults to None.

    Attributes:
        attention (SelfAttention): The self-attention layer.
        feed_forward (nn.Sequential): The feed-forward neural network.
        dropout (nn.Dropout): The dropout layer.
        layer_norm_1 (nn.LayerNorm): The first layer normalization.
        layer_norm_2 (nn.LayerNorm): The second layer normalization.
        ys_layer (nn.Linear): The linear layer for shared representation.
        scaling_layers (nn.ModuleList): The list of scaling layers for task-specific representations.
        _mtl_module (moo.MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (moo.MultiMOOForLoop): The multi-objective optimization block.

    Raises:
        Exception: This class should no longer be used.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        cat_dim = z_dim * n_modules
        self.attention = SelfAttention(cat_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(cat_dim, ys_dim),
            nn.GELU(),
            nn.Linear(ys_dim, cat_dim),
            nn.Dropout(0.1),
        )
        self.dropout = nn.Dropout(0.1)
        self.layer_norm_1 = nn.LayerNorm(cat_dim)
        self.layer_norm_2 = nn.LayerNorm(cat_dim)
        self.ys_layer = nn.Linear(cat_dim, ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """Order the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the task-specific modules.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """Forward pass of the module. The input tensors z are concatenated and passed through
        the self-attention layer. The output is then passed through the feed-forward neural network
        combined with the residual connection and layer normalization. The output is then passed through
        the MTL block and the scaling layers to generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """
        x = torch.cat(z, dim=1)
        attended = self.attention(torch.cat(z, dim=1))
        x = self.layer_norm_1(x + self.dropout(attended))

        h = self.feed_forward(x)
        h = self.layer_norm_2(x + self.dropout(h))
        yb = self.ys_layer(h)
        (yb_moo,) = self.moo_block(yb)
        y = tuple(
            [
                self.scaling_layers[i](yi)
                for i, yi in zip(range(self.n_modules), yb_moo)
            ]
        )
        return y
forward(z)

Forward pass of the module. The input tensors z are concatenated and passed through the self-attention layer. The output is then passed through the feed-forward neural network combined with the residual connection and layer normalization. The output is then passed through the MTL block and the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """Forward pass of the module. The input tensors z are concatenated and passed through
    the self-attention layer. The output is then passed through the feed-forward neural network
    combined with the residual connection and layer normalization. The output is then passed through
    the MTL block and the scaling layers to generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """
    x = torch.cat(z, dim=1)
    attended = self.attention(torch.cat(z, dim=1))
    x = self.layer_norm_1(x + self.dropout(attended))

    h = self.feed_forward(x)
    h = self.layer_norm_2(x + self.dropout(h))
    yb = self.ys_layer(h)
    (yb_moo,) = self.moo_block(yb)
    y = tuple(
        [
            self.scaling_layers[i](yi)
            for i, yi in zip(range(self.n_modules), yb_moo)
        ]
    )
    return y
order_layers(module_names)

Order the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the task-specific modules.

required
Source code in vambn/modelling/models/hivae/shared.py
804
805
806
807
808
809
810
811
812
813
814
def order_layers(self, module_names: Tuple[str]) -> None:
    """Order the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the task-specific modules.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
ImposterModule

Bases: BaseModule

This module is used as a placeholder when no modularity is desired. It simply returns the input z as the output.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Attributes:

Name Type Description
has_params bool

Whether the module has parameters.

Source code in vambn/modelling/models/hivae/shared.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ImposterModule(BaseModule):
    """
    This module is used as a placeholder when no modularity is desired. It simply
    returns the input z as the output.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Attributes:
        has_params (bool): Whether the module has parameters.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.has_params = False

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names. Since this module has no parameters,
        this method does nothing in the case of the ImposterModule.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        pass

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. Since this module has no parameters, it simply
        returns the input z.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor.
        """
        return z
forward(z)

Forward pass through the module. Since this module has no parameters, it simply returns the input z.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor.

Source code in vambn/modelling/models/hivae/shared.py
170
171
172
173
174
175
176
177
178
179
180
181
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. Since this module has no parameters, it simply
    returns the input z.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor.
    """
    return z
order_layers(module_names)

Order the layers based on the module names. Since this module has no parameters, this method does nothing in the case of the ImposterModule.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
160
161
162
163
164
165
166
167
168
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names. Since this module has no parameters,
    this method does nothing in the case of the ImposterModule.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    pass
MaxModuleMtl

Bases: BaseModule

This module takes the maximum of the z's and passes them through the MOO block. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z.

required
ys_dim int

The dimension of the ys.

required
y_dim int | Tuple[int, ...]

The dimension of the output y.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
class MaxModuleMtl(BaseModule):
    """
    This module takes the maximum of the z's and passes them through the MOO block. The
    output is then passed through individual layers to generate the final outputs.

    Args:
        z_dim (int): The dimension of the input z.
        ys_dim (int): The dimension of the ys.
        y_dim (int | Tuple[int, ...]): The dimension of the output y.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method for multi-task learning. Defaults to None.

    Attributes:
        _mtl_module (MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (MultiMOOForLoop): The multi-objective optimization block.
        scaling_layers (nn.ModuleList): The list of scaling layers.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [BaseElement(self.z_dim, self.z_dim) for i in range(self.n_modules)]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the module. The maximum of the z's is passed through
        the MOO block. The output is then passed through individual scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.
        """
        # Average the z's and pass through the shared layer
        (h,) = self.moo_block(torch.max(torch.stack(z), dim=0).values)
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), h)
            ]
        )
forward(z)

Performs forward pass through the module. The maximum of the z's is passed through the MOO block. The output is then passed through individual scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the module. The maximum of the z's is passed through
    the MOO block. The output is then passed through individual scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.
    """
    # Average the z's and pass through the shared layer
    (h,) = self.moo_block(torch.max(torch.stack(z), dim=0).values)
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), h)
        ]
    )
order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
553
554
555
556
557
558
559
560
561
562
563
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )
SelfAttention

Bases: Module

Self-Attention module.

Parameters:

Name Type Description Default
hidden_dim int

The dimension of the input and output tensors.

required

Attributes:

Name Type Description
query Linear

Linear layer for computing the query tensor.

key Linear

Linear layer for computing the key tensor.

value Linear

Linear layer for computing the value tensor.

softmax Softmax

Softmax function for computing attention weights.

Source code in vambn/modelling/models/hivae/shared.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
class SelfAttention(nn.Module):
    """
    Self-Attention module.

    Args:
        hidden_dim (int): The dimension of the input and output tensors.

    Attributes:
        query (nn.Linear): Linear layer for computing the query tensor.
        key (nn.Linear): Linear layer for computing the key tensor.
        value (nn.Linear): Linear layer for computing the value tensor.
        softmax (nn.Softmax): Softmax function for computing attention weights.
    """

    def __init__(self, hidden_dim):
        super().__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, z):
        """
        Forward pass of the SelfAttention module. Computes the query, key, and value tensors
        and uses them to compute the attention weights. The attention weights are then used
        to compute the attended values.

        Args:
            z (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).
        """
        q = self.query(z)
        k = self.key(z)
        v = self.value(z)

        attention_scores = torch.matmul(q, k.transpose(-2, -1))
        attention_weights = self.softmax(attention_scores)

        attended_values = torch.matmul(attention_weights, v)
        return attended_values
forward(z)

Forward pass of the SelfAttention module. Computes the query, key, and value tensors and uses them to compute the attention weights. The attention weights are then used to compute the attended values.

Parameters:

Name Type Description Default
z Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required

Returns:

Type Description

torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).

Source code in vambn/modelling/models/hivae/shared.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def forward(self, z):
    """
    Forward pass of the SelfAttention module. Computes the query, key, and value tensors
    and uses them to compute the attention weights. The attention weights are then used
    to compute the attended values.

    Args:
        z (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).
    """
    q = self.query(z)
    k = self.key(z)
    v = self.value(z)

    attention_scores = torch.matmul(q, k.transpose(-2, -1))
    attention_weights = self.softmax(attention_scores)

    attended_values = torch.matmul(attention_weights, v)
    return attended_values
SharedLinearModule

Bases: BaseModule

This module passes the z's through a single shared dense layer that generates the individual outputs for each module using the same weights. Outputs are generated one by one. The assumption is that z shares the same dimensional space across modules. This is the simplest form of modularity.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Attributes:

Name Type Description
shared_layer BaseElement

Shared dense layer.

scaling_layers ModuleList

List of individual dense layers for each module.

Source code in vambn/modelling/models/hivae/shared.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class SharedLinearModule(BaseModule):
    """
    This module passes the z's through a single shared dense layer that generates the
    individual outputs for each module using the same weights. Outputs are generated
    one by one. The assumption is that z shares the same dimensional space across
    modules. This is the simplest form of modularity.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Attributes:
        shared_layer (BaseElement): Shared dense layer.
        scaling_layers (nn.ModuleList): List of individual dense layers for each module.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.shared_layer = BaseElement(self.z_dim, self.ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. The z's are passed through the shared layer
        to generate one output which is identical for all modules. The output is then
        passed through individual layers to generate the final outputs.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
        """
        ys = tuple([self.shared_layer(zi) for zi in z])
        return tuple([self.scaling_layers[i](ysi) for i, ysi in enumerate(ys)])
forward(z)

Forward pass through the module. The z's are passed through the shared layer to generate one output which is identical for all modules. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.

Source code in vambn/modelling/models/hivae/shared.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. The z's are passed through the shared layer
    to generate one output which is identical for all modules. The output is then
    passed through individual layers to generate the final outputs.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
    """
    ys = tuple([self.shared_layer(zi) for zi in z])
    return tuple([self.scaling_layers[i](ysi) for i, ysi in enumerate(ys)])
order_layers(module_names)

Order the layers based on the module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
229
230
231
232
233
234
235
236
237
238
239
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

trainer

BaseTrainer

Bases: Generic[TConfig, TModel, TPredict, TrainerType, TEncoding], ABC

Base class for trainers in the VAMBN2 framework.

Parameters:

Name Type Description Default
dataset VambnDataset

The dataset to use for training.

required
config PipelineConfig

The configuration object for the pipeline.

required
workers int

The number of workers to use for data loading.

required
checkpoint_path Path

The path to save checkpoints during training.

required
module_name Optional[str]

The name of the module. Defaults to None.

None
experiment_name Optional[str]

The name of the experiment. Defaults to None.

None
force_cpu bool

Whether to force CPU usage. Defaults to False.

False

Attributes:

Name Type Description
dataset VambnDataset

The dataset used for training.

config PipelineConfig

The configuration object for the pipeline.

workers int

The number of workers used for data loading.

checkpoint_path Path

The path to save checkpoints during training.

model Optional[TModel]

The model used for training.

model_config Optional[TConfig]

The configuration object for the model.

module_name Optional[str]

The name of the module.

experiment_name Optional[str]

The name of the experiment.

type str

The type of the trainer.

device device

The device used for training.

use_mtl bool

Whether to use multi-task learning.

use_gan bool

Whether to use generative adversarial networks.

Source code in vambn/modelling/models/hivae/trainer.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
class BaseTrainer(
    Generic[TConfig, TModel, TPredict, TrainerType, TEncoding], ABC
):
    """
    Base class for trainers in the VAMBN2 framework.

    Args:
        dataset (VambnDataset): The dataset to use for training.
        config (PipelineConfig): The configuration object for the pipeline.
        workers (int): The number of workers to use for data loading.
        checkpoint_path (Path): The path to save checkpoints during training.
        module_name (Optional[str], optional): The name of the module. Defaults to None.
        experiment_name (Optional[str], optional): The name of the experiment. Defaults to None.
        force_cpu (bool, optional): Whether to force CPU usage. Defaults to False.

    Attributes:
        dataset (VambnDataset): The dataset used for training.
        config (PipelineConfig): The configuration object for the pipeline.
        workers (int): The number of workers used for data loading.
        checkpoint_path (Path): The path to save checkpoints during training.
        model (Optional[TModel]): The model used for training.
        model_config (Optional[TConfig]): The configuration object for the model.
        module_name (Optional[str]): The name of the module.
        experiment_name (Optional[str]): The name of the experiment.
        type (str): The type of the trainer.
        device (torch.device): The device used for training.
        use_mtl (bool): Whether to use multi-task learning.
        use_gan (bool): Whether to use generative adversarial networks.
    """

    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: Optional[str] = None,
        experiment_name: Optional[str] = None,
        force_cpu: bool = False,
    ):
        self.dataset = dataset
        self.config = config
        self.workers = workers
        self.checkpoint_path = checkpoint_path

        self.model = None
        self.model_config = None
        self.module_name = module_name
        self.experiment_name = experiment_name
        self.type = "base"
        self.device = (
            torch.device("cuda")
            if torch.cuda.is_available() and not force_cpu
            else torch.device("cpu")
        )
        self.use_mtl = config.training.use_mtl
        self.use_gan = config.training.with_gan
        logger.info(f"Use {self.device} for training")

    @cached_property
    def run_base(self) -> str:
        """
        Generates a base name for the training run. Used e.g. for MLflow run names.

        Returns:
            str: The base name for the Training run.
        """
        base = f"{self.type}_{'wmtl' if self.use_mtl else 'womtl'}_{'wgan' if self.use_gan else 'wogan'}"
        if self.module_name is not None:
            base += f"_{self.module_name}"
        return base

    def get_dataloader(
        self, dataset: VambnDataset, batch_size: int, shuffle: bool
    ) -> DataLoader:
        """
        Get a DataLoader object for the given dataset.

        Args:
            dataset (VambnDataset): The dataset to use.
            batch_size (int): The batch size.
            shuffle (bool): Whether to shuffle the data.

        Returns:
            DataLoader: The DataLoader object.

        Raises:
            ValueError: If the dataset is empty.
        """
        # Set the number of workers to 0
        self.workers = 0

        # Create and return the DataLoader object
        return DataLoader(
            dataset.get_iter_dataset(self.module_name)
            if self.module_name is not None
            else dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.workers,
            # collate_fn=self.custom_collate,
            persistent_workers=True if self.workers > 0 else False,
            pin_memory=True if self.device.type == "cuda" else False,
            drop_last=True
            if shuffle and len(dataset) % batch_size <= 3
            else False,
        )

    def _set_device(self):
        """
        Sets the device for the model and dataset.

        This method sets the device for the model and dataset to the device specified in the `device` attribute.

        Returns:
            None
        """
        if self.model is not None:
            self.model.to(self.device)

    def multiple_objective_selection(
        self, study: optuna.Study, corr_weight: float = 0.8
    ) -> optuna.trial.FrozenTrial:
        """
        Selects the best trial from a given Optuna study based on multiple objectives.

        Args:
            study (optuna.Study): The Optuna study object.
            corr_weight (float, optional): The weight for the relative correlation error. Defaults to 0.8.

        Returns:
            optuna.trial.FrozenTrial: The best trial.

        Raises:
            ValueError: If no trials are found in the study.
        """

        # Get the best trials from the study
        best_trials = study.best_trials

        if not best_trials:
            raise ValueError("No trials found in the study.")

        # Calculate the weighted sum of relative correlation error and loss for each trial
        weighted_scores = []
        for trial in best_trials:
            corr_error = trial.values[1]
            loss = trial.values[0]
            weighted_score = corr_weight * corr_error + (1 - corr_weight) * loss
            weighted_scores.append(weighted_score)

        # Find the index of the trial with the minimum weighted score
        best_index = weighted_scores.index(min(weighted_scores))

        # Select the best trial based on the weighted score
        best_trial = best_trials[best_index]
        logger.info(f"Selected trial: {best_trial.number}")

        return best_trial

    def hyperopt(self, study: optuna.Study, num_trials: int) -> Hyperparameters:
        """
        Perform hyperparameter optimization using Optuna.

        Args:
            study (optuna.Study): The Optuna study object.
            num_trials (int): The number of trials to run.

        Returns:
            Hyperparameters: The best hyperparameters found during optimization.

        Raises:
            ValueError: If no trials are found in the study.
        """
        with mlflow.start_run(run_name=f"{self.run_base}_hyperopt"):
            # Optimize the study
            study.optimize(self._objective, n_trials=num_trials)

            # Get the best trial parameters
            if self.config.optimization.use_relative_correlation_error_for_optimization:
                trial = self.multiple_objective_selection(study)
            else:
                trial = study.best_trial

            # Extract the best hyperparameters
            params = trial.params
            best_epoch = trial.user_attrs.get("best_epoch", None)
            params["epochs"] = best_epoch

            # Process the parameters
            params["batch_size"] = 2 ** params.pop("batch_size_n")
            if "hidden_dim_s" in params:
                params["dim_s"] = params.pop("hidden_dim_s")
            else:
                matching_keys = [k for k in params if "hidden_dim_s" in k]
                dim_s = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    dim_s[module_name] = params.pop(key)
                params["dim_s"] = dim_s
            if "hidden_dim_y" in params:
                params["dim_y"] = params.pop("hidden_dim_y")
            else:
                matching_keys = [k for k in params if "hidden_dim_y_" in k]
                dim_y = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    dim_y[module_name] = params.pop(key)
                params["dim_y"] = dim_y
            if "hidden_dim_ys" in params:
                params["dim_ys"] = params.pop("hidden_dim_ys")
            params["dim_z"] = params.pop("hidden_dim_z")
            if "learning_rate" not in params:
                matching_keys = [k for k in params if "learning_rate" in k]
                learning_rate = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    learning_rate[module_name] = params.pop(key)
                params["learning_rate"] = learning_rate

            # Create Hyperparameters object
            hyperparameters = Hyperparameters(dropout=0.1, **params)

            # Log hyperparameters to MLflow
            mlflow.log_params(hyperparameters.__dict__)

        return hyperparameters

    def cleanup_checkpoints(self):
        """
        Cleans up the checkpoints directory.

        This method deletes the entire checkpoints directory specified by `self.checkpoint_path`.

        Returns:
            None
        """
        delete_directory(self.checkpoint_path)

    def optimize_model(self, model: Optional[TModel] = None) -> TModel:
        """
        Optimizes the model using a specified optimization function.

        Args:
            model (Optional[TModel], optional): The model to optimize. If None, the method optimizes self.model. Defaults to None.

        Returns:
            TModel: The optimized model.

        Notes:
            - The optimization function is specified by the opt_func variable.
            - The opt_func function should take a model as input and return an optimized model.
            - If model is None, the method optimizes self.model.
            - If model is not None, the method optimizes the specified model.

        Raises:
            TypeError: If the model is not of type TModel.

        """
        # Define the optimization function
        opt_func = partial(torch.compile, mode="reduce-overhead")
        # opt_func = lambda model: model

        if model is None:
            # Optimize self.model
            self.model = opt_func(model=self.model)
            return self.model
        else:
            return opt_func(model=model)

    def _add_encs(
        self,
        encodings_s: Dict[str, torch.Tensor],
        encodings_z: Dict[str, torch.Tensor],
        meta_enc: Dict[str, np.ndarray],
        decoder_output: DecoderOutput,
    ) -> None:
        """
        Helper function to prepare the encodings from the decoder output.

        This function takes the decoder output and adds the s and z encodings to the respective dictionaries.
        It also updates the meta encoding dictionary with the s and z encodings.

        Args:
            encodings_s (Dict[str, torch.Tensor]): Dictionary of s encodings per module.
            encodings_z (Dict[str, torch.Tensor]): Dictionary of z encodings per module.
            meta_enc (Dict[str, np.ndarray]): Entire meta encoding.
            decoder_output (DecoderOutput): The decoder output.

        Returns:
            None
        """
        dict_name = decoder_output.output_name

        # Add s and z encodings to the respective dictionaries
        if dict_name in encodings_s:
            encodings_s[dict_name].append(decoder_output.enc_s)
            encodings_z[dict_name].append(decoder_output.enc_z)
        else:
            encodings_s[dict_name] = [decoder_output.enc_s]
            encodings_z[dict_name] = [decoder_output.enc_z]

        # Update meta encoding dictionary with s and z encodings
        if decoder_output.enc_s.shape[1] > 1:
            s_dist = torch.argmax(decoder_output.enc_s, dim=1)
            s_name = f"{decoder_output.output_name}_s"
            if s_name in meta_enc:
                meta_enc[s_name] = np.concatenate(
                    [meta_enc[s_name], s_dist.numpy()]
                )
            else:
                meta_enc[s_name] = s_dist.numpy()
        for i in range(decoder_output.enc_z.shape[1]):
            z_name = f"{decoder_output.output_name}_z{i}"
            if z_name in meta_enc:
                meta_enc[z_name] = np.concatenate(
                    [meta_enc[z_name], decoder_output.enc_z[:, i].numpy()]
                )
            else:
                meta_enc[z_name] = decoder_output.enc_z[:, i].numpy()

    def save_model(self, path: Path) -> None:
        """
        Save the model and its configuration to the specified path.

        Args:
            path (Path): The path to save the model.

        Returns:
            None
        """
        self.save_model_config(path / "config.pkl")
        torch.save(self.model.state_dict(), path / "model.bin")

    def load_model(self, path: Path) -> None:
        """
        Load the model and its configuration from the specified path.

        Args:
            path (Path): The path to load the model from.

        Returns:
            None
        """
        self.read_model_config(path / "config.pkl")
        self.model = self.init_model(self.model_config)
        state_dict = torch.load(path / "model.bin")
        try:
            self.model.load_state_dict(state_dict)
        except RuntimeError:
            state_dict = OrderedDict(
                {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
            )
            self.model.load_state_dict(state_dict)

    @abstractmethod
    def save_model_config(self, path: Path) -> None:
        """
        Save the model configuration to the specified path.

        Args:
            path (Path): The path to save the model configuration.

        Returns:
            None
        """
        pass

    @abstractmethod
    def read_model_config(self, path: Path) -> None:
        """
        Read the model configuration from the specified path.

        Args:
            path (Path): The path to read the model configuration from.

        Returns:
            None
        """
        pass

    def cv_generator(
        self, splits: int, seed: int = 42
    ) -> Generator[Tuple[VambnDataset, VambnDataset], None, None]:
        """
        Generates train and validation datasets for cross-validation.

        Args:
            splits (int): Number of splits.
            seed (int, optional): Seed for split. Defaults to 42.

        Raises:
            ValueError: Number of splits must be greater than 0.

        Yields:
            Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.
        """
        if splits < 1:
            raise ValueError("splits must be greater than 0")

        data_idx = np.arange(len(self.dataset))
        if splits == 1:
            train_idx, val_idx = train_test_split(
                data_idx, test_size=0.2, random_state=seed
            )
            yield (
                self.dataset.subset_by_idx(train_idx),
                self.dataset.subset_by_idx(val_idx),
            )
        else:
            cv = KFold(n_splits=splits, shuffle=True, random_state=seed)
            for train_idx, val_idx in cv.split(data_idx):
                if len(train_idx) == 0 or len(val_idx) == 0:
                    raise Exception("Empty split")
                yield (
                    self.dataset.subset_by_idx(train_idx),
                    self.dataset.subset_by_idx(val_idx),
                )

    @staticmethod
    def __lstm_forward_hook(module, input, output):
        # output is a tuple (output_tensor, (h_n, c_n))
        return output[0]  # We only want the output tensor for the next layer

    @staticmethod
    def setup_y_layer(
        z_dim: int, y_dim: int, dropout: float, n_layers: Optional[int] = None
    ) -> torch.nn.Module:
        if n_layers is None:
            return torch.nn.Sequential(
                # torch.nn.LayerNorm(z_dim),
                ModifiedLinear(z_dim, y_dim),
                torch.nn.Dropout(dropout),
            )
        else:
            layers = [
                # torch.nn.LayerNorm(z_dim),
                torch.nn.LSTM(
                    input_size=z_dim,
                    hidden_size=y_dim,
                    num_layers=n_layers,
                    dropout=dropout if n_layers > 1 else 0.0,
                    batch_first=True,
                ),
            ]

            layers[-1].register_forward_hook(BaseTrainer.__lstm_forward_hook)

            layers.append(torch.nn.Dropout(dropout))

            return torch.nn.Sequential(*layers)

    @abstractmethod
    def init_model(self, config: TConfig) -> TModel:
        pass

    @abstractmethod
    def train(self, best_parameters: Hyperparameters) -> TModel:
        pass

    def predict(self, dl: DataLoader | None) -> TPredict:
        if dl is None:
            data = self.dataset.get_iter_dataset(self.module_name)
            dl = self.get_dataloader(data, batch_size=128, shuffle=False)

        with torch.no_grad():
            return self.model.predict(dl)

    def decode(
        self,
        encoding: HivaeEncoding | ModularHivaeEncoding,
        use_mode: bool = True,
    ) -> Dict[str, torch.Tensor]:
        self.model.eval()
        self.model.decoding = use_mode

        return {self.module_name: self.model.decode(encoding)}

    def save_trainer(self, path: Path) -> None:
        path.parent.mkdir(parents=True, exist_ok=True)
        self.save_model(path)
        self.model = None
        torch.save(self, path / "trainer.pkl", pickle_module=dill)

        with (path / "pipeline-config.yml").open("w") as f:
            f.write(yaml.dump(self.config.__dict__))

    @classmethod
    def load_trainer(cls, path: Path) -> TrainerType:
        trainer = torch.load(path / "trainer.pkl", pickle_module=dill)
        trainer.load_model(path)
        return trainer

    @abstractmethod
    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        pass

    @abstractmethod
    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        pass

    def __handle_continous_data(
        self,
        original: pd.Series,
        decoded: pd.Series,
        output_file: Path,
        dtype: str,
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.boxplot(x="variable", y="value", data=df)
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)
        if orig_tensor.isnan().any() or orig_tensor.isinf().any():
            raise Exception("NaN values in original tensor")

        if dec_tensor.isnan().any() or dec_tensor.isinf().any():
            logger.warning("NaN values in decoded tensor")
            logger.warning(
                f"Found nan values: {dec_tensor.isnan().any()}, {dec_tensor.isnan().sum()}"
            )
            logger.warning(
                f"Found inf values: {dec_tensor.isinf().any()}, {dec_tensor.isinf().sum()}"
            )
            dec_tensor = dec_tensor.nan_to_num()
            dec_tensor[dec_tensor == float("inf")] = dec_tensor[
                dec_tensor != float("inf")
            ].max()
            logger.info(
                f"After replacing: {dec_tensor.isnan().any()}, {dec_tensor.isnan().sum()}"
            )
            logger.info(
                f"After replacing: {dec_tensor.isinf().any()}, {dec_tensor.isinf().sum()}"
            )

        error = float(
            nrmse(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        )
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, dtype)
        try:
            statistic, pval = pearsonr(
                decoded.values.tolist(), original.values.tolist()
            )
        except ValueError:
            statistic, pval = 0.0, 1.0
            logger.warning("ValueError in pearsonr")
        return error, jsd, statistic, pval

    def handle_pos_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        return self.__handle_continous_data(
            original, decoded, output_file, "pos"
        )

    def handle_real_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        return self.__handle_continous_data(
            original, decoded, output_file, "real"
        )

    def handle_count_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.histplot(
            data=df, x="value", hue="variable", stat="probability", bins=30
        )
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)

        nrmse_val = nrmse(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, "count")
        statistic, pval = pearsonr(
            decoded.astype(int).values, original.astype(int).values
        )
        return nrmse_val, jsd, statistic, pval

    def handle_categorical_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.histplot(data=df, x="value", hue="variable", stat="probability")
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)

        error = accuracy(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, "cat")
        statistic, pval = spearmanr(
            decoded.values.tolist(), original.values.tolist()
        )
        return error, jsd, statistic, pval

    @abstractmethod
    def process_encodings(self, predictions: TPredict) -> TEncoding:
        raise NotImplementedError

    @staticmethod
    def reverse_scale(x: Tensor, variable_types: VarTypes) -> Tensor:
        copied_x = x.clone()
        for i, var_type in enumerate(variable_types):
            if var_type.data_type in ["real", "pos", "truncate_norm", "gamma"]:
                copied_x[:, i] = var_type.reverse_scale(copied_x[:, i])
        return copied_x

    def evaluate(self, dl: DataLoader, output_path: Path):
        output_path.mkdir(parents=True, exist_ok=True)

        predictions = self.predict(dl)
        encoding = self.process_encodings(predictions)

        data_output_path = output_path / "data_outputs"
        data_output_path.mkdir(exist_ok=True, parents=True)
        encoding.save_meta_enc(data_output_path / "meta_enc.csv")

        modules = (
            (self.module_name,)
            if self.module_name is not None
            else self.dataset.module_names
        )
        overall_metrics = [None] * len(modules)
        for k, module_name in enumerate(modules):
            submodules = self.dataset.get_modules(module_name)
            data, mask = self.dataset.get_longitudinal_data(module_name)
            sampled_data = encoding.get_samples(module_name)
            if sampled_data.ndim == 2 and data.ndim == 3:
                sampled_data = sampled_data.unsqueeze(1)

            if data.shape != sampled_data.shape:
                raise Exception("Data and sampled data have different shapes")

            if data.shape != mask.shape:
                raise Exception("Data and mask have different shapes")

            if data.ndim == 2:
                data = data.unsqueeze(1)
                sampled_data = sampled_data.unsqueeze(1)
                mask = mask.unsqueeze(1)

            num_timepoints = data.shape[1]
            decoded_data = [None] * num_timepoints
            original_data = [None] * num_timepoints
            mask_data = [None] * num_timepoints
            colnames = [
                re.sub("_VIS[0-9]+", "", c) for c in submodules[0].columns
            ]

            for time_point in range(num_timepoints):
                data_df = pd.DataFrame(
                    self.reverse_scale(
                        data[:, time_point, :], submodules[0].variable_types
                    ),
                    columns=colnames,
                )
                data_df["SUBJID"] = self.dataset.subj
                data_df["VISIT"] = time_point + 1
                data_df.set_index(["SUBJID", "VISIT"], inplace=True)
                original_data[time_point] = data_df

                sampled_data_df = pd.DataFrame(
                    sampled_data[:, time_point, :],
                    columns=colnames,
                )
                sampled_data_df["SUBJID"] = self.dataset.subj
                sampled_data_df["VISIT"] = time_point + 1
                sampled_data_df.set_index(["SUBJID", "VISIT"], inplace=True)
                decoded_data[time_point] = sampled_data_df

                mask_df = pd.DataFrame(mask[:, time_point, :], columns=colnames)
                mask_df["SUBJID"] = self.dataset.subj
                mask_df["VISIT"] = time_point + 1
                mask_df.set_index(["SUBJID", "VISIT"], inplace=True)
                mask_data[time_point] = mask_df

            original_data = pd.concat(original_data)
            decoded_data = pd.concat(decoded_data)

            # assert that we have the same visit ids in the original and decoded data
            assert original_data.index.equals(decoded_data.index), (
                f"Original and decoded data have different indices: "
                f"{original_data.index} != {decoded_data.index}"
            )
            if decoded_data.isna().any().any():
                raise Exception("NaN in decoded data")
            mask_data = pd.concat(mask_data)

            decoded_data.to_csv(data_output_path / f"{module_name}_decoded.csv")
            mask_data.to_csv(data_output_path / f"{module_name}_mask.csv")
            original_data.to_csv(
                data_output_path / f"{module_name}_original.csv"
            )

            # Calculate metrics per column
            error = [None] * len(colnames)
            jsd = [None] * len(colnames)
            correlation_stat = [None] * len(colnames)
            correlation_pval = [None] * len(colnames)

            distribution_path = output_path / "distributions"
            distribution_path.mkdir(parents=True, exist_ok=True)

            for i, column in enumerate(colnames):
                orig_col = original_data[column]
                decoded_col = decoded_data[column]
                mask_col = mask_data[column]

                orig_avail = orig_col[mask_col == 1]
                decoded_avail = decoded_col[mask_col == 1]

                col_type = submodules[0].variable_types[i].data_type
                if col_type in ["real", "pos", "truncate_norm", "gamma"]:
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_real_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                elif col_type == "count":
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_count_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                elif col_type == "cat":
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_categorical_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                else:
                    raise ValueError(f"Unknown data type {col_type}")

            overall_metrics[k] = pd.DataFrame(
                {
                    "module_name": [module_name] * len(colnames),
                    "column": colnames,
                    "error": [float(x) for x in error],
                    "jsd": jsd,
                    "correlation_stat": correlation_stat,
                    "correlation_pval": correlation_pval,
                }
            )

        overall_metrics = pd.concat(overall_metrics)
        overall_metrics.to_csv(output_path / "overall_metrics.csv", index=False)
run_base: str cached property

Generates a base name for the training run. Used e.g. for MLflow run names.

Returns:

Name Type Description
str str

The base name for the Training run.

cleanup_checkpoints()

Cleans up the checkpoints directory.

This method deletes the entire checkpoints directory specified by self.checkpoint_path.

Returns:

Type Description

None

Source code in vambn/modelling/models/hivae/trainer.py
411
412
413
414
415
416
417
418
419
420
def cleanup_checkpoints(self):
    """
    Cleans up the checkpoints directory.

    This method deletes the entire checkpoints directory specified by `self.checkpoint_path`.

    Returns:
        None
    """
    delete_directory(self.checkpoint_path)
cv_generator(splits, seed=42)

Generates train and validation datasets for cross-validation.

Parameters:

Name Type Description Default
splits int

Number of splits.

required
seed int

Seed for split. Defaults to 42.

42

Raises:

Type Description
ValueError

Number of splits must be greater than 0.

Yields:

Type Description
VambnDataset

Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.

Source code in vambn/modelling/models/hivae/trainer.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def cv_generator(
    self, splits: int, seed: int = 42
) -> Generator[Tuple[VambnDataset, VambnDataset], None, None]:
    """
    Generates train and validation datasets for cross-validation.

    Args:
        splits (int): Number of splits.
        seed (int, optional): Seed for split. Defaults to 42.

    Raises:
        ValueError: Number of splits must be greater than 0.

    Yields:
        Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.
    """
    if splits < 1:
        raise ValueError("splits must be greater than 0")

    data_idx = np.arange(len(self.dataset))
    if splits == 1:
        train_idx, val_idx = train_test_split(
            data_idx, test_size=0.2, random_state=seed
        )
        yield (
            self.dataset.subset_by_idx(train_idx),
            self.dataset.subset_by_idx(val_idx),
        )
    else:
        cv = KFold(n_splits=splits, shuffle=True, random_state=seed)
        for train_idx, val_idx in cv.split(data_idx):
            if len(train_idx) == 0 or len(val_idx) == 0:
                raise Exception("Empty split")
            yield (
                self.dataset.subset_by_idx(train_idx),
                self.dataset.subset_by_idx(val_idx),
            )
get_dataloader(dataset, batch_size, shuffle)

Get a DataLoader object for the given dataset.

Parameters:

Name Type Description Default
dataset VambnDataset

The dataset to use.

required
batch_size int

The batch size.

required
shuffle bool

Whether to shuffle the data.

required

Returns:

Name Type Description
DataLoader DataLoader

The DataLoader object.

Raises:

Type Description
ValueError

If the dataset is empty.

Source code in vambn/modelling/models/hivae/trainer.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def get_dataloader(
    self, dataset: VambnDataset, batch_size: int, shuffle: bool
) -> DataLoader:
    """
    Get a DataLoader object for the given dataset.

    Args:
        dataset (VambnDataset): The dataset to use.
        batch_size (int): The batch size.
        shuffle (bool): Whether to shuffle the data.

    Returns:
        DataLoader: The DataLoader object.

    Raises:
        ValueError: If the dataset is empty.
    """
    # Set the number of workers to 0
    self.workers = 0

    # Create and return the DataLoader object
    return DataLoader(
        dataset.get_iter_dataset(self.module_name)
        if self.module_name is not None
        else dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=self.workers,
        # collate_fn=self.custom_collate,
        persistent_workers=True if self.workers > 0 else False,
        pin_memory=True if self.device.type == "cuda" else False,
        drop_last=True
        if shuffle and len(dataset) % batch_size <= 3
        else False,
    )
hyperopt(study, num_trials)

Perform hyperparameter optimization using Optuna.

Parameters:

Name Type Description Default
study Study

The Optuna study object.

required
num_trials int

The number of trials to run.

required

Returns:

Name Type Description
Hyperparameters Hyperparameters

The best hyperparameters found during optimization.

Raises:

Type Description
ValueError

If no trials are found in the study.

Source code in vambn/modelling/models/hivae/trainer.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def hyperopt(self, study: optuna.Study, num_trials: int) -> Hyperparameters:
    """
    Perform hyperparameter optimization using Optuna.

    Args:
        study (optuna.Study): The Optuna study object.
        num_trials (int): The number of trials to run.

    Returns:
        Hyperparameters: The best hyperparameters found during optimization.

    Raises:
        ValueError: If no trials are found in the study.
    """
    with mlflow.start_run(run_name=f"{self.run_base}_hyperopt"):
        # Optimize the study
        study.optimize(self._objective, n_trials=num_trials)

        # Get the best trial parameters
        if self.config.optimization.use_relative_correlation_error_for_optimization:
            trial = self.multiple_objective_selection(study)
        else:
            trial = study.best_trial

        # Extract the best hyperparameters
        params = trial.params
        best_epoch = trial.user_attrs.get("best_epoch", None)
        params["epochs"] = best_epoch

        # Process the parameters
        params["batch_size"] = 2 ** params.pop("batch_size_n")
        if "hidden_dim_s" in params:
            params["dim_s"] = params.pop("hidden_dim_s")
        else:
            matching_keys = [k for k in params if "hidden_dim_s" in k]
            dim_s = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                dim_s[module_name] = params.pop(key)
            params["dim_s"] = dim_s
        if "hidden_dim_y" in params:
            params["dim_y"] = params.pop("hidden_dim_y")
        else:
            matching_keys = [k for k in params if "hidden_dim_y_" in k]
            dim_y = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                dim_y[module_name] = params.pop(key)
            params["dim_y"] = dim_y
        if "hidden_dim_ys" in params:
            params["dim_ys"] = params.pop("hidden_dim_ys")
        params["dim_z"] = params.pop("hidden_dim_z")
        if "learning_rate" not in params:
            matching_keys = [k for k in params if "learning_rate" in k]
            learning_rate = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                learning_rate[module_name] = params.pop(key)
            params["learning_rate"] = learning_rate

        # Create Hyperparameters object
        hyperparameters = Hyperparameters(dropout=0.1, **params)

        # Log hyperparameters to MLflow
        mlflow.log_params(hyperparameters.__dict__)

    return hyperparameters
load_model(path)

Load the model and its configuration from the specified path.

Parameters:

Name Type Description Default
path Path

The path to load the model from.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def load_model(self, path: Path) -> None:
    """
    Load the model and its configuration from the specified path.

    Args:
        path (Path): The path to load the model from.

    Returns:
        None
    """
    self.read_model_config(path / "config.pkl")
    self.model = self.init_model(self.model_config)
    state_dict = torch.load(path / "model.bin")
    try:
        self.model.load_state_dict(state_dict)
    except RuntimeError:
        state_dict = OrderedDict(
            {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
        )
        self.model.load_state_dict(state_dict)
multiple_objective_selection(study, corr_weight=0.8)

Selects the best trial from a given Optuna study based on multiple objectives.

Parameters:

Name Type Description Default
study Study

The Optuna study object.

required
corr_weight float

The weight for the relative correlation error. Defaults to 0.8.

0.8

Returns:

Type Description
FrozenTrial

optuna.trial.FrozenTrial: The best trial.

Raises:

Type Description
ValueError

If no trials are found in the study.

Source code in vambn/modelling/models/hivae/trainer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def multiple_objective_selection(
    self, study: optuna.Study, corr_weight: float = 0.8
) -> optuna.trial.FrozenTrial:
    """
    Selects the best trial from a given Optuna study based on multiple objectives.

    Args:
        study (optuna.Study): The Optuna study object.
        corr_weight (float, optional): The weight for the relative correlation error. Defaults to 0.8.

    Returns:
        optuna.trial.FrozenTrial: The best trial.

    Raises:
        ValueError: If no trials are found in the study.
    """

    # Get the best trials from the study
    best_trials = study.best_trials

    if not best_trials:
        raise ValueError("No trials found in the study.")

    # Calculate the weighted sum of relative correlation error and loss for each trial
    weighted_scores = []
    for trial in best_trials:
        corr_error = trial.values[1]
        loss = trial.values[0]
        weighted_score = corr_weight * corr_error + (1 - corr_weight) * loss
        weighted_scores.append(weighted_score)

    # Find the index of the trial with the minimum weighted score
    best_index = weighted_scores.index(min(weighted_scores))

    # Select the best trial based on the weighted score
    best_trial = best_trials[best_index]
    logger.info(f"Selected trial: {best_trial.number}")

    return best_trial
optimize_model(model=None)

Optimizes the model using a specified optimization function.

Parameters:

Name Type Description Default
model Optional[TModel]

The model to optimize. If None, the method optimizes self.model. Defaults to None.

None

Returns:

Name Type Description
TModel TModel

The optimized model.

Notes
  • The optimization function is specified by the opt_func variable.
  • The opt_func function should take a model as input and return an optimized model.
  • If model is None, the method optimizes self.model.
  • If model is not None, the method optimizes the specified model.

Raises:

Type Description
TypeError

If the model is not of type TModel.

Source code in vambn/modelling/models/hivae/trainer.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def optimize_model(self, model: Optional[TModel] = None) -> TModel:
    """
    Optimizes the model using a specified optimization function.

    Args:
        model (Optional[TModel], optional): The model to optimize. If None, the method optimizes self.model. Defaults to None.

    Returns:
        TModel: The optimized model.

    Notes:
        - The optimization function is specified by the opt_func variable.
        - The opt_func function should take a model as input and return an optimized model.
        - If model is None, the method optimizes self.model.
        - If model is not None, the method optimizes the specified model.

    Raises:
        TypeError: If the model is not of type TModel.

    """
    # Define the optimization function
    opt_func = partial(torch.compile, mode="reduce-overhead")
    # opt_func = lambda model: model

    if model is None:
        # Optimize self.model
        self.model = opt_func(model=self.model)
        return self.model
    else:
        return opt_func(model=model)
read_model_config(path) abstractmethod

Read the model configuration from the specified path.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
551
552
553
554
555
556
557
558
559
560
561
562
@abstractmethod
def read_model_config(self, path: Path) -> None:
    """
    Read the model configuration from the specified path.

    Args:
        path (Path): The path to read the model configuration from.

    Returns:
        None
    """
    pass
save_model(path)

Save the model and its configuration to the specified path.

Parameters:

Name Type Description Default
path Path

The path to save the model.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
504
505
506
507
508
509
510
511
512
513
514
515
def save_model(self, path: Path) -> None:
    """
    Save the model and its configuration to the specified path.

    Args:
        path (Path): The path to save the model.

    Returns:
        None
    """
    self.save_model_config(path / "config.pkl")
    torch.save(self.model.state_dict(), path / "model.bin")
save_model_config(path) abstractmethod

Save the model configuration to the specified path.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
538
539
540
541
542
543
544
545
546
547
548
549
@abstractmethod
def save_model_config(self, path: Path) -> None:
    """
    Save the model configuration to the specified path.

    Args:
        path (Path): The path to save the model configuration.

    Returns:
        None
    """
    pass
Hyperparameters dataclass

Class representing the hyperparameters for the model trainer.

Parameters:

Name Type Description Default
dim_s int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the input sequence(s).

required
dim_y int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the output sequence(s).

required
dim_z int

Dimension of the latent space.

required
dropout float

Dropout rate.

required
batch_size int

Batch size.

required
learning_rate float | Dict[str, float] | Tuple[float, ...]

Learning rate(s).

required
epochs int

Number of training epochs.

required
mtl_methods Tuple[str, ...]

Multi-task learning methods. Defaults to ("identity",).

('identity')
lstm_layers int

Number of LSTM layers. Defaults to 1.

1
dim_ys Optional[int]

Dimension of the output sequence. Defaults to None.

None

Attributes:

Name Type Description
dim_s int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the input sequence(s).

dim_y int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the output sequence(s).

dim_z int

Dimension of the latent space.

dropout float

Dropout rate.

batch_size int

Batch size.

learning_rate float | Dict[str, float] | Tuple[float, ...]

Learning rate(s).

epochs int

Number of training epochs.

mtl_methods Tuple[str, ...]

Multi-task learning methods.

lstm_layers int

Number of LSTM layers.

dim_ys Optional[int]

Dimension of the output sequence.

Methods:

Name Description
__post_init__

Post-initialization method.

write_to_json

Path): Write the hyperparameters to a JSON file.

read_from_json

Path): Read the hyperparameters from a JSON file.

Source code in vambn/modelling/models/hivae/trainer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@dataclass
class Hyperparameters:
    """
    Class representing the hyperparameters for the model trainer.

    Args:
        dim_s (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the input sequence(s).
        dim_y (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the output sequence(s).
        dim_z (int): Dimension of the latent space.
        dropout (float): Dropout rate.
        batch_size (int): Batch size.
        learning_rate (float | Dict[str, float] | Tuple[float, ...]): Learning rate(s).
        epochs (int): Number of training epochs.
        mtl_methods (Tuple[str, ...], optional): Multi-task learning methods. Defaults to ("identity",).
        lstm_layers (int, optional): Number of LSTM layers. Defaults to 1.
        dim_ys (Optional[int], optional): Dimension of the output sequence. Defaults to None.

    Attributes:
        dim_s (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the input sequence(s).
        dim_y (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the output sequence(s).
        dim_z (int): Dimension of the latent space.
        dropout (float): Dropout rate.
        batch_size (int): Batch size.
        learning_rate (float | Dict[str, float] | Tuple[float, ...]): Learning rate(s).
        epochs (int): Number of training epochs.
        mtl_methods (Tuple[str, ...]): Multi-task learning methods.
        lstm_layers (int): Number of LSTM layers.
        dim_ys (Optional[int]): Dimension of the output sequence.

    Methods:
        __post_init__(self): Post-initialization method.
        write_to_json(self, path: Path): Write the hyperparameters to a JSON file.
        read_from_json(cls, path: Path): Read the hyperparameters from a JSON file.

    """

    dim_s: int | Dict[str, int] | Tuple[int, ...]
    dim_y: int | Dict[str, int] | Tuple[int, ...]
    dim_z: int
    dropout: float
    batch_size: int
    learning_rate: float | Dict[str, float] | Tuple[float, ...]
    epochs: int
    mtl_methods: Tuple[str, ...] = ("identity",)
    lstm_layers: int = 1
    dim_ys: Optional[int] = None

    def __post_init__(self):
        if isinstance(self.mtl_methods, List):
            self.mtl_methods = tuple(self.mtl_methods)
        elif isinstance(self.mtl_methods, str):
            self.mtl_methods = (self.mtl_methods,)

    def write_to_json(self, path: Path):
        """
        Write the hyperparameters to a JSON file.

        Args:
            path (Path): Path to the JSON file.

        """
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w") as f:
            json.dump(self.__dict__, f, indent=4)

    @classmethod
    def read_from_json(cls, path: Path):
        """
        Read the hyperparameters from a JSON file.

        Args:
            path (Path): Path to the JSON file.

        Returns:
            Hyperparameters: An instance of the Hyperparameters class.

        """
        with path.open("r") as f:
            data = json.load(f)

        tmp = cls(**data)
        # FIXME: fix model script to avoid this step
        # check if , is in mtl_methods
        logger.info(tmp.mtl_methods)
        if len(tmp.mtl_methods) == 1 and "," in tmp.mtl_methods[0]:
            method_str = tmp.mtl_methods[0]
            tmp.mtl_methods = tuple(method_str.split(","))
        logger.info(tmp.mtl_methods)
        logger.debug(tmp)
        return tmp
read_from_json(path) classmethod

Read the hyperparameters from a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file.

required

Returns:

Name Type Description
Hyperparameters

An instance of the Hyperparameters class.

Source code in vambn/modelling/models/hivae/trainer.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@classmethod
def read_from_json(cls, path: Path):
    """
    Read the hyperparameters from a JSON file.

    Args:
        path (Path): Path to the JSON file.

    Returns:
        Hyperparameters: An instance of the Hyperparameters class.

    """
    with path.open("r") as f:
        data = json.load(f)

    tmp = cls(**data)
    # FIXME: fix model script to avoid this step
    # check if , is in mtl_methods
    logger.info(tmp.mtl_methods)
    if len(tmp.mtl_methods) == 1 and "," in tmp.mtl_methods[0]:
        method_str = tmp.mtl_methods[0]
        tmp.mtl_methods = tuple(method_str.split(","))
    logger.info(tmp.mtl_methods)
    logger.debug(tmp)
    return tmp
write_to_json(path)

Write the hyperparameters to a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file.

required
Source code in vambn/modelling/models/hivae/trainer.py
144
145
146
147
148
149
150
151
152
153
154
def write_to_json(self, path: Path):
    """
    Write the hyperparameters to a JSON file.

    Args:
        path (Path): Path to the JSON file.

    """
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w") as f:
        json.dump(self.__dict__, f, indent=4)
ModularTrainer

Bases: BaseTrainer[GenericMHivaeConfig, GenericMHivaeModel, ModularHivaeOutput, 'ModularTrainer', ModularHivaeEncoding]

Source code in vambn/modelling/models/hivae/trainer.py
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
class ModularTrainer(
    BaseTrainer[
        GenericMHivaeConfig,
        GenericMHivaeModel,
        ModularHivaeOutput,
        "ModularTrainer",
        ModularHivaeEncoding,
    ]
):
    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: str | None = None,
        experiment_name: str | None = None,
        force_cpu: bool = False,
        shared_element: str = "none",
    ):
        """
        Initialize the ModularTrainer class.

        Args:
            dataset: The VambnDataset object.
            config: The PipelineConfig object.
            workers: The number of workers for data loading.
            checkpoint_path: The path to save checkpoints.
            module_name: The name of the module.
            experiment_name: The name of the experiment.
            force_cpu: Whether to force CPU usage.
            shared_element: The type of shared element.
        """
        super().__init__(
            dataset=dataset,
            config=config,
            workers=workers,
            checkpoint_path=checkpoint_path,
            module_name=module_name,
            experiment_name=experiment_name,
            force_cpu=force_cpu,
        )
        self.type = "modular"
        self.shared_element = shared_element

    def process_encodings(
        self, predictions: ModularHivaeOutput
    ) -> ModularHivaeEncoding:
        """
        Process the model predictions and return the encodings.

        Args:
            predictions: The model predictions.

        Returns:
            The ModularHivaeEncoding object.
        """
        return ModularHivaeEncoding(
            encodings=tuple(
                HivaeEncoding(
                    z=pred.enc_z,
                    s=pred.enc_s,
                    module=module_name,
                    samples=pred.samples,
                    subjid=self.dataset.subj,
                )
                for module_name, pred in zip(
                    self.dataset.module_names, predictions
                )
            ),
            modules=self.dataset.module_names,
        )

    def init_model(self, config: ModularHivaeConfig) -> ModularHivae:
        """
        Initialize the ModularHivae model.

        Args:
            config: The ModularHivaeConfig object.

        Returns:
            The initialized ModularHivae model.
        """
        return ModularHivae(
            dim_s=config.dim_s,
            dim_ys=config.dim_ys,
            dim_y=config.dim_y,
            dim_z=config.dim_z,
            module_config=config.module_config,
            mtl_method=config.mtl_method,
            shared_element_type=config.shared_element,
            use_imputation_layer=config.use_imputation_layer,
        )

    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        """
        Generate hyperparameters for the trial.

        Args:
            trial: The optuna.Trial object.

        Returns:
            The generated Hyperparameters object.
        """
        opt = self.config.optimization
        if opt.fixed_s_dim:
            # Use fixed dimension for s
            fixed_dim_s = trial.suggest_int(
                "hidden_dim_s",
                opt.s_dim_lower,
                opt.s_dim_upper,
                opt.s_dim_step,
            )
            dim_s = {
                module_name: fixed_dim_s
                for module_name in self.dataset.module_names
            }
        else:
            # Use different dimensions for s
            dim_s = {
                module_name: trial.suggest_int(
                    f"hidden_dim_s_{module_name}",
                    opt.s_dim_lower,
                    opt.s_dim_upper,
                    opt.s_dim_step,
                )
                for module_name in self.dataset.module_names
            }
        if opt.fixed_y_dim:
            # Use fixed dimension for y
            fixed_y = trial.suggest_int(
                "hidden_dim_y",
                opt.y_dim_lower,
                opt.y_dim_upper,
                opt.y_dim_step,
            )
            dim_y = {module_name: fixed_y for module_name in dim_s}
        else:
            # Use different dimensions for y
            dim_y = {
                module_name: trial.suggest_int(
                    f"hidden_dim_y_{module_name}",
                    opt.y_dim_lower,
                    opt.y_dim_upper,
                    opt.y_dim_step,
                )
                for module_name in self.dataset.module_names
            }
        dim_z = trial.suggest_int(
            "hidden_dim_z",
            opt.latent_dim_lower,
            opt.latent_dim_upper,
            opt.latent_dim_step,
        )
        dim_ys = trial.suggest_int(
            "hidden_dim_ys",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        if opt.fixed_learning_rate:
            lr = trial.suggest_float(
                "learning_rate",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )
            learning_rate = {
                module_name: lr for module_name in self.dataset.module_names
            }
            learning_rate["learning_rate_shared"] = lr
        else:
            learning_rate = {
                module_name: trial.suggest_float(
                    f"learning_rate_{module_name}",
                    opt.learning_rate_lower,
                    opt.learning_rate_upper,
                    log=True,
                )
                for module_name in self.dataset.module_names
            }
            learning_rate["learning_rate_shared"] = trial.suggest_float(
                "learning_rate_shared",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )

        batch_size_n = trial.suggest_int(
            "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
        )
        batch_size = 2**batch_size_n

        if self.dataset.is_longitudinal:
            lstm_layers = trial.suggest_int(
                "lstm_layers",
                opt.lstm_layers_lower,
                opt.lstm_layers_upper,
                opt.lstm_layers_step,
            )
        else:
            lstm_layers = 1
        if self.use_mtl:
            mtl_string = trial.suggest_categorical(
                "mtl_methods",
                [
                    "gradnorm",
                    "graddrop",
                    "gradnorm,graddrop",
                ],
            )
            mtl_methods = (
                tuple(mtl_string.split(","))
                if "," in mtl_string
                else tuple([mtl_string])
            )
        else:
            mtl_methods = ("identity",)
        return Hyperparameters(
            dim_s=dim_s,
            dim_y=dim_y,
            dropout=0.1,
            batch_size=batch_size,
            learning_rate=learning_rate,
            epochs=opt.max_epochs,
            lstm_layers=lstm_layers,
            mtl_methods=mtl_methods,
            dim_ys=dim_ys,
            dim_z=dim_z,
        )

    def get_module_config(
        self, hyperparameters: Hyperparameters
    ) -> Tuple[DataModuleConfig, ...]:
        """
        Get the module configuration based on the hyperparameters.

        Args:
            hyperparameters: The Hyperparameters object.

        Returns:
            A tuple of DataModuleConfig objects.
        """
        module_configs = []
        for module_name in self.dataset.module_names:
            module = self.dataset.get_modules(module_name)[0]
            module_config = DataModuleConfig(
                name=module_name,
                variable_types=module.variable_types,
                n_layers=hyperparameters.lstm_layers,
                num_timepoints=self.dataset.visits_per_module[module_name],
            )
            module_configs.append(module_config)
        return tuple(module_configs)

    def train(self, best_parameters: Hyperparameters) -> GenericMHivaeModel:
        whole_dataloader = self.get_dataloader(
            self.dataset, best_parameters.batch_size, shuffle=True
        )
        module_config = self.get_module_config(best_parameters)
        self.model_config = ModularHivaeConfig(
            module_config=module_config,
            dim_z=best_parameters.dim_z,
            dim_s=best_parameters.dim_s,
            dim_y=best_parameters.dim_y,
            dropout=best_parameters.dropout,
            mtl_method=best_parameters.mtl_methods,
            n_layers=best_parameters.lstm_layers,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dim_ys=best_parameters.dim_ys,
        )
        self.model = self.init_model(self.model_config)
        if isinstance(best_parameters.learning_rate, Dict):
            order = self.dataset.module_names + ["shared"]
            learning_rate = tuple(
                best_parameters.learning_rate[module_name]
                for module_name in order
            )
        else:
            learning_rate = best_parameters.learning_rate
        with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
            mlflow.log_params(best_parameters.__dict__)
            self.model.fit(
                train_dataloader=whole_dataloader,
                num_epochs=best_parameters.epochs,
                learning_rate=learning_rate,
            )
        return self.model

    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        """
        Objective function for the optimization process.

        Args:
            trial (optuna.Trial): The trial object for hyperparameter optimization.

        Returns:
            float | Tuple[float, float]: The loss value or a tuple of loss values.
        """

        trial_params = self.hyperparameters(trial)
        logger.info(f"Trial parameters: {trial_params}")
        fold_loss = []
        n_epochs = []
        rel_corr_loss = []
        auc_metric = []
        module_config = self.get_module_config(trial_params)
        model_config = ModularHivaeConfig(
            module_config=module_config,
            dim_z=trial_params.dim_z,
            dim_y=trial_params.dim_y,
            dropout=trial_params.dropout,
            mtl_method=trial_params.mtl_methods,
            n_layers=trial_params.lstm_layers,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dim_s=trial_params.dim_s,
            shared_element=self.shared_element,
            dim_ys=trial_params.dim_ys,
        )

        if isinstance(trial_params.learning_rate, Dict):
            order = self.dataset.module_names + ["learning_rate_shared"]
            learning_rate = tuple(
                trial_params.learning_rate[module_name] for module_name in order
            )
        else:
            learning_rate = trial_params.learning_rate
        for i, (train_data, val_data) in enumerate(
            self.cv_generator(self.config.optimization.folds)
        ):
            train_dataloader = self.get_dataloader(
                train_data, trial_params.batch_size, shuffle=True
            )
            val_dataloader = self.get_dataloader(
                val_data, trial_params.batch_size, shuffle=False
            )
            model = self.init_model(model_config)
            with mlflow.start_run(
                run_name=f"{self.run_base}_T{trial._trial_id}_F{i}",
                nested=True,
            ):
                mlflow.log_params(trial_params.__dict__)
                start = time.time()
                try:
                    fit_loss, fit_epoch = model.fit(
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        num_epochs=trial_params.epochs,
                        learning_rate=learning_rate,
                    )
                except ValueError:
                    raise optuna.TrialPruned(
                        "Trial pruned due to error during training. Likely due to unsuitable hyperparameters"
                    )

                if ((time.time() - start) / 3600) > 2:
                    logger.warning(
                        f"Trial pruned due to very long execution ({(time.time() - start) / 3600}h at fold {i})"
                    )
                    raise optuna.TrialPruned()

                res = model.predict(val_dataloader)
                mlflow.log_metric("Final loss", res.avg_loss)
                logger.info(f"Fold {i} loss: {res.avg_loss}")

                fold_loss.append(res.avg_loss)
                n_epochs.append(fit_epoch)

                if (
                    self.config.optimization.use_relative_correlation_error_for_optimization
                    or self.config.optimization.use_auc_for_optimization
                ):
                    # Calculate the relative correlation loss
                    assert len(res.outputs) == len(model.module_configs)

                    orig_data = val_data.to_pandas()
                    module_dfs = []
                    for module_output, conf in zip(
                        res.outputs, model.module_configs
                    ):
                        module_columns = self.dataset.get_modules(conf.name)[
                            0
                        ].columns
                        cleaned_columns = [
                            re.sub(r"_VIS[0-9]+", "", c) for c in module_columns
                        ]
                        module_samples = module_output.samples
                        if module_samples.ndim == 2:
                            module_samples = module_samples.unsqueeze(1)

                        module_data = pd.concat(
                            [
                                pd.DataFrame(
                                    module_samples[:, i, :].numpy(),
                                    columns=cleaned_columns,
                                )
                                for i in range(module_samples.shape[1])
                            ]
                        )
                        module_data["SUBJID"] = (
                            val_data.subj * conf.num_timepoints
                        )
                        module_data["VISIT"] = np.repeat(
                            np.arange(1, module_samples.shape[1] + 1),
                            len(val_data),
                        )
                        module_dfs.append(module_data)

                    decoded_data = reduce(
                        lambda x, y: pd.merge(
                            x, y, on=["SUBJID", "VISIT"], how="outer"
                        ),
                        module_dfs,
                    )
                    decoded_data_filtered = decoded_data.loc[
                        :, orig_data.columns
                    ].drop(columns=["SUBJID", "VISIT"])
                    orig_data_filtered = orig_data.drop(
                        columns=["SUBJID", "VISIT"]
                    )

                    fold_rel_corr_loss, m1, m2 = RelativeCorrelation.error(
                        real=orig_data_filtered,
                        synthetic=decoded_data_filtered,
                    )
                    mlflow.log_metric(
                        "Relative correlation loss", fold_rel_corr_loss
                    )
                    rel_corr_loss.append(fold_rel_corr_loss)

                    auc = get_auc(orig_data_filtered, decoded_data_filtered)
                    if auc < 0.5:
                        auc = 1 - auc
                    auc_quality = max(math.floor((1 - auc) * 200), 1)
                    mlflow.log_metric("AUC", auc)
                    mlflow.log_metric("AUC quality", auc_quality)
                    auc_metric.append(auc_quality)

        loss = np.mean(fold_loss)
        best_epoch = n_epochs[np.argmin(fold_loss)]
        mlflow.log_metric("Best epoch", best_epoch)
        trial.set_user_attr("best_epoch", best_epoch)
        trial.set_user_attr("n_epochs", n_epochs)

        if self.config.optimization.use_relative_correlation_error_for_optimization:
            rel_corr_loss = np.mean(rel_corr_loss)
            logger.info(f"Trial loss: {loss}; rel_corr_loss: {rel_corr_loss}")
            return loss, rel_corr_loss
        elif self.config.optimization.use_auc_for_optimization:
            return np.mean(auc_metric)
        else:
            logger.info(f"Trial loss: {loss}")
            return loss

    def predict(self, dl: DataLoader) -> ModularHivaeOutput:
        """
        Predict the output of the model.

        Args:
            dl (DataLoader): The DataLoader object.

        Returns:
            ModularHivaeOutput: The output of the model.
        """
        with torch.no_grad():
            return self.model.predict(dl)

    def save_model_config(self, path: Path):
        """
        Save the model configuration to a file.

        Args:
            path (Path): The path to save the model configuration.

        Raises:
            Exception: If the model is None.
            Exception: If the model config is None.
        """
        if self.model is None:
            raise Exception("Model should not be none")

        if self.model_config is None:
            raise Exception("Model config should not be none")

        with path.open("wb") as f:
            dill.dump(self.model_config, f)

    def read_model_config(self, path: Path):
        """
        Read the model configuration from a file.

        Args:
            path (Path): The path to read the model configuration from.
        """
        with path.open("rb") as f:
            self.model_config = dill.load(f)

    def decode(
        self,
        encoding: Union[HivaeEncoding, ModularHivaeEncoding],
        use_mode: bool = True,
    ) -> Dict[str, Tensor]:
        """
        Decode the given encoding to obtain the sampled data.

        Args:
            encoding (Union[HivaeEncoding, ModularHivaeEncoding]): The encoding to decode.
            use_mode (bool, optional): Whether to use the mode for decoding. Defaults to True.

        Returns:
            Dict[str, Tensor]: The decoded sampled data, with module names as keys.
        """
        self.model.eval()
        self.model.decoding = use_mode
        with torch.no_grad():
            sampled_data = self.model.decode(encoding)

        sample_dict = {
            module_name: sampled_data[i]
            for i, module_name in enumerate(self.dataset.module_names)
        }
        return sample_dict
__init__(dataset, config, workers, checkpoint_path, module_name=None, experiment_name=None, force_cpu=False, shared_element='none')

Initialize the ModularTrainer class.

Parameters:

Name Type Description Default
dataset VambnDataset

The VambnDataset object.

required
config PipelineConfig

The PipelineConfig object.

required
workers int

The number of workers for data loading.

required
checkpoint_path Path

The path to save checkpoints.

required
module_name str | None

The name of the module.

None
experiment_name str | None

The name of the experiment.

None
force_cpu bool

Whether to force CPU usage.

False
shared_element str

The type of shared element.

'none'
Source code in vambn/modelling/models/hivae/trainer.py
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
def __init__(
    self,
    dataset: VambnDataset,
    config: PipelineConfig,
    workers: int,
    checkpoint_path: Path,
    module_name: str | None = None,
    experiment_name: str | None = None,
    force_cpu: bool = False,
    shared_element: str = "none",
):
    """
    Initialize the ModularTrainer class.

    Args:
        dataset: The VambnDataset object.
        config: The PipelineConfig object.
        workers: The number of workers for data loading.
        checkpoint_path: The path to save checkpoints.
        module_name: The name of the module.
        experiment_name: The name of the experiment.
        force_cpu: Whether to force CPU usage.
        shared_element: The type of shared element.
    """
    super().__init__(
        dataset=dataset,
        config=config,
        workers=workers,
        checkpoint_path=checkpoint_path,
        module_name=module_name,
        experiment_name=experiment_name,
        force_cpu=force_cpu,
    )
    self.type = "modular"
    self.shared_element = shared_element
decode(encoding, use_mode=True)

Decode the given encoding to obtain the sampled data.

Parameters:

Name Type Description Default
encoding Union[HivaeEncoding, ModularHivaeEncoding]

The encoding to decode.

required
use_mode bool

Whether to use the mode for decoding. Defaults to True.

True

Returns:

Type Description
Dict[str, Tensor]

Dict[str, Tensor]: The decoded sampled data, with module names as keys.

Source code in vambn/modelling/models/hivae/trainer.py
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
def decode(
    self,
    encoding: Union[HivaeEncoding, ModularHivaeEncoding],
    use_mode: bool = True,
) -> Dict[str, Tensor]:
    """
    Decode the given encoding to obtain the sampled data.

    Args:
        encoding (Union[HivaeEncoding, ModularHivaeEncoding]): The encoding to decode.
        use_mode (bool, optional): Whether to use the mode for decoding. Defaults to True.

    Returns:
        Dict[str, Tensor]: The decoded sampled data, with module names as keys.
    """
    self.model.eval()
    self.model.decoding = use_mode
    with torch.no_grad():
        sampled_data = self.model.decode(encoding)

    sample_dict = {
        module_name: sampled_data[i]
        for i, module_name in enumerate(self.dataset.module_names)
    }
    return sample_dict
get_module_config(hyperparameters)

Get the module configuration based on the hyperparameters.

Parameters:

Name Type Description Default
hyperparameters Hyperparameters

The Hyperparameters object.

required

Returns:

Type Description
Tuple[DataModuleConfig, ...]

A tuple of DataModuleConfig objects.

Source code in vambn/modelling/models/hivae/trainer.py
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
def get_module_config(
    self, hyperparameters: Hyperparameters
) -> Tuple[DataModuleConfig, ...]:
    """
    Get the module configuration based on the hyperparameters.

    Args:
        hyperparameters: The Hyperparameters object.

    Returns:
        A tuple of DataModuleConfig objects.
    """
    module_configs = []
    for module_name in self.dataset.module_names:
        module = self.dataset.get_modules(module_name)[0]
        module_config = DataModuleConfig(
            name=module_name,
            variable_types=module.variable_types,
            n_layers=hyperparameters.lstm_layers,
            num_timepoints=self.dataset.visits_per_module[module_name],
        )
        module_configs.append(module_config)
    return tuple(module_configs)
hyperparameters(trial)

Generate hyperparameters for the trial.

Parameters:

Name Type Description Default
trial Trial

The optuna.Trial object.

required

Returns:

Type Description
Hyperparameters

The generated Hyperparameters object.

Source code in vambn/modelling/models/hivae/trainer.py
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
    """
    Generate hyperparameters for the trial.

    Args:
        trial: The optuna.Trial object.

    Returns:
        The generated Hyperparameters object.
    """
    opt = self.config.optimization
    if opt.fixed_s_dim:
        # Use fixed dimension for s
        fixed_dim_s = trial.suggest_int(
            "hidden_dim_s",
            opt.s_dim_lower,
            opt.s_dim_upper,
            opt.s_dim_step,
        )
        dim_s = {
            module_name: fixed_dim_s
            for module_name in self.dataset.module_names
        }
    else:
        # Use different dimensions for s
        dim_s = {
            module_name: trial.suggest_int(
                f"hidden_dim_s_{module_name}",
                opt.s_dim_lower,
                opt.s_dim_upper,
                opt.s_dim_step,
            )
            for module_name in self.dataset.module_names
        }
    if opt.fixed_y_dim:
        # Use fixed dimension for y
        fixed_y = trial.suggest_int(
            "hidden_dim_y",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        dim_y = {module_name: fixed_y for module_name in dim_s}
    else:
        # Use different dimensions for y
        dim_y = {
            module_name: trial.suggest_int(
                f"hidden_dim_y_{module_name}",
                opt.y_dim_lower,
                opt.y_dim_upper,
                opt.y_dim_step,
            )
            for module_name in self.dataset.module_names
        }
    dim_z = trial.suggest_int(
        "hidden_dim_z",
        opt.latent_dim_lower,
        opt.latent_dim_upper,
        opt.latent_dim_step,
    )
    dim_ys = trial.suggest_int(
        "hidden_dim_ys",
        opt.y_dim_lower,
        opt.y_dim_upper,
        opt.y_dim_step,
    )
    if opt.fixed_learning_rate:
        lr = trial.suggest_float(
            "learning_rate",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )
        learning_rate = {
            module_name: lr for module_name in self.dataset.module_names
        }
        learning_rate["learning_rate_shared"] = lr
    else:
        learning_rate = {
            module_name: trial.suggest_float(
                f"learning_rate_{module_name}",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )
            for module_name in self.dataset.module_names
        }
        learning_rate["learning_rate_shared"] = trial.suggest_float(
            "learning_rate_shared",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )

    batch_size_n = trial.suggest_int(
        "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
    )
    batch_size = 2**batch_size_n

    if self.dataset.is_longitudinal:
        lstm_layers = trial.suggest_int(
            "lstm_layers",
            opt.lstm_layers_lower,
            opt.lstm_layers_upper,
            opt.lstm_layers_step,
        )
    else:
        lstm_layers = 1
    if self.use_mtl:
        mtl_string = trial.suggest_categorical(
            "mtl_methods",
            [
                "gradnorm",
                "graddrop",
                "gradnorm,graddrop",
            ],
        )
        mtl_methods = (
            tuple(mtl_string.split(","))
            if "," in mtl_string
            else tuple([mtl_string])
        )
    else:
        mtl_methods = ("identity",)
    return Hyperparameters(
        dim_s=dim_s,
        dim_y=dim_y,
        dropout=0.1,
        batch_size=batch_size,
        learning_rate=learning_rate,
        epochs=opt.max_epochs,
        lstm_layers=lstm_layers,
        mtl_methods=mtl_methods,
        dim_ys=dim_ys,
        dim_z=dim_z,
    )
init_model(config)

Initialize the ModularHivae model.

Parameters:

Name Type Description Default
config ModularHivaeConfig

The ModularHivaeConfig object.

required

Returns:

Type Description
ModularHivae

The initialized ModularHivae model.

Source code in vambn/modelling/models/hivae/trainer.py
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
def init_model(self, config: ModularHivaeConfig) -> ModularHivae:
    """
    Initialize the ModularHivae model.

    Args:
        config: The ModularHivaeConfig object.

    Returns:
        The initialized ModularHivae model.
    """
    return ModularHivae(
        dim_s=config.dim_s,
        dim_ys=config.dim_ys,
        dim_y=config.dim_y,
        dim_z=config.dim_z,
        module_config=config.module_config,
        mtl_method=config.mtl_method,
        shared_element_type=config.shared_element,
        use_imputation_layer=config.use_imputation_layer,
    )
predict(dl)

Predict the output of the model.

Parameters:

Name Type Description Default
dl DataLoader

The DataLoader object.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The output of the model.

Source code in vambn/modelling/models/hivae/trainer.py
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
def predict(self, dl: DataLoader) -> ModularHivaeOutput:
    """
    Predict the output of the model.

    Args:
        dl (DataLoader): The DataLoader object.

    Returns:
        ModularHivaeOutput: The output of the model.
    """
    with torch.no_grad():
        return self.model.predict(dl)
process_encodings(predictions)

Process the model predictions and return the encodings.

Parameters:

Name Type Description Default
predictions ModularHivaeOutput

The model predictions.

required

Returns:

Type Description
ModularHivaeEncoding

The ModularHivaeEncoding object.

Source code in vambn/modelling/models/hivae/trainer.py
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
def process_encodings(
    self, predictions: ModularHivaeOutput
) -> ModularHivaeEncoding:
    """
    Process the model predictions and return the encodings.

    Args:
        predictions: The model predictions.

    Returns:
        The ModularHivaeEncoding object.
    """
    return ModularHivaeEncoding(
        encodings=tuple(
            HivaeEncoding(
                z=pred.enc_z,
                s=pred.enc_s,
                module=module_name,
                samples=pred.samples,
                subjid=self.dataset.subj,
            )
            for module_name, pred in zip(
                self.dataset.module_names, predictions
            )
        ),
        modules=self.dataset.module_names,
    )
read_model_config(path)

Read the model configuration from a file.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required
Source code in vambn/modelling/models/hivae/trainer.py
1857
1858
1859
1860
1861
1862
1863
1864
1865
def read_model_config(self, path: Path):
    """
    Read the model configuration from a file.

    Args:
        path (Path): The path to read the model configuration from.
    """
    with path.open("rb") as f:
        self.model_config = dill.load(f)
save_model_config(path)

Save the model configuration to a file.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required

Raises:

Type Description
Exception

If the model is None.

Exception

If the model config is None.

Source code in vambn/modelling/models/hivae/trainer.py
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
def save_model_config(self, path: Path):
    """
    Save the model configuration to a file.

    Args:
        path (Path): The path to save the model configuration.

    Raises:
        Exception: If the model is None.
        Exception: If the model config is None.
    """
    if self.model is None:
        raise Exception("Model should not be none")

    if self.model_config is None:
        raise Exception("Model config should not be none")

    with path.open("wb") as f:
        dill.dump(self.model_config, f)
TraditionalGanTrainer

Bases: TraditionalTrainer[HivaeConfig, GanHivae]

Source code in vambn/modelling/models/hivae/trainer.py
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
class TraditionalGanTrainer(TraditionalTrainer[HivaeConfig, GanHivae]):
    def init_model(self, config: HivaeConfig) -> GanHivae:
        """
        Initializes the GAN-HIVAE model.

        Args:
            config (HivaeConfig): The configuration for the GAN-HIVAE model.

        Returns:
            GanHivae: The initialized GAN-HIVAE model.
        """
        return GanHivae(
            variable_types=config.variable_types,
            input_dim=config.input_dim,
            dim_s=config.dim_s,
            dim_y=config.dim_y,
            dim_z=config.dim_z,
            module_name=config.name,
            mtl_method=config.mtl_methods,
            use_imputation_layer=config.use_imputation_layer,
            individual_model=True,
            n_layers=config.n_layers,
            num_timepoints=config.num_timepoints,
            noise_size=10,
        )
init_model(config)

Initializes the GAN-HIVAE model.

Parameters:

Name Type Description Default
config HivaeConfig

The configuration for the GAN-HIVAE model.

required

Returns:

Name Type Description
GanHivae GanHivae

The initialized GAN-HIVAE model.

Source code in vambn/modelling/models/hivae/trainer.py
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
def init_model(self, config: HivaeConfig) -> GanHivae:
    """
    Initializes the GAN-HIVAE model.

    Args:
        config (HivaeConfig): The configuration for the GAN-HIVAE model.

    Returns:
        GanHivae: The initialized GAN-HIVAE model.
    """
    return GanHivae(
        variable_types=config.variable_types,
        input_dim=config.input_dim,
        dim_s=config.dim_s,
        dim_y=config.dim_y,
        dim_z=config.dim_z,
        module_name=config.name,
        mtl_method=config.mtl_methods,
        use_imputation_layer=config.use_imputation_layer,
        individual_model=True,
        n_layers=config.n_layers,
        num_timepoints=config.num_timepoints,
        noise_size=10,
    )
TraditionalTrainer

Bases: BaseTrainer[GenericHivaeConfig, GenericHivaeModel, HivaeOutput, 'TraditionalTrainer', HivaeEncoding]

Source code in vambn/modelling/models/hivae/trainer.py
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
class TraditionalTrainer(
    BaseTrainer[
        GenericHivaeConfig,
        GenericHivaeModel,
        HivaeOutput,
        "TraditionalTrainer",
        HivaeEncoding,
    ]
):
    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: str | None = None,
        experiment_name: str | None = None,
        force_cpu: bool = False,
    ):
        super().__init__(
            dataset=dataset,
            config=config,
            workers=workers,
            checkpoint_path=checkpoint_path,
            module_name=module_name,
            experiment_name=experiment_name,
            force_cpu=force_cpu,
        )
        self.type = "traditional"
        if module_name is None:
            raise ValueError("Module name must be specified")
        elif module_name not in self.dataset.module_names:
            raise ValueError(
                f"Module name {module_name} not in dataset, available modules: {self.dataset.module_names}"
            )

    def process_encodings(self, predictions: HivaeOutput) -> HivaeEncoding:
        return HivaeEncoding(
            z=predictions.enc_z,
            s=predictions.enc_s,
            module=self.module_name,
            samples=predictions.samples,
            subjid=self.dataset.subj,
        )

    def init_model(self, config: GenericHivaeConfig) -> GenericHivaeModel:
        if config.is_longitudinal:
            return LstmHivae(
                n_layers=config.n_layers,
                num_timepoints=config.num_timepoints,
                dim_s=config.dim_s,
                dim_y=config.dim_y,
                dim_z=config.dim_z,
                input_dim=config.input_dim,
                module_name=self.module_name,
                mtl_method=config.mtl_methods,
                use_imputation_layer=config.use_imputation_layer,
                variable_types=config.variable_types,
                individual_model=True,
            )
        else:
            return Hivae(
                variable_types=config.variable_types,
                input_dim=config.input_dim,
                dim_s=config.dim_s,
                dim_y=config.dim_y,
                dim_z=config.dim_z,
                module_name=self.module_name,
                mtl_method=config.mtl_methods,
                use_imputation_layer=config.use_imputation_layer,
                individual_model=True,
            )

    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        """
        Function to suggest hyperparameters for the model

        Args:
            trial (optuna.Trial): Trial instance

        Returns:
            Dict[str, Any]: Suggested hyperparameters
        """
        opt = self.config.optimization
        dim_s = trial.suggest_int(
            "hidden_dim_s",
            opt.s_dim_lower,
            opt.s_dim_upper,
            opt.s_dim_step,
        )
        dim_y = trial.suggest_int(
            "hidden_dim_y",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        dim_z = trial.suggest_int(
            "hidden_dim_z",
            opt.latent_dim_lower,
            opt.latent_dim_upper,
            opt.latent_dim_step,
        )
        learning_rate = trial.suggest_float(
            "learning_rate",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )
        batch_size_n = trial.suggest_int(
            "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
        )
        batch_size = 2**batch_size_n
        # epochs = trial.suggest_int(
        #     "epochs",
        #     low=opt.epoch_lower,
        #     high=opt.epoch_upper,
        #     step=opt.epoch_step,
        # )

        if self.dataset.is_longitudinal:
            lstm_layers = trial.suggest_int(
                "lstm_layers",
                opt.lstm_layers_lower,
                opt.lstm_layers_upper,
                opt.lstm_layers_step,
            )
        else:
            lstm_layers = 1
        if self.use_mtl:
            mtl_string = trial.suggest_categorical(
                "mtl_methods",
                [
                    "gradnorm",
                    "graddrop",
                    "gradnorm,graddrop",
                ],
            )
            mtl_methods = (
                tuple(mtl_string.split(","))
                if "," in mtl_string
                else tuple([mtl_string])
            )
        else:
            mtl_methods = ("identity",)
        return Hyperparameters(
            dim_s=dim_s,
            dim_y=dim_y,
            dim_z=dim_z,
            dropout=0.1,
            batch_size=batch_size,
            learning_rate=learning_rate,
            epochs=opt.max_epochs,
            lstm_layers=lstm_layers,
            mtl_methods=mtl_methods,
        )

    def train(self, best_parameters: Hyperparameters) -> GenericHivaeModel:
        """
        Trains the model using the best hyperparameters.

        Args:
            best_parameters (Hyperparameters): The best hyperparameters obtained from optimization.

        Returns:
            GenericHivaeModel: The trained model.
        """
        ref_module = self.dataset.get_modules(self.module_name)[0]
        whole_dataloader = self.get_dataloader(
            self.dataset, best_parameters.batch_size, shuffle=True
        )
        self.model_config = HivaeConfig(
            name=self.module_name,
            variable_types=ref_module.variable_types,
            dim_s=best_parameters.dim_s,
            dim_y=best_parameters.dim_y,
            dim_z=best_parameters.dim_z,
            mtl_methods=best_parameters.mtl_methods,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dropout=best_parameters.dropout,
            n_layers=best_parameters.lstm_layers,
            num_timepoints=self.dataset.visits_per_module[self.module_name],
        )
        self.model = self.init_model(self.model_config).to(self.device)
        self.model.device = self.device
        self.optimize_model()
        with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
            mlflow.log_params(best_parameters.__dict__)
            self.model.fit(
                train_dataloader=whole_dataloader,
                num_epochs=best_parameters.epochs,
                learning_rate=best_parameters.learning_rate,
            )
        return self.model

    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        """
        Objective function for the optimization process.

        Args:
            trial (optuna.Trial): The trial object for hyperparameter optimization.

        Returns:
            float or Tuple[float, float]: The loss value or a tuple of loss values.
        """
        trial_params = self.hyperparameters(trial)
        logger.info(f"Trial parameters: {trial_params}")
        fold_loss = []
        rel_corr_loss = []
        auc_metric = []
        ref_module = self.dataset.get_modules(self.module_name)[0]
        n_epochs = []

        for i, (train_data, val_data) in enumerate(
            self.cv_generator(self.config.optimization.folds)
        ):
            train_dataloader = self.get_dataloader(
                train_data, trial_params.batch_size, shuffle=True
            )
            val_dataloader = self.get_dataloader(
                val_data, trial_params.batch_size, shuffle=False
            )

            model_config = HivaeConfig(
                name=self.module_name,
                variable_types=ref_module.variable_types,
                dim_s=trial_params.dim_s,
                dim_y=trial_params.dim_y,
                dim_z=trial_params.dim_z,
                mtl_methods=trial_params.mtl_methods,
                use_imputation_layer=self.config.training.use_imputation_layer,
                dropout=trial_params.dropout,
                n_layers=trial_params.lstm_layers,
                num_timepoints=self.dataset.visits_per_module[self.module_name],
            )
            raw_model = self.init_model(model_config)
            raw_model.device = self.device
            model = self.optimize_model(raw_model)
            model.to(self.device)
            with mlflow.start_run(
                run_name=f"{self.run_base}_T{trial._trial_id}_F{i}",
                nested=True,
            ):
                mlflow.log_params(trial_params.__dict__)
                start = time.time()
                try:
                    fit_loss, fit_epoch = model.fit(
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        num_epochs=trial_params.epochs,
                        learning_rate=trial_params.learning_rate,
                    )
                except ValueError:
                    raise optuna.TrialPruned(
                        "Trial pruned due to error during training. Likely due to unsuitable hyperparameters"
                    )

                if ((time.time() - start) / 3600) > 2:
                    logger.warning(
                        f"Trial pruned due to very long execution ({(time.time() - start) / 3600}h at fold {i})"
                    )
                    raise optuna.TrialPruned()
                output = model.predict(val_dataloader)

                mlflow.log_metric("Final loss", output.avg_loss)
                logger.info(f"Fold {i} loss: {output.avg_loss}")
                fold_loss.append(output.avg_loss)
                n_epochs.append(fit_epoch)

                if (
                    self.config.optimization.use_relative_correlation_error_for_optimization
                    or self.config.optimization.use_auc_for_optimization
                ):
                    # Calculate the relative correlation loss
                    orig_data = val_data.to_pandas(module_name=self.module_name)
                    # convert decoded samples into a pandas dataframe
                    decoded_samples = output.samples
                    assert decoded_samples.shape[-1] == (
                        orig_data.shape[-1] - 2
                    )
                    column_names = [
                        re.sub(r"_VIS[0-9]+", "", c)
                        for c in self.dataset.get_modules(self.module_name)[
                            0
                        ].columns
                    ]
                    if decoded_samples.ndim == 2:
                        synthetic_data = pd.DataFrame(
                            decoded_samples.numpy(),
                            columns=column_names,
                        )
                        synthetic_data["SUBJID"] = val_data.subj
                        synthetic_data["VISIT"] = 1
                    else:
                        synthetic_data = pd.concat(
                            [
                                pd.DataFrame(
                                    decoded_samples[:, i, :].numpy(),
                                    columns=column_names,
                                )
                                for i in range(decoded_samples.shape[1])
                            ]
                        )
                        synthetic_data["SUBJID"] = (
                            val_data.subj
                            * self.dataset.visits_per_module[self.module_name]
                        )
                        synthetic_data["VISIT"] = np.repeat(
                            np.arange(1, decoded_samples.shape[1] + 1),
                            len(val_data),
                        )

                    synthetic_data_filtered = synthetic_data.loc[
                        :, orig_data.columns
                    ].drop(columns=["SUBJID", "VISIT"])
                    orig_data_filtered = orig_data.drop(
                        columns=["SUBJID", "VISIT"]
                    )

                    fold_rel_corr_loss, m1, m2 = RelativeCorrelation.error(
                        real=orig_data_filtered,
                        synthetic=synthetic_data_filtered,
                    )

                    auc = get_auc(orig_data_filtered, synthetic_data_filtered)
                    if auc < 0.5:
                        auc = 1 - auc
                    auc_quality = max(math.floor((1 - auc) * 200), 1)
                    mlflow.log_metric("AUC", auc)
                    mlflow.log_metric("AUC quality", auc_quality)
                    auc_metric.append(auc_quality)

                    mlflow.log_metric(
                        "Relative correlation loss", fold_rel_corr_loss
                    )
                    rel_corr_loss.append(fold_rel_corr_loss)

        loss = np.mean(fold_loss)
        avg_auc_metric = np.mean(auc_metric)

        # Get the n_epochs that correspond to the best loss
        best_epoch = n_epochs[np.argmin(fold_loss)]
        mlflow.log_metric("Best epoch", best_epoch)
        # Log the value in the trial
        trial.set_user_attr("best_epoch", best_epoch)
        trial.set_user_attr("n_epochs", n_epochs)

        if self.config.optimization.use_relative_correlation_error_for_optimization:
            rel_corr_loss = np.mean(rel_corr_loss)
            logger.info(f"Trial loss: {loss}; rel_corr_loss: {rel_corr_loss}")
            return loss, rel_corr_loss
        elif self.config.optimization.use_auc_for_optimization:
            return avg_auc_metric
        else:
            logger.info(f"Trial loss: {loss}")
            return loss

    def save_model_config(self, path: Path):
        """
        Saves the model configuration to a file.

        Args:
            path (Path): The path to save the model configuration.
        """
        if self.model is None:
            raise Exception("Model should not be none")

        if self.model_config is None:
            raise Exception("Model config should not be none")

        with path.open("wb") as f:
            dill.dump(self.model_config, f)

    def read_model_config(self, path: Path):
        """
        Reads the model configuration from a file.

        Args:
            path (Path): The path to read the model configuration from.
        """
        with path.open("rb") as f:
            self.model_config = dill.load(f)
hyperparameters(trial)

Function to suggest hyperparameters for the model

Parameters:

Name Type Description Default
trial Trial

Trial instance

required

Returns:

Type Description
Hyperparameters

Dict[str, Any]: Suggested hyperparameters

Source code in vambn/modelling/models/hivae/trainer.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
    """
    Function to suggest hyperparameters for the model

    Args:
        trial (optuna.Trial): Trial instance

    Returns:
        Dict[str, Any]: Suggested hyperparameters
    """
    opt = self.config.optimization
    dim_s = trial.suggest_int(
        "hidden_dim_s",
        opt.s_dim_lower,
        opt.s_dim_upper,
        opt.s_dim_step,
    )
    dim_y = trial.suggest_int(
        "hidden_dim_y",
        opt.y_dim_lower,
        opt.y_dim_upper,
        opt.y_dim_step,
    )
    dim_z = trial.suggest_int(
        "hidden_dim_z",
        opt.latent_dim_lower,
        opt.latent_dim_upper,
        opt.latent_dim_step,
    )
    learning_rate = trial.suggest_float(
        "learning_rate",
        opt.learning_rate_lower,
        opt.learning_rate_upper,
        log=True,
    )
    batch_size_n = trial.suggest_int(
        "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
    )
    batch_size = 2**batch_size_n
    # epochs = trial.suggest_int(
    #     "epochs",
    #     low=opt.epoch_lower,
    #     high=opt.epoch_upper,
    #     step=opt.epoch_step,
    # )

    if self.dataset.is_longitudinal:
        lstm_layers = trial.suggest_int(
            "lstm_layers",
            opt.lstm_layers_lower,
            opt.lstm_layers_upper,
            opt.lstm_layers_step,
        )
    else:
        lstm_layers = 1
    if self.use_mtl:
        mtl_string = trial.suggest_categorical(
            "mtl_methods",
            [
                "gradnorm",
                "graddrop",
                "gradnorm,graddrop",
            ],
        )
        mtl_methods = (
            tuple(mtl_string.split(","))
            if "," in mtl_string
            else tuple([mtl_string])
        )
    else:
        mtl_methods = ("identity",)
    return Hyperparameters(
        dim_s=dim_s,
        dim_y=dim_y,
        dim_z=dim_z,
        dropout=0.1,
        batch_size=batch_size,
        learning_rate=learning_rate,
        epochs=opt.max_epochs,
        lstm_layers=lstm_layers,
        mtl_methods=mtl_methods,
    )
read_model_config(path)

Reads the model configuration from a file.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required
Source code in vambn/modelling/models/hivae/trainer.py
1332
1333
1334
1335
1336
1337
1338
1339
1340
def read_model_config(self, path: Path):
    """
    Reads the model configuration from a file.

    Args:
        path (Path): The path to read the model configuration from.
    """
    with path.open("rb") as f:
        self.model_config = dill.load(f)
save_model_config(path)

Saves the model configuration to a file.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required
Source code in vambn/modelling/models/hivae/trainer.py
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
def save_model_config(self, path: Path):
    """
    Saves the model configuration to a file.

    Args:
        path (Path): The path to save the model configuration.
    """
    if self.model is None:
        raise Exception("Model should not be none")

    if self.model_config is None:
        raise Exception("Model config should not be none")

    with path.open("wb") as f:
        dill.dump(self.model_config, f)
train(best_parameters)

Trains the model using the best hyperparameters.

Parameters:

Name Type Description Default
best_parameters Hyperparameters

The best hyperparameters obtained from optimization.

required

Returns:

Name Type Description
GenericHivaeModel GenericHivaeModel

The trained model.

Source code in vambn/modelling/models/hivae/trainer.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def train(self, best_parameters: Hyperparameters) -> GenericHivaeModel:
    """
    Trains the model using the best hyperparameters.

    Args:
        best_parameters (Hyperparameters): The best hyperparameters obtained from optimization.

    Returns:
        GenericHivaeModel: The trained model.
    """
    ref_module = self.dataset.get_modules(self.module_name)[0]
    whole_dataloader = self.get_dataloader(
        self.dataset, best_parameters.batch_size, shuffle=True
    )
    self.model_config = HivaeConfig(
        name=self.module_name,
        variable_types=ref_module.variable_types,
        dim_s=best_parameters.dim_s,
        dim_y=best_parameters.dim_y,
        dim_z=best_parameters.dim_z,
        mtl_methods=best_parameters.mtl_methods,
        use_imputation_layer=self.config.training.use_imputation_layer,
        dropout=best_parameters.dropout,
        n_layers=best_parameters.lstm_layers,
        num_timepoints=self.dataset.visits_per_module[self.module_name],
    )
    self.model = self.init_model(self.model_config).to(self.device)
    self.model.device = self.device
    self.optimize_model()
    with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
        mlflow.log_params(best_parameters.__dict__)
        self.model.fit(
            train_dataloader=whole_dataloader,
            num_epochs=best_parameters.epochs,
            learning_rate=best_parameters.learning_rate,
        )
    return self.model
timed(fn)

Decorator to time a function for benchmarking purposes.

Parameters:

Name Type Description Default
fn Callable

Function to be timed.

required

Returns:

Type Description
Tuple[Any, float]

Tuple[Any, float]: Result of the function and the time taken to execute it.

Source code in vambn/modelling/models/hivae/trainer.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def timed(fn: Callable) -> Tuple[Any, float]:
    """
    Decorator to time a function for benchmarking purposes.

    Args:
        fn (Callable): Function to be timed.

    Returns:
        Tuple[Any, float]: Result of the function and the time taken to execute it.
    """

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

layers

ImputationLayer

Bases: Module

Imputation layer capable of handling both 2D and 3D data.

Source code in vambn/modelling/models/layers.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class ImputationLayer(nn.Module):
    """Imputation layer capable of handling both 2D and 3D data."""

    def __init__(self, feature_size: int) -> None:
        """Initialize the imputation layer.

        Args:
            feature_size (int): Size of the features dimension.

        Raises:
            ValueError: If `feature_size` is not positive.
        """
        super().__init__()

        if feature_size <= 0:
            raise ValueError(
                f"Feature size should be positive, got {feature_size}"
            )

        self.imputation_matrix = nn.Parameter(
            torch.zeros(feature_size, requires_grad=True)
        )
        nn.init.normal_(self.imputation_matrix)
        self.register_parameter("imputation_matrix", self.imputation_matrix)

    def forward(
        self, input_data: torch.Tensor, missing_mask: torch.Tensor
    ) -> torch.Tensor:
        """Perform the forward pass for data imputation.

        Args:
            input_data (torch.Tensor): Input data matrix, can be 2D (batch x features) or 3D (batch x time x features).
            missing_mask (torch.Tensor): Binary mask indicating missing values in `input_data`.

        Returns:
            torch.Tensor: Imputed data matrix.

        Raises:
            ValueError: If `input_data` is not 2D or 3D.
        """
        if input_data.dim() == 2:
            imputation = self.imputation_matrix
        elif input_data.dim() == 3:
            imputation = self.imputation_matrix.unsqueeze(0).expand(
                input_data.shape[1], -1
            )
        else:
            raise ValueError("Input data must be either 2D or 3D.")

        return input_data * missing_mask + imputation * (1.0 - missing_mask)
__init__(feature_size)

Initialize the imputation layer.

Parameters:

Name Type Description Default
feature_size int

Size of the features dimension.

required

Raises:

Type Description
ValueError

If feature_size is not positive.

Source code in vambn/modelling/models/layers.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, feature_size: int) -> None:
    """Initialize the imputation layer.

    Args:
        feature_size (int): Size of the features dimension.

    Raises:
        ValueError: If `feature_size` is not positive.
    """
    super().__init__()

    if feature_size <= 0:
        raise ValueError(
            f"Feature size should be positive, got {feature_size}"
        )

    self.imputation_matrix = nn.Parameter(
        torch.zeros(feature_size, requires_grad=True)
    )
    nn.init.normal_(self.imputation_matrix)
    self.register_parameter("imputation_matrix", self.imputation_matrix)
forward(input_data, missing_mask)

Perform the forward pass for data imputation.

Parameters:

Name Type Description Default
input_data Tensor

Input data matrix, can be 2D (batch x features) or 3D (batch x time x features).

required
missing_mask Tensor

Binary mask indicating missing values in input_data.

required

Returns:

Type Description
Tensor

torch.Tensor: Imputed data matrix.

Raises:

Type Description
ValueError

If input_data is not 2D or 3D.

Source code in vambn/modelling/models/layers.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def forward(
    self, input_data: torch.Tensor, missing_mask: torch.Tensor
) -> torch.Tensor:
    """Perform the forward pass for data imputation.

    Args:
        input_data (torch.Tensor): Input data matrix, can be 2D (batch x features) or 3D (batch x time x features).
        missing_mask (torch.Tensor): Binary mask indicating missing values in `input_data`.

    Returns:
        torch.Tensor: Imputed data matrix.

    Raises:
        ValueError: If `input_data` is not 2D or 3D.
    """
    if input_data.dim() == 2:
        imputation = self.imputation_matrix
    elif input_data.dim() == 3:
        imputation = self.imputation_matrix.unsqueeze(0).expand(
            input_data.shape[1], -1
        )
    else:
        raise ValueError("Input data must be either 2D or 3D.")

    return input_data * missing_mask + imputation * (1.0 - missing_mask)

ModifiedLinear

Bases: Linear

Source code in vambn/modelling/models/layers.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class ModifiedLinear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        """Initialize the ModifiedLinear layer.

        Args:
            in_features (int): The number of input features.
            out_features (int): The number of output features.
            bias (bool, optional): If set to False, the layer will not learn an additive bias. Defaults to True.
            device: The device on which to create the tensor. Defaults to None.
            dtype: The desired data type of the tensor. Defaults to None.
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        nn.init.orthogonal_(self.weight)

        if bias:
            nn.init.constant_(self.bias, 0)
__init__(in_features, out_features, bias=True, device=None, dtype=None)

Initialize the ModifiedLinear layer.

Parameters:

Name Type Description Default
in_features int

The number of input features.

required
out_features int

The number of output features.

required
bias bool

If set to False, the layer will not learn an additive bias. Defaults to True.

True
device

The device on which to create the tensor. Defaults to None.

None
dtype

The desired data type of the tensor. Defaults to None.

None
Source code in vambn/modelling/models/layers.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = True,
    device=None,
    dtype=None,
) -> None:
    """Initialize the ModifiedLinear layer.

    Args:
        in_features (int): The number of input features.
        out_features (int): The number of output features.
        bias (bool, optional): If set to False, the layer will not learn an additive bias. Defaults to True.
        device: The device on which to create the tensor. Defaults to None.
        dtype: The desired data type of the tensor. Defaults to None.
    """
    super().__init__(in_features, out_features, bias, device, dtype)
    nn.init.orthogonal_(self.weight)

    if bias:
        nn.init.constant_(self.bias, 0)

init_weights(module)

Initialize the weights of a module.

Parameters:

Name Type Description Default
module Module

The module to initialize.

required
Source code in vambn/modelling/models/layers.py
 9
10
11
12
13
14
15
16
17
18
def init_weights(module: nn.Module):
    """Initialize the weights of a module.

    Args:
        module (nn.Module): The module to initialize.
    """
    if isinstance(module, nn.Linear):
        nn.init.normal_(module.weight, std=0.05)
        if module.bias is not None:
            module.bias.data.fill_(0.01)

templates

AbstractGanModel

Bases: AbstractModel[NewBatchInput, NewBaseOutput, NewForwardOutput, EncodingInput, Tuple[Optimizer, Optimizer, Optimizer], Tuple[_LRScheduler, _LRScheduler, _LRScheduler], float]

Abstract model class for GAN models.

Source code in vambn/modelling/models/templates.py
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
class AbstractGanModel(
    AbstractModel[
        NewBatchInput,
        NewBaseOutput,
        NewForwardOutput,
        EncodingInput,
        Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer],
        Tuple[
            optim.lr_scheduler._LRScheduler,
            optim.lr_scheduler._LRScheduler,
            optim.lr_scheduler._LRScheduler,
        ],
        float,
    ]
):
    """Abstract model class for GAN models."""

    def _calc_gradient_penalty(
        self, real_data: Tensor, fake_data: Tensor
    ) -> Tensor:
        """
        Calculate the gradient penalty loss for WGAN-GP.

        Args:
            real_data (Tensor): Real data.
            fake_data (Tensor): Fake data.

        Returns:
            Tensor: Gradient penalty loss.
        """
        alpha = torch.rand(real_data.shape[0], 1, device=real_data.device)
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
        disc_interpolates = self.discriminator(interpolates)
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones(
                disc_interpolates.size(), device=real_data.device
            ),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    @abstractmethod
    def _train_gan_discriminator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: optim.Optimizer,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """
        Perform a single GAN discriminator training step.

        Args:
            data (NewBatchInput): Input data for the discriminator.
            mask (NewBatchInput): Mask for the input data.
            optimizer (optim.Optimizer): Optimizer used for training.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Discriminator loss, real loss, and fake loss.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_model_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: optim.Optimizer,
    ) -> NewForwardOutput:
        """
        Perform a single model training step.

        Args:
            data (NewBatchInput): Input data for the model.
            mask (NewBatchInput): Mask for the input data.
            optimizer (optim.Optimizer): Optimizer used for training.

        Returns:
            NewForwardOutput: Model output.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_gan_generator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: optim.Optimizer,
    ) -> Tensor:
        """
        Perform a single GAN generator training step.

        Args:
            data (NewBatchInput): Input data for the generator.
            mask (NewBatchInput): Mask for the input data.
            optimizer (optim.Optimizer): Optimizer used for training.

        Returns:
            Tensor: Generator loss.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_model_from_discriminator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: optim.Optimizer,
    ) -> Tensor:
        """
        Perform a single model training step from the discriminator.

        Args:
            data (NewBatchInput): Input data for the model.
            mask (NewBatchInput): Mask for the input data.
            optimizer (optim.Optimizer): Optimizer used for training.

        Returns:
            Tensor: Model loss from the discriminator.
        """
        raise NotImplementedError

    @staticmethod
    def concat_and_aggregate(
        metric_list: Tuple[Tensor, ...] | List[Tensor], n: int
    ) -> Tensor:
        """
        Concatenate and aggregate metric tensors.

        Args:
            metric_list (Tuple[Tensor, ...] | List[Tensor]): List of metric tensors.
            n (int): Number of items.

        Returns:
            Tensor: Aggregated metric tensor.
        """
        if metric_list[0].ndim == 0:
            metric_tensor = torch.stack(metric_list)
        else:
            metric_tensor = torch.cat(metric_list)
        return torch.sum(metric_tensor) / n

    @abstractmethod
    def _get_loss_from_output(self, output: NewForwardOutput) -> Tensor:
        """
        Get the loss from the model output.

        Args:
            output (NewForwardOutput): Model output.

        Returns:
            Tensor: Loss tensor.
        """
        raise NotImplementedError

    @abstractmethod
    def _get_number_of_items(self, mask: NewBatchInput) -> int:
        """
        Get the number of items from the mask.

        Args:
            mask (NewBatchInput): Mask for the input data.

        Returns:
            int: Number of items.
        """
        raise NotImplementedError

    def _training_epoch(
        self,
        dataloader: DataLoader,
        optimizers: Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer],
    ) -> Tuple[float, List[float]]:
        """
        Perform a single training epoch.

        Args:
            dataloader (DataLoader): DataLoader for training data.
            optimizers (Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]): Optimizers used for training.

        Returns:
            Tuple[float, List[float]]: Average model loss and list of batch losses.
        """
        self.train()

        (
            model_optimizer,
            gan_discriminator_optimizer,
            gan_generator_optimizer,
        ) = optimizers

        # train the model
        loss = []
        items = []
        for data, missing in dataloader:
            ploss = self._get_loss_from_output(
                self._train_model_step(data, missing, model_optimizer)
            )
            loss.append(float(ploss.detach().cpu()))
            items.append(self._get_number_of_items(missing))

        items = sum(items)
        model_outputs = sum(loss) / items

        # train the GAN
        errD_loss = []
        errD_real_loss = []
        errD_fake_loss = []
        errG_loss = []
        errD_model_loss = []

        for data, missing in dataloader:
            errD, errD_real, errD_fake = self._train_gan_discriminator_step(
                data, missing, gan_discriminator_optimizer
            )
            errG = self._train_gan_generator_step(
                data, missing, gan_generator_optimizer
            )
            errD_ = self._train_model_from_discriminator_step(
                data, missing, model_optimizer
            )
            errD_loss.append(errD.detach().cpu())
            errD_real_loss.append(errD_real.detach().cpu())
            errD_fake_loss.append(errD_fake.detach().cpu())
            errG_loss.append(errG.detach().cpu())
            errD_model_loss.append(errD_.detach().cpu())

        # log the metrics
        mlflow.log_metric("train_model_loss", model_outputs)
        mlflow.log_metric(
            "train_errD_loss", self.concat_and_aggregate(errD_loss, items)
        )
        mlflow.log_metric(
            "train_errD_real_loss",
            self.concat_and_aggregate(errD_real_loss, items),
        )
        mlflow.log_metric(
            "train_errD_fake_loss",
            self.concat_and_aggregate(errD_fake_loss, items),
        )
        mlflow.log_metric(
            "train_errG_loss", self.concat_and_aggregate(errG_loss, items)
        )
        mlflow.log_metric(
            "train_errD_model_loss",
            self.concat_and_aggregate(errD_model_loss, items),
        )
        return model_outputs, loss

    def _get_optimizer(
        self,
        learning_rate: float,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]:
        """
        Get the optimizers for the model, GAN discriminator, and GAN generator.

        Args:
            learning_rate (float): Learning rate for the optimizers.
            beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
            beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

        Returns:
            Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer]: Optimizers for the model, GAN discriminator, and GAN generator.
        """
        model_optimizer = optim.Adam(
            self.model.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        gan_discriminator_optimizer = optim.Adam(
            self.discriminator.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        gan_generator_optimizer = optim.Adam(
            self.generator.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        return (
            model_optimizer,
            gan_discriminator_optimizer,
            gan_generator_optimizer,
        )

    def get_optimizer(
        self,
        learning_rate: float,
        num_epochs: int,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[
        Tuple[Optimizer, Optimizer, Optimizer],
        Tuple[_LRScheduler, _LRScheduler, _LRScheduler],
    ]:
        """
        Get the optimizers and schedulers for the model.

        Args:
            learning_rate (float): Learning rate for the optimizers.
            num_epochs (int): Number of epochs for training.
            beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
            beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

        Returns:
            Tuple[
                Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer],
                Tuple[optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler]
            ]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.
        """
        optimizers = self._get_optimizer(
            learning_rate=learning_rate,
            beta1=beta1,
            beta2=beta2,
        )
        schedulers = tuple(
            torch.optim.lr_scheduler.OneCycleLR(
                optimizer=optimizer,
                max_lr=learning_rate,
                total_steps=num_epochs,
            )
            for optimizer in optimizers
        )
        return optimizers, schedulers

    def fit(
        self,
        train_dataloader: DataLoader,
        num_epochs: int,
        learning_rate: float,
        val_dataloader: DataLoader | None = None,
    ) -> Tuple[float, int]:
        """
        Fit the model to the training data.

        Args:
            train_dataloader (DataLoader): DataLoader for training data.
            num_epochs (int): Number of epochs to train.
            learning_rate (float): Learning rate for the optimizer.
            val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

        Returns:
            Tuple[float, int]: Best validation loss and number of epochs trained.

        Raises:
            Exception: If number of epochs is less than 1.
        """
        if num_epochs <= 0:
            raise Exception("Number of epochs must be at least 1")

        # determine number of trainable parameters and non-trainable parameters
        trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad
        )
        non_trainable_params = sum(
            p.numel() for p in self.parameters() if not p.requires_grad
        )
        logger.info(f"Trainable parameters: {trainable_params}")
        logger.info(f"Non-trainable parameters: {non_trainable_params}")

        # get the optimizers and schedulers
        optimizers, schedulers = self.get_optimizer(
            learning_rate=learning_rate,
            num_epochs=num_epochs,
        )
        out = self.fabric.setup(self, *optimizers)
        self = out[0]
        optimizers = out[1:]
        for optimizer in optimizers:
            self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        if val_dataloader is not None:
            val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

            # use early stopping if val_dataloader is not None
        if val_dataloader is not None:
            best_loss = float("inf")
            patience = 0
        else:
            best_loss = None
            patience = None

        for current_epoch in tqdm(range(num_epochs), total=num_epochs):
            mlflow.log_metric("epoch", current_epoch, step=current_epoch)
            avg_loss, losses = self._training_epoch(
                dataloader=train_dataloader, optimizers=optimizers
            )
            mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
            for scheduler in schedulers:
                scheduler.step()

            if val_dataloader is not None and (
                current_epoch % 10 == 0 or current_epoch == num_epochs - 1
            ):
                val_loss = self._validation_epoch(val_dataloader)
                mlflow.log_metric("val_loss", val_loss, step=current_epoch)

                if val_loss < best_loss:
                    best_loss = val_loss
                    patience = 0
                else:
                    patience += 1

                if patience == 10:
                    logger.info("Early stopping")
                    break
        if val_dataloader is not None:
            val_loss = self._validation_epoch(val_dataloader)
            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0

        if best_loss is not None:
            best_epoch = (
                current_epoch - patience
                if patience is not None
                else current_epoch
            )
            if not isinstance(best_loss, float):
                best_loss = float(best_loss)
            assert isinstance(best_epoch, int)
        else:
            best_epoch = None

        return (best_loss, best_epoch)
concat_and_aggregate(metric_list, n) staticmethod

Concatenate and aggregate metric tensors.

Parameters:

Name Type Description Default
metric_list Tuple[Tensor, ...] | List[Tensor]

List of metric tensors.

required
n int

Number of items.

required

Returns:

Name Type Description
Tensor Tensor

Aggregated metric tensor.

Source code in vambn/modelling/models/templates.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
@staticmethod
def concat_and_aggregate(
    metric_list: Tuple[Tensor, ...] | List[Tensor], n: int
) -> Tensor:
    """
    Concatenate and aggregate metric tensors.

    Args:
        metric_list (Tuple[Tensor, ...] | List[Tensor]): List of metric tensors.
        n (int): Number of items.

    Returns:
        Tensor: Aggregated metric tensor.
    """
    if metric_list[0].ndim == 0:
        metric_tensor = torch.stack(metric_list)
    else:
        metric_tensor = torch.cat(metric_list)
    return torch.sum(metric_tensor) / n
fit(train_dataloader, num_epochs, learning_rate, val_dataloader=None)

Fit the model to the training data.

Parameters:

Name Type Description Default
train_dataloader DataLoader

DataLoader for training data.

required
num_epochs int

Number of epochs to train.

required
learning_rate float

Learning rate for the optimizer.

required
val_dataloader Optional[DataLoader]

DataLoader for validation data. Defaults to None.

None

Returns:

Type Description
Tuple[float, int]

Tuple[float, int]: Best validation loss and number of epochs trained.

Raises:

Type Description
Exception

If number of epochs is less than 1.

Source code in vambn/modelling/models/templates.py
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
def fit(
    self,
    train_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: float,
    val_dataloader: DataLoader | None = None,
) -> Tuple[float, int]:
    """
    Fit the model to the training data.

    Args:
        train_dataloader (DataLoader): DataLoader for training data.
        num_epochs (int): Number of epochs to train.
        learning_rate (float): Learning rate for the optimizer.
        val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

    Returns:
        Tuple[float, int]: Best validation loss and number of epochs trained.

    Raises:
        Exception: If number of epochs is less than 1.
    """
    if num_epochs <= 0:
        raise Exception("Number of epochs must be at least 1")

    # determine number of trainable parameters and non-trainable parameters
    trainable_params = sum(
        p.numel() for p in self.parameters() if p.requires_grad
    )
    non_trainable_params = sum(
        p.numel() for p in self.parameters() if not p.requires_grad
    )
    logger.info(f"Trainable parameters: {trainable_params}")
    logger.info(f"Non-trainable parameters: {non_trainable_params}")

    # get the optimizers and schedulers
    optimizers, schedulers = self.get_optimizer(
        learning_rate=learning_rate,
        num_epochs=num_epochs,
    )
    out = self.fabric.setup(self, *optimizers)
    self = out[0]
    optimizers = out[1:]
    for optimizer in optimizers:
        self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
    train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
    if val_dataloader is not None:
        val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

        # use early stopping if val_dataloader is not None
    if val_dataloader is not None:
        best_loss = float("inf")
        patience = 0
    else:
        best_loss = None
        patience = None

    for current_epoch in tqdm(range(num_epochs), total=num_epochs):
        mlflow.log_metric("epoch", current_epoch, step=current_epoch)
        avg_loss, losses = self._training_epoch(
            dataloader=train_dataloader, optimizers=optimizers
        )
        mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
        for scheduler in schedulers:
            scheduler.step()

        if val_dataloader is not None and (
            current_epoch % 10 == 0 or current_epoch == num_epochs - 1
        ):
            val_loss = self._validation_epoch(val_dataloader)
            mlflow.log_metric("val_loss", val_loss, step=current_epoch)

            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0
            else:
                patience += 1

            if patience == 10:
                logger.info("Early stopping")
                break
    if val_dataloader is not None:
        val_loss = self._validation_epoch(val_dataloader)
        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0

    if best_loss is not None:
        best_epoch = (
            current_epoch - patience
            if patience is not None
            else current_epoch
        )
        if not isinstance(best_loss, float):
            best_loss = float(best_loss)
        assert isinstance(best_epoch, int)
    else:
        best_epoch = None

    return (best_loss, best_epoch)
get_optimizer(learning_rate, num_epochs, beta1=0.9, beta2=0.999)

Get the optimizers and schedulers for the model.

Parameters:

Name Type Description Default
learning_rate float

Learning rate for the optimizers.

required
num_epochs int

Number of epochs for training.

required
beta1 float

Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.

0.9
beta2 float

Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

0.999

Returns:

Type Description
Tuple[Optimizer, Optimizer, Optimizer]

Tuple[ Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer], Tuple[optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler]

Tuple[_LRScheduler, _LRScheduler, _LRScheduler]

]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.

Source code in vambn/modelling/models/templates.py
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
def get_optimizer(
    self,
    learning_rate: float,
    num_epochs: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
) -> Tuple[
    Tuple[Optimizer, Optimizer, Optimizer],
    Tuple[_LRScheduler, _LRScheduler, _LRScheduler],
]:
    """
    Get the optimizers and schedulers for the model.

    Args:
        learning_rate (float): Learning rate for the optimizers.
        num_epochs (int): Number of epochs for training.
        beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
        beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

    Returns:
        Tuple[
            Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer],
            Tuple[optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler, optim.lr_scheduler._LRScheduler]
        ]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.
    """
    optimizers = self._get_optimizer(
        learning_rate=learning_rate,
        beta1=beta1,
        beta2=beta2,
    )
    schedulers = tuple(
        torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=learning_rate,
            total_steps=num_epochs,
        )
        for optimizer in optimizers
    )
    return optimizers, schedulers

AbstractGanModularModel

Bases: AbstractModel[NewBatchInput, NewBaseOutput, NewForwardOutput, EncodingInput, Tuple[Tuple[Optimizer, Optimizer, Optimizer], ...], Tuple[Tuple[_LRScheduler, _LRScheduler, _LRScheduler], ...], Tuple[float, ...]]

Abstract model class for GAN modular models.

Source code in vambn/modelling/models/templates.py
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
class AbstractGanModularModel(
    AbstractModel[
        NewBatchInput,
        NewBaseOutput,
        NewForwardOutput,
        EncodingInput,
        Tuple[Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer], ...],
        Tuple[
            Tuple[
                optim.lr_scheduler._LRScheduler,
                optim.lr_scheduler._LRScheduler,
                optim.lr_scheduler._LRScheduler,
            ],
            ...,
        ],
        Tuple[float, ...],
    ]
):
    """Abstract model class for GAN modular models."""

    def _calc_gradient_penalty(
        self, real_data: Tensor, fake_data: Tensor, discriminator: nn.Module
    ) -> Tensor:
        """
        Calculate the gradient penalty loss for WGAN-GP.

        Args:
            real_data (Tensor): Real data.
            fake_data (Tensor): Fake data.
            discriminator (nn.Module): Discriminator.

        Returns:
            Tensor: Gradient penalty loss.
        """
        alpha = torch.rand(real_data.shape[0], 1, device=real_data.device)
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)
        interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
        disc_interpolates = discriminator(interpolates)
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones(
                disc_interpolates.size(), device=real_data.device
            ),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    @abstractmethod
    def _train_model_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: Tuple[optim.Optimizer, ...],
    ) -> NewForwardOutput:
        """
        Perform a single model training step.

        Args:
            data (NewBatchInput): Input data for the model.
            mask (NewBatchInput): Mask for the input data.
            optimizer (Tuple[optim.Optimizer, ...]): Optimizers used for training.

        Returns:
            NewForwardOutput: Model output.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_gan_discriminator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: Tuple[optim.Optimizer, ...],
    ) -> Tuple[Tuple[Tensor, Tensor, Tensor], ...]:
        """
        Perform a single GAN discriminator training step.

        Args:
            data (NewBatchInput): Input data for the discriminator.
            mask (NewBatchInput): Mask for the input data.
            optimizer (Tuple[optim.Optimizer, ...]): Optimizers used for training.

        Returns:
            Tuple[Tuple[Tensor, Tensor, Tensor], ...]: Discriminator loss, real loss, and fake loss for each module.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_gan_generator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: Tuple[optim.Optimizer, ...],
    ) -> Tuple[Tensor, ...]:
        """
        Perform a single GAN generator training step.

        Args:
            data (NewBatchInput): Input data for the generator.
            mask (NewBatchInput): Mask for the input data.
            optimizer (Tuple[optim.Optimizer, ...]): Optimizers used for training.

        Returns:
            Tuple[Tensor, ...]: Generator loss for each module.
        """
        raise NotImplementedError

    @abstractmethod
    def _train_model_from_discriminator_step(
        self,
        data: NewBatchInput,
        mask: NewBatchInput,
        optimizer: Tuple[optim.Optimizer, ...],
    ) -> Tuple[Tensor, ...]:
        """
        Perform a single model training step from the discriminator.

        Args:
            data (NewBatchInput): Input data for the model.
            mask (NewBatchInput): Mask for the input data.
            optimizer (Tuple[optim.Optimizer, ...]): Optimizers used for training.

        Returns:
            Tuple[Tensor, ...]: Model loss from the discriminator for each module.
        """
        raise NotImplementedError

    @staticmethod
    def concat_and_aggregate(
        metric_list: Tuple[Tensor, ...] | List[Tensor], n: int
    ) -> Tensor:
        """
        Concatenate and aggregate metric tensors.

        Args:
            metric_list (Tuple[Tensor, ...] | List[Tensor]): List of metric tensors.
            n (int): Number of items.

        Returns:
            Tensor: Aggregated metric tensor.
        """
        return AbstractGanModel.concat_and_aggregate(metric_list, n)

    @abstractmethod
    def _get_loss_from_output(self, output: NewForwardOutput) -> float:
        """
        Get the loss from the model output.

        Args:
            output (NewForwardOutput): Model output.

        Returns:
            Tuple[Tensor, ...]: Loss tensor for each module.
        """
        raise NotImplementedError

    @abstractmethod
    def _get_number_of_items(self, mask: NewBatchInput) -> int:
        """
        Get the number of items from the mask.

        Args:
            mask (NewBatchInput): Mask for the input data.

        Returns:
            Tuple[int, ...]: Number of items for each module.
        """
        raise NotImplementedError

    def _training_epoch(
        self,
        dataloader: DataLoader,
        optimizers: Tuple[
            Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer], ...
        ],
    ) -> Tuple[Tensor, Tensor]:
        """
        Perform a single training epoch.

        Args:
            dataloader (DataLoader): DataLoader for training data.
            optimizers (Tuple[Tuple[optim.Optimizer, optim.Optimizer, optim.Optimizer], ...]): Optimizers used for training.

        Returns:
            Tuple[Tensor, Tensor]: Average model loss and list of batch losses.
        """
        self.train()
        losses = []
        items = []
        model_optimizers = tuple([opt[0] for opt in optimizers])
        gan_generator_optimizers = tuple([opt[1] for opt in optimizers])
        gan_discriminator_optimizers = tuple([opt[2] for opt in optimizers])
        # start = time.time()
        for data, mask in dataloader:
            ploss = self._get_loss_from_output(
                self._train_model_step(data, mask, model_optimizers)
            )
            losses.append(ploss)
            items.append(self._get_number_of_items(mask))
        losses = torch.tensor(losses)
        model_outputs = losses.sum() / torch.stack(items).sum()
        # print(f"Model training time: {time.time() - start}")

        # train the GAN
        errD_loss = []
        errD_real_loss = []
        errD_fake_loss = []
        errG_loss = []
        errD_model_loss = []

        # start = time.time()
        for data, mask in dataloader:
            errD, errD_real, errD_fake = self._train_gan_discriminator_step(
                data, mask, gan_discriminator_optimizers
            )
            errG = self._train_gan_generator_step(
                data, mask, gan_generator_optimizers
            )
            errD_ = self._train_model_from_discriminator_step(
                data, mask, model_optimizers
            )
            errD_loss.append(errD)
            errD_real_loss.append(errD_real)
            errD_fake_loss.append(errD_fake)
            errG_loss.append(errG)
            errD_model_loss.append(errD_)

        # print(f"GAN training time: {time.time() - start}")

        # stack the losses
        errD_per_module = torch.tensor(errD_loss).sum(dim=0)
        errD_real_per_module = torch.tensor(errD_real_loss).sum(dim=0)
        errD_fake_per_module = torch.tensor(errD_fake_loss).sum(dim=0)
        errG_per_module = torch.tensor(errG_loss).sum(dim=0)
        errD_model_per_module = torch.tensor(errD_model_loss).sum(dim=0)

        # log the summed metrics
        mlflow.log_metric("train_model_loss", model_outputs)
        mlflow.log_metric("train_errD_loss", errD_per_module.sum())
        mlflow.log_metric("train_errD_real_loss", errD_real_per_module.sum())
        mlflow.log_metric("train_errD_fake_loss", errD_fake_per_module.sum())
        mlflow.log_metric("train_errG_loss", errG_per_module.sum())
        mlflow.log_metric("train_errD_model_loss", errD_model_per_module.sum())

        return model_outputs, losses

    def get_optimizer(
        self,
        learning_rate: Tuple[float, ...],
        num_epochs: int,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[
        Tuple[Tuple[Optimizer, Optional[Optimizer], Optional[Optimizer]], ...],
        Tuple[_LRScheduler, ...],
    ]:
        """
        Get the optimizers and schedulers for the model.

        Args:
            learning_rate (Tuple[float, ...]): Learning rates for the optimizers.
            num_epochs (int): Number of epochs for training.
            beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
            beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

        Returns:
            Tuple[
                Tuple[Tuple[optim.Optimizer, Optional[optim.Optimizer], Optional[optim.Optimizer]], ...],
                Tuple[optim.lr_scheduler._LRScheduler, ...]
            ]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.
        """

        def _module_optimizer(module, learning_rate):
            opt = optim.Adam(
                module.parameters(),
                lr=learning_rate,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer=opt, max_lr=learning_rate, total_steps=num_epochs
            )
            return opt, scheduler

        learning_rate, shared_learning_rate = (
            learning_rate[:-1],
            learning_rate[-1],
        )
        assert len(learning_rate) == len(self.model.module_models)
        out = tuple(
            _module_optimizer(module, lr)
            for module, lr in zip(
                self.model.module_models.values(), learning_rate
            )
        )
        model_optimizers = tuple(x[0] for x in out)
        model_schedulers = tuple(x[1] for x in out)

        # get the optimizer for the shared element
        shared_optimizer = (
            optim.Adam(
                self.model.shared_element.parameters(),
                lr=shared_learning_rate,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            if self.model.shared_element.has_params
            else None
        )
        shared_scheduler = (
            torch.optim.lr_scheduler.OneCycleLR(
                optimizer=shared_optimizer,
                total_steps=num_epochs,
                max_lr=shared_learning_rate,
            )
            if shared_optimizer is not None
            else None
        )

        model_optimizers = (*model_optimizers, shared_optimizer)
        schedulers = (*model_schedulers, shared_scheduler)

        assert isinstance(learning_rate, Tuple), "Learning rate must be a tuple"
        # Get the GAN optimizers
        gan_discriminator_optimizers = tuple(
            optim.Adam(
                module.parameters(),
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            for lr, module in zip(learning_rate, self.discriminators)
        )

        gan_generator_optimizers = tuple(
            optim.Adam(
                module.parameters(),
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            for lr, module in zip(learning_rate, self.generators)
        )

        optimizers = [
            (mod, gen, disc)
            for mod, gen, disc in zip(
                model_optimizers,
                gan_generator_optimizers,
                gan_discriminator_optimizers,
            )
        ]
        optimizers.append((shared_optimizer, None, None))

        return optimizers, schedulers  # type: ignore

    def fit(
        self,
        train_dataloader: DataLoader,
        num_epochs: int,
        learning_rate: Tuple[float],
        val_dataloader: DataLoader | None = None,
    ) -> Tuple[float, int]:
        """
        Fit the model to the training data.

        Args:
            train_dataloader (DataLoader): DataLoader for training data.
            num_epochs (int): Number of epochs to train.
            learning_rate (Tuple[float]): Learning rates for the optimizers.
            val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

        Returns:
            Tuple[float, int]: Best validation loss and number of epochs trained.

        Raises:
            Exception: If number of epochs is less than 1.
        """

        if num_epochs <= 0:
            raise Exception("Number of epochs must be at least 1")

        # determine number of trainable parameters and non-trainable parameters
        trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad
        )
        non_trainable_params = sum(
            p.numel() for p in self.parameters() if not p.requires_grad
        )
        logger.info(f"Trainable parameters: {trainable_params}")
        logger.info(f"Non-trainable parameters: {non_trainable_params}")

        # get the optimizers and schedulers
        optimizers, schedulers = self.get_optimizer(
            learning_rate=learning_rate,
            num_epochs=num_epochs,
        )
        flattened_optimizers = [
            opt
            for optimizer in optimizers
            for opt in optimizer
            if opt is not None
        ]
        out = self.fabric.setup(self, *flattened_optimizers)
        self = out[0]
        flattened_optimizers = list(out[1:])
        # append two None values to the end of the list to match the length of the optimizers
        for optimizer in flattened_optimizers:
            self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
        flattened_optimizers.extend(
            [None, None] if self.model.shared_element.has_params else [None] * 3
        )
        optimizers = [
            tuple(flattened_optimizers[i : i + 3])
            for i in range(0, len(flattened_optimizers), 3)
        ]

        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        if val_dataloader is not None:
            val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

        # use early stopping if val_dataloader is not None
        if val_dataloader is not None:
            best_loss = float("inf")
            patience = 0
        else:
            best_loss = None
            patience = None

        for current_epoch in tqdm(range(num_epochs), total=num_epochs):
            mlflow.log_metric("epoch", current_epoch, step=current_epoch)
            avg_loss, losses = self._training_epoch(
                dataloader=train_dataloader, optimizers=optimizers
            )
            mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
            for scheduler in schedulers:
                if scheduler is None:
                    continue
                scheduler.step()

            if val_dataloader is not None and (
                current_epoch % 10 == 0 or current_epoch == num_epochs - 1
            ):
                val_loss = self._validation_epoch(val_dataloader)
                mlflow.log_metric("val_loss", val_loss, step=current_epoch)

                if val_loss < best_loss:
                    best_loss = val_loss
                    patience = 0
                else:
                    patience += 1

                if patience == 10:
                    logger.info("Early stopping")
                    break
        if val_dataloader is not None:
            val_loss = self._validation_epoch(val_dataloader)
            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0

        if best_loss is not None:
            best_epoch = (
                current_epoch - patience
                if patience is not None
                else current_epoch
            )
            if not isinstance(best_loss, float):
                best_loss = float(best_loss)
            assert isinstance(best_epoch, int)
        else:
            best_epoch = None

        return (best_loss, best_epoch)
concat_and_aggregate(metric_list, n) staticmethod

Concatenate and aggregate metric tensors.

Parameters:

Name Type Description Default
metric_list Tuple[Tensor, ...] | List[Tensor]

List of metric tensors.

required
n int

Number of items.

required

Returns:

Name Type Description
Tensor Tensor

Aggregated metric tensor.

Source code in vambn/modelling/models/templates.py
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
@staticmethod
def concat_and_aggregate(
    metric_list: Tuple[Tensor, ...] | List[Tensor], n: int
) -> Tensor:
    """
    Concatenate and aggregate metric tensors.

    Args:
        metric_list (Tuple[Tensor, ...] | List[Tensor]): List of metric tensors.
        n (int): Number of items.

    Returns:
        Tensor: Aggregated metric tensor.
    """
    return AbstractGanModel.concat_and_aggregate(metric_list, n)
fit(train_dataloader, num_epochs, learning_rate, val_dataloader=None)

Fit the model to the training data.

Parameters:

Name Type Description Default
train_dataloader DataLoader

DataLoader for training data.

required
num_epochs int

Number of epochs to train.

required
learning_rate Tuple[float]

Learning rates for the optimizers.

required
val_dataloader Optional[DataLoader]

DataLoader for validation data. Defaults to None.

None

Returns:

Type Description
Tuple[float, int]

Tuple[float, int]: Best validation loss and number of epochs trained.

Raises:

Type Description
Exception

If number of epochs is less than 1.

Source code in vambn/modelling/models/templates.py
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
def fit(
    self,
    train_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: Tuple[float],
    val_dataloader: DataLoader | None = None,
) -> Tuple[float, int]:
    """
    Fit the model to the training data.

    Args:
        train_dataloader (DataLoader): DataLoader for training data.
        num_epochs (int): Number of epochs to train.
        learning_rate (Tuple[float]): Learning rates for the optimizers.
        val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

    Returns:
        Tuple[float, int]: Best validation loss and number of epochs trained.

    Raises:
        Exception: If number of epochs is less than 1.
    """

    if num_epochs <= 0:
        raise Exception("Number of epochs must be at least 1")

    # determine number of trainable parameters and non-trainable parameters
    trainable_params = sum(
        p.numel() for p in self.parameters() if p.requires_grad
    )
    non_trainable_params = sum(
        p.numel() for p in self.parameters() if not p.requires_grad
    )
    logger.info(f"Trainable parameters: {trainable_params}")
    logger.info(f"Non-trainable parameters: {non_trainable_params}")

    # get the optimizers and schedulers
    optimizers, schedulers = self.get_optimizer(
        learning_rate=learning_rate,
        num_epochs=num_epochs,
    )
    flattened_optimizers = [
        opt
        for optimizer in optimizers
        for opt in optimizer
        if opt is not None
    ]
    out = self.fabric.setup(self, *flattened_optimizers)
    self = out[0]
    flattened_optimizers = list(out[1:])
    # append two None values to the end of the list to match the length of the optimizers
    for optimizer in flattened_optimizers:
        self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
    flattened_optimizers.extend(
        [None, None] if self.model.shared_element.has_params else [None] * 3
    )
    optimizers = [
        tuple(flattened_optimizers[i : i + 3])
        for i in range(0, len(flattened_optimizers), 3)
    ]

    train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
    if val_dataloader is not None:
        val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

    # use early stopping if val_dataloader is not None
    if val_dataloader is not None:
        best_loss = float("inf")
        patience = 0
    else:
        best_loss = None
        patience = None

    for current_epoch in tqdm(range(num_epochs), total=num_epochs):
        mlflow.log_metric("epoch", current_epoch, step=current_epoch)
        avg_loss, losses = self._training_epoch(
            dataloader=train_dataloader, optimizers=optimizers
        )
        mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
        for scheduler in schedulers:
            if scheduler is None:
                continue
            scheduler.step()

        if val_dataloader is not None and (
            current_epoch % 10 == 0 or current_epoch == num_epochs - 1
        ):
            val_loss = self._validation_epoch(val_dataloader)
            mlflow.log_metric("val_loss", val_loss, step=current_epoch)

            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0
            else:
                patience += 1

            if patience == 10:
                logger.info("Early stopping")
                break
    if val_dataloader is not None:
        val_loss = self._validation_epoch(val_dataloader)
        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0

    if best_loss is not None:
        best_epoch = (
            current_epoch - patience
            if patience is not None
            else current_epoch
        )
        if not isinstance(best_loss, float):
            best_loss = float(best_loss)
        assert isinstance(best_epoch, int)
    else:
        best_epoch = None

    return (best_loss, best_epoch)
get_optimizer(learning_rate, num_epochs, beta1=0.9, beta2=0.999)

Get the optimizers and schedulers for the model.

Parameters:

Name Type Description Default
learning_rate Tuple[float, ...]

Learning rates for the optimizers.

required
num_epochs int

Number of epochs for training.

required
beta1 float

Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.

0.9
beta2 float

Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

0.999

Returns:

Type Description
Tuple[Tuple[Optimizer, Optional[Optimizer], Optional[Optimizer]], ...]

Tuple[ Tuple[Tuple[optim.Optimizer, Optional[optim.Optimizer], Optional[optim.Optimizer]], ...], Tuple[optim.lr_scheduler._LRScheduler, ...]

Tuple[_LRScheduler, ...]

]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.

Source code in vambn/modelling/models/templates.py
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
def get_optimizer(
    self,
    learning_rate: Tuple[float, ...],
    num_epochs: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
) -> Tuple[
    Tuple[Tuple[Optimizer, Optional[Optimizer], Optional[Optimizer]], ...],
    Tuple[_LRScheduler, ...],
]:
    """
    Get the optimizers and schedulers for the model.

    Args:
        learning_rate (Tuple[float, ...]): Learning rates for the optimizers.
        num_epochs (int): Number of epochs for training.
        beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
        beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

    Returns:
        Tuple[
            Tuple[Tuple[optim.Optimizer, Optional[optim.Optimizer], Optional[optim.Optimizer]], ...],
            Tuple[optim.lr_scheduler._LRScheduler, ...]
        ]: Optimizers and schedulers for the model, GAN discriminator, and GAN generator.
    """

    def _module_optimizer(module, learning_rate):
        opt = optim.Adam(
            module.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=opt, max_lr=learning_rate, total_steps=num_epochs
        )
        return opt, scheduler

    learning_rate, shared_learning_rate = (
        learning_rate[:-1],
        learning_rate[-1],
    )
    assert len(learning_rate) == len(self.model.module_models)
    out = tuple(
        _module_optimizer(module, lr)
        for module, lr in zip(
            self.model.module_models.values(), learning_rate
        )
    )
    model_optimizers = tuple(x[0] for x in out)
    model_schedulers = tuple(x[1] for x in out)

    # get the optimizer for the shared element
    shared_optimizer = (
        optim.Adam(
            self.model.shared_element.parameters(),
            lr=shared_learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        if self.model.shared_element.has_params
        else None
    )
    shared_scheduler = (
        torch.optim.lr_scheduler.OneCycleLR(
            optimizer=shared_optimizer,
            total_steps=num_epochs,
            max_lr=shared_learning_rate,
        )
        if shared_optimizer is not None
        else None
    )

    model_optimizers = (*model_optimizers, shared_optimizer)
    schedulers = (*model_schedulers, shared_scheduler)

    assert isinstance(learning_rate, Tuple), "Learning rate must be a tuple"
    # Get the GAN optimizers
    gan_discriminator_optimizers = tuple(
        optim.Adam(
            module.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        for lr, module in zip(learning_rate, self.discriminators)
    )

    gan_generator_optimizers = tuple(
        optim.Adam(
            module.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        for lr, module in zip(learning_rate, self.generators)
    )

    optimizers = [
        (mod, gen, disc)
        for mod, gen, disc in zip(
            model_optimizers,
            gan_generator_optimizers,
            gan_discriminator_optimizers,
        )
    ]
    optimizers.append((shared_optimizer, None, None))

    return optimizers, schedulers  # type: ignore

AbstractModel

Bases: Generic[BatchInput, BaseOutput, ForwardOutput, EncodingInput, OptimizerInput, SchedulerInput, LearningRateInput], ABC, Module

Abstract base class for all models.

Source code in vambn/modelling/models/templates.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class AbstractModel(
    Generic[
        BatchInput,
        BaseOutput,
        ForwardOutput,
        EncodingInput,
        OptimizerInput,
        SchedulerInput,
        LearningRateInput,
    ],
    ABC,
    nn.Module,
):
    """Abstract base class for all models."""

    def __init__(self):
        """Initializes the model with device and fabric setup."""
        super(AbstractModel, self).__init__()
        nn.Module.__init__(self)
        self.device = torch.device("cpu")
        self.fabric = Fabric(accelerator="cpu", precision="bf16-mixed")
        self.fabric.seed_everything(1234)

    @abstractmethod
    def get_optimizer(
        self,
        learning_rate: LearningRateInput,
        num_epochs: int,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[OptimizerInput, SchedulerInput]:
        """
        Get the optimizer for the Modular-HIVAE.

        Args:
            learning_rate (LearningRateInput): Learning rate for the optimizer.
            num_epochs (int): Number of epochs for training.
            beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
            beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

        Returns:
            Tuple[OptimizerInput, SchedulerInput]: The optimizer and scheduler.
        """
        raise NotImplementedError

    @abstractmethod
    def forward(self, data: BatchInput, mask: BatchInput) -> ForwardOutput:
        """
        Defines the computation performed at every call.

        Args:
            data (BatchInput): Input data for the forward pass.
            mask (BatchInput): Mask for the input data.

        Returns:
            ForwardOutput: Output of the forward pass.
        """
        raise NotImplementedError

    @abstractmethod
    def decode(self, encoding: EncodingInput) -> BaseOutput:
        """
        Decodes the given encoding to the base output format.

        Args:
            encoding (EncodingInput): Encoding to be decoded.

        Returns:
            BaseOutput: Decoded output.
        """
        raise NotImplementedError

    @abstractmethod
    def _training_step(
        self, data: BatchInput, mask: BatchInput, optimizer: OptimizerInput
    ) -> float:
        """
        Perform a single training step.

        Args:
            data (BatchInput): Input data for the training step.
            mask (BatchInput): Mask for the input data.
            optimizer (OptimizerInput): Optimizer used for the training step.

        Returns:
            float: Loss value for the training step.
        """
        pass

    @staticmethod
    def _process_column(x_m_vt):
        """
        Processes a column of data.

        Args:
            x_m_vt: Tuple containing data, mask, and variable type.

        Returns:
            Processed column data.
        """
        x, m, vt = x_m_vt
        if vt.data_type == "cat":
            return Conversion._encode_categorical(x, m, vt.n_parameters)
        return x.view(-1, 1), m.view(-1, 1)

    @abstractmethod
    def _training_epoch(
        self, dataloader: DataLoader, optimizer: OptimizerInput
    ) -> Tuple[float, List[float]]:
        """
        Perform a single training epoch.

        Args:
            dataloader (DataLoader): DataLoader for training data.
            optimizer (OptimizerInput): Optimizer used for training.

        Returns:
            Tuple[float, List[float]]: Average loss and list of losses for each batch.
        """
        raise NotImplementedError

    @abstractmethod
    def _validation_step(self, data: BatchInput, mask: BatchInput) -> float:
        """
        Perform a single validation step.

        Args:
            data (BatchInput): Input data for the validation step.
            mask (BatchInput): Mask for the input data.

        Returns:
            float: Loss value for the validation step.
        """
        pass

    def _validation_epoch(self, dataloader: DataLoader) -> float:
        """
        Perform a validation epoch.

        Args:
            dataloader (DataLoader): DataLoader for validation data.

        Returns:
            float: Average validation loss.
        """
        self.eval()
        loss = []
        for data, missing in dataloader:
            loss.append(self._validation_step(data=data, mask=missing))

        # each loss is the average loss for a batch
        avg_loss = np.mean(loss)
        logger.info(f"Validation loss: {avg_loss}")
        return avg_loss

    @abstractmethod
    def fit(
        self,
        train_dataloader: DataLoader,
        num_epochs: int,
        learning_rate: LearningRateInput,
        val_dataloader: Optional[DataLoader] = None,
    ) -> Tuple[float, int]:
        """
        Fit the model to the training data.

        Args:
            train_dataloader (DataLoader): DataLoader for training data.
            num_epochs (int): Number of epochs to train.
            learning_rate (LearningRateInput): Learning rate for the optimizer.
            val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

        Returns:
            Tuple[float, int]: Best validation loss and number of epochs trained.
        """
        raise NotImplementedError

    @abstractmethod
    def _test_step(self, data: BatchInput, mask: BatchInput) -> float:
        """
        Perform a single test step.

        Args:
            data (BatchInput): Input data for the test step.
            mask (BatchInput): Mask for the input data.

        Returns:
            float: Loss value for the test step.
        """
        pass

    @abstractmethod
    def _predict_step(
        self, data: BatchInput, mask: BatchInput
    ) -> ForwardOutput:
        """
        Perform a single prediction step.

        Args:
            data (BatchInput): Input data for the prediction step.
            mask (BatchInput): Mask for the input data.

        Returns:
            ForwardOutput: Prediction output.
        """

        pass

    def predict(self, dataloader: DataLoader) -> ForwardOutput:
        """
        Perform prediction on a dataset.

        Args:
            dataloader (DataLoader): DataLoader for prediction data.

        Returns:
            ForwardOutput: Combined predictions for the entire dataset.

        Raises:
            Exception: If no data is provided to the model.
        """
        self.eval()
        outputs = None
        self = self.fabric.setup(self)
        dataloader = self.fabric.setup_dataloaders(dataloader)

        for data, mask in dataloader:
            tmp = self._predict_step(data=data, mask=mask)
            if outputs is None:
                outputs = tmp
            else:
                outputs += tmp

        if outputs is None:
            raise Exception("No data was provided to the model.")

        return outputs
__init__()

Initializes the model with device and fabric setup.

Source code in vambn/modelling/models/templates.py
45
46
47
48
49
50
51
def __init__(self):
    """Initializes the model with device and fabric setup."""
    super(AbstractModel, self).__init__()
    nn.Module.__init__(self)
    self.device = torch.device("cpu")
    self.fabric = Fabric(accelerator="cpu", precision="bf16-mixed")
    self.fabric.seed_everything(1234)
decode(encoding) abstractmethod

Decodes the given encoding to the base output format.

Parameters:

Name Type Description Default
encoding EncodingInput

Encoding to be decoded.

required

Returns:

Name Type Description
BaseOutput BaseOutput

Decoded output.

Source code in vambn/modelling/models/templates.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@abstractmethod
def decode(self, encoding: EncodingInput) -> BaseOutput:
    """
    Decodes the given encoding to the base output format.

    Args:
        encoding (EncodingInput): Encoding to be decoded.

    Returns:
        BaseOutput: Decoded output.
    """
    raise NotImplementedError
fit(train_dataloader, num_epochs, learning_rate, val_dataloader=None) abstractmethod

Fit the model to the training data.

Parameters:

Name Type Description Default
train_dataloader DataLoader

DataLoader for training data.

required
num_epochs int

Number of epochs to train.

required
learning_rate LearningRateInput

Learning rate for the optimizer.

required
val_dataloader Optional[DataLoader]

DataLoader for validation data. Defaults to None.

None

Returns:

Type Description
Tuple[float, int]

Tuple[float, int]: Best validation loss and number of epochs trained.

Source code in vambn/modelling/models/templates.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@abstractmethod
def fit(
    self,
    train_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: LearningRateInput,
    val_dataloader: Optional[DataLoader] = None,
) -> Tuple[float, int]:
    """
    Fit the model to the training data.

    Args:
        train_dataloader (DataLoader): DataLoader for training data.
        num_epochs (int): Number of epochs to train.
        learning_rate (LearningRateInput): Learning rate for the optimizer.
        val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

    Returns:
        Tuple[float, int]: Best validation loss and number of epochs trained.
    """
    raise NotImplementedError
forward(data, mask) abstractmethod

Defines the computation performed at every call.

Parameters:

Name Type Description Default
data BatchInput

Input data for the forward pass.

required
mask BatchInput

Mask for the input data.

required

Returns:

Name Type Description
ForwardOutput ForwardOutput

Output of the forward pass.

Source code in vambn/modelling/models/templates.py
75
76
77
78
79
80
81
82
83
84
85
86
87
@abstractmethod
def forward(self, data: BatchInput, mask: BatchInput) -> ForwardOutput:
    """
    Defines the computation performed at every call.

    Args:
        data (BatchInput): Input data for the forward pass.
        mask (BatchInput): Mask for the input data.

    Returns:
        ForwardOutput: Output of the forward pass.
    """
    raise NotImplementedError
get_optimizer(learning_rate, num_epochs, beta1=0.9, beta2=0.999) abstractmethod

Get the optimizer for the Modular-HIVAE.

Parameters:

Name Type Description Default
learning_rate LearningRateInput

Learning rate for the optimizer.

required
num_epochs int

Number of epochs for training.

required
beta1 float

Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.

0.9
beta2 float

Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

0.999

Returns:

Type Description
Tuple[OptimizerInput, SchedulerInput]

Tuple[OptimizerInput, SchedulerInput]: The optimizer and scheduler.

Source code in vambn/modelling/models/templates.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@abstractmethod
def get_optimizer(
    self,
    learning_rate: LearningRateInput,
    num_epochs: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
) -> Tuple[OptimizerInput, SchedulerInput]:
    """
    Get the optimizer for the Modular-HIVAE.

    Args:
        learning_rate (LearningRateInput): Learning rate for the optimizer.
        num_epochs (int): Number of epochs for training.
        beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
        beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

    Returns:
        Tuple[OptimizerInput, SchedulerInput]: The optimizer and scheduler.
    """
    raise NotImplementedError
predict(dataloader)

Perform prediction on a dataset.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader for prediction data.

required

Returns:

Name Type Description
ForwardOutput ForwardOutput

Combined predictions for the entire dataset.

Raises:

Type Description
Exception

If no data is provided to the model.

Source code in vambn/modelling/models/templates.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def predict(self, dataloader: DataLoader) -> ForwardOutput:
    """
    Perform prediction on a dataset.

    Args:
        dataloader (DataLoader): DataLoader for prediction data.

    Returns:
        ForwardOutput: Combined predictions for the entire dataset.

    Raises:
        Exception: If no data is provided to the model.
    """
    self.eval()
    outputs = None
    self = self.fabric.setup(self)
    dataloader = self.fabric.setup_dataloaders(dataloader)

    for data, mask in dataloader:
        tmp = self._predict_step(data=data, mask=mask)
        if outputs is None:
            outputs = tmp
        else:
            outputs += tmp

    if outputs is None:
        raise Exception("No data was provided to the model.")

    return outputs

AbstractModularModel

Bases: AbstractModel[NewBatchInput, NewBaseOutput, NewForwardOutput, EncodingInput, Tuple[Optimizer, ...], Tuple[_LRScheduler, ...], Tuple[float, ...]]

Abstract model class for normal models.

Source code in vambn/modelling/models/templates.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class AbstractModularModel(
    AbstractModel[
        NewBatchInput,
        NewBaseOutput,
        NewForwardOutput,
        EncodingInput,
        Tuple[optim.Optimizer, ...],
        Tuple[optim.lr_scheduler._LRScheduler, ...],
        Tuple[float, ...],
    ],
):
    """Abstract model class for normal models."""

    def _training_epoch(
        self,
        dataloader: DataLoader,
        optimizers: Tuple[optim.Optimizer, ...],
    ) -> Tuple[float, List[float]]:
        """
        Perform a single training epoch.

        Args:
            dataloader (DataLoader): DataLoader for training data.
            optimizer (optim.Optimizer): Optimizer used for training.

        Returns:
            Tuple[float, List[float]]: Average loss and list of losses for each batch.
        """
        self.train()
        losses = []
        for data, mask in dataloader:
            loss = self._training_step(
                data=data, mask=mask, optimizer=optimizers
            )
            losses.append(loss)
        return np.mean(losses), losses

    def get_optimizer(
        self,
        learning_rate: Tuple[float, ...],
        num_epochs: int,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[Tuple[Optimizer, ...], Tuple[_LRScheduler, ...]]:
        """
        Get the optimizer for the model.

        Args:
            learning_rate (float): Learning rate for the optimizer.
            num_epochs (int): Number of epochs for training.
            beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
            beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

        Returns:
            Tuple[optim.Optimizer, optim.lr_scheduler.OneCycleLR]: The optimizer and scheduler.
        """

        def _module_optimizer(module, learning_rate):
            opt = optim.Adam(
                module.parameters(),
                lr=learning_rate,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer=opt, max_lr=learning_rate, total_steps=num_epochs
            )
            return opt, scheduler

        learning_rate, shared_learning_rate = (
            learning_rate[:-1],
            learning_rate[-1],
        )
        assert len(learning_rate) == len(self.module_models)
        out = tuple(
            _module_optimizer(module, lr)
            for module, lr in zip(self.module_models.values(), learning_rate)
        )
        optimizers = tuple(x[0] for x in out)
        schedulers = tuple(x[1] for x in out)

        # get the optimizer for the shared element
        shared_optimizer = (
            optim.Adam(
                self.shared_element.parameters(),
                lr=shared_learning_rate,
                betas=(beta1, beta2),
                weight_decay=0.01,
            )
            if self.shared_element.has_params
            else None
        )
        shared_scheduler = (
            torch.optim.lr_scheduler.OneCycleLR(
                optimizer=shared_optimizer,
                total_steps=num_epochs,
                max_lr=shared_learning_rate,
            )
            if shared_optimizer is not None
            else None
        )

        optimizers = (*optimizers, shared_optimizer)
        schedulers = (*schedulers, shared_scheduler)

        return optimizers, schedulers

    def fit(
        self,
        train_dataloader: DataLoader,
        num_epochs: int,
        learning_rate: Tuple[float, ...],
        val_dataloader: DataLoader | None = None,
    ) -> Tuple[float, int]:
        """
        Fit the model to the training data.

        Args:
            train_dataloader (DataLoader): DataLoader for training data.
            num_epochs (int): Number of epochs to train.
            learning_rate (float): Learning rate for the optimizer.
            val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

        Returns:
            Tuple[float, int]: Best validation loss and number of epochs trained.

        Raises:
            Exception: If number of epochs is less than 1.
        """
        if num_epochs <= 0:
            raise Exception("Number of epochs must be at least 1")

        optimizers, schedulers = self.get_optimizer(
            learning_rate=learning_rate,
            num_epochs=num_epochs,
        )

        if val_dataloader is not None:
            best_loss = float("inf")
            patience = 0
        else:
            best_loss = None
            patience = None

        for current_epoch in tqdm(range(num_epochs), total=num_epochs):
            mlflow.log_metric("epoch", current_epoch, step=current_epoch)
            avg_loss, losses = self._training_epoch(
                dataloader=train_dataloader, optimizers=optimizers
            )
            for scheduler, loss in zip(schedulers, losses):
                if scheduler is None:
                    continue
                scheduler.step()

            if val_dataloader is not None and (
                current_epoch % 10 == 0 or current_epoch == num_epochs - 1
            ):
                val_loss = self._validation_epoch(val_dataloader)
                mlflow.log_metric("val_loss", val_loss, step=current_epoch)

                if val_loss < best_loss:
                    best_loss = val_loss
                    patience = 0
                else:
                    patience += 1

                if patience == 10:
                    logger.info("Early stopping")
                    break
        if val_dataloader is not None:
            val_loss = self._validation_epoch(val_dataloader)
            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0

        if best_loss is not None:
            best_epoch = (
                current_epoch - patience
                if patience is not None
                else current_epoch
            )
            if not isinstance(best_loss, float):
                best_loss = float(best_loss)
            assert isinstance(best_epoch, int)
        else:
            best_epoch = None

        return (best_loss, best_epoch)
fit(train_dataloader, num_epochs, learning_rate, val_dataloader=None)

Fit the model to the training data.

Parameters:

Name Type Description Default
train_dataloader DataLoader

DataLoader for training data.

required
num_epochs int

Number of epochs to train.

required
learning_rate float

Learning rate for the optimizer.

required
val_dataloader Optional[DataLoader]

DataLoader for validation data. Defaults to None.

None

Returns:

Type Description
Tuple[float, int]

Tuple[float, int]: Best validation loss and number of epochs trained.

Raises:

Type Description
Exception

If number of epochs is less than 1.

Source code in vambn/modelling/models/templates.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def fit(
    self,
    train_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: Tuple[float, ...],
    val_dataloader: DataLoader | None = None,
) -> Tuple[float, int]:
    """
    Fit the model to the training data.

    Args:
        train_dataloader (DataLoader): DataLoader for training data.
        num_epochs (int): Number of epochs to train.
        learning_rate (float): Learning rate for the optimizer.
        val_dataloader (Optional[DataLoader], optional): DataLoader for validation data. Defaults to None.

    Returns:
        Tuple[float, int]: Best validation loss and number of epochs trained.

    Raises:
        Exception: If number of epochs is less than 1.
    """
    if num_epochs <= 0:
        raise Exception("Number of epochs must be at least 1")

    optimizers, schedulers = self.get_optimizer(
        learning_rate=learning_rate,
        num_epochs=num_epochs,
    )

    if val_dataloader is not None:
        best_loss = float("inf")
        patience = 0
    else:
        best_loss = None
        patience = None

    for current_epoch in tqdm(range(num_epochs), total=num_epochs):
        mlflow.log_metric("epoch", current_epoch, step=current_epoch)
        avg_loss, losses = self._training_epoch(
            dataloader=train_dataloader, optimizers=optimizers
        )
        for scheduler, loss in zip(schedulers, losses):
            if scheduler is None:
                continue
            scheduler.step()

        if val_dataloader is not None and (
            current_epoch % 10 == 0 or current_epoch == num_epochs - 1
        ):
            val_loss = self._validation_epoch(val_dataloader)
            mlflow.log_metric("val_loss", val_loss, step=current_epoch)

            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0
            else:
                patience += 1

            if patience == 10:
                logger.info("Early stopping")
                break
    if val_dataloader is not None:
        val_loss = self._validation_epoch(val_dataloader)
        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0

    if best_loss is not None:
        best_epoch = (
            current_epoch - patience
            if patience is not None
            else current_epoch
        )
        if not isinstance(best_loss, float):
            best_loss = float(best_loss)
        assert isinstance(best_epoch, int)
    else:
        best_epoch = None

    return (best_loss, best_epoch)
get_optimizer(learning_rate, num_epochs, beta1=0.9, beta2=0.999)

Get the optimizer for the model.

Parameters:

Name Type Description Default
learning_rate float

Learning rate for the optimizer.

required
num_epochs int

Number of epochs for training.

required
beta1 float

Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.

0.9
beta2 float

Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

0.999

Returns:

Type Description
Tuple[Tuple[Optimizer, ...], Tuple[_LRScheduler, ...]]

Tuple[optim.Optimizer, optim.lr_scheduler.OneCycleLR]: The optimizer and scheduler.

Source code in vambn/modelling/models/templates.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def get_optimizer(
    self,
    learning_rate: Tuple[float, ...],
    num_epochs: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
) -> Tuple[Tuple[Optimizer, ...], Tuple[_LRScheduler, ...]]:
    """
    Get the optimizer for the model.

    Args:
        learning_rate (float): Learning rate for the optimizer.
        num_epochs (int): Number of epochs for training.
        beta1 (float, optional): Beta1 hyperparameter for the Adam optimizer. Defaults to 0.9.
        beta2 (float, optional): Beta2 hyperparameter for the Adam optimizer. Defaults to 0.999.

    Returns:
        Tuple[optim.Optimizer, optim.lr_scheduler.OneCycleLR]: The optimizer and scheduler.
    """

    def _module_optimizer(module, learning_rate):
        opt = optim.Adam(
            module.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=opt, max_lr=learning_rate, total_steps=num_epochs
        )
        return opt, scheduler

    learning_rate, shared_learning_rate = (
        learning_rate[:-1],
        learning_rate[-1],
    )
    assert len(learning_rate) == len(self.module_models)
    out = tuple(
        _module_optimizer(module, lr)
        for module, lr in zip(self.module_models.values(), learning_rate)
    )
    optimizers = tuple(x[0] for x in out)
    schedulers = tuple(x[1] for x in out)

    # get the optimizer for the shared element
    shared_optimizer = (
        optim.Adam(
            self.shared_element.parameters(),
            lr=shared_learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        if self.shared_element.has_params
        else None
    )
    shared_scheduler = (
        torch.optim.lr_scheduler.OneCycleLR(
            optimizer=shared_optimizer,
            total_steps=num_epochs,
            max_lr=shared_learning_rate,
        )
        if shared_optimizer is not None
        else None
    )

    optimizers = (*optimizers, shared_optimizer)
    schedulers = (*schedulers, shared_scheduler)

    return optimizers, schedulers

AbstractNormalModel

Bases: AbstractModel[NewBatchInput, NewBaseOutput, NewForwardOutput, EncodingInput, Optimizer, _LRScheduler, float]

Source code in vambn/modelling/models/templates.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
class AbstractNormalModel(
    AbstractModel[
        NewBatchInput,
        NewBaseOutput,
        NewForwardOutput,
        EncodingInput,
        optim.Optimizer,
        optim.lr_scheduler._LRScheduler,
        float,
    ],
):
    def _training_epoch(
        self, dataloader: DataLoader, optimizer: optim.Optimizer
    ) -> Tuple[float, List[float]]:
        """Training epoch for HIVAE"""
        self.train()
        loss = []
        for data, missing in dataloader:
            ploss = self._training_step(
                data=data, mask=missing, optimizer=optimizer
            )
            loss.append(ploss)

        return np.mean(loss), loss

    def get_optimizer(
        self,
        learning_rate: float,
        num_epochs: int,
        beta1: float = 0.9,
        beta2: float = 0.999,
    ) -> Tuple[optim.Optimizer, optim.lr_scheduler.OneCycleLR]:
        """Get the optimizer for the Modular-HIVAE"""
        optimizer = optim.Adam(
            self.parameters(),
            lr=learning_rate,
            betas=(beta1, beta2),
            weight_decay=0.01,
        )
        return optimizer, torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer, max_lr=learning_rate, total_steps=num_epochs
        )

    def fit(
        self,
        train_dataloader: DataLoader,
        num_epochs: int,
        learning_rate: float,
        val_dataloader: Optional[DataLoader] = None,
    ) -> Tuple[float, int]:
        """Fit the HIVAE model"""
        if num_epochs <= 0:
            raise Exception("Number of epochs must be at least 1")

        # determine number of trainable parameters and non-trainable parameters
        trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad
        )
        non_trainable_params = sum(
            p.numel() for p in self.parameters() if not p.requires_grad
        )
        logger.info(f"Trainable parameters: {trainable_params}")
        logger.info(f"Non-trainable parameters: {non_trainable_params}")

        optimizer, scheduler = self.get_optimizer(
            learning_rate=learning_rate,
            num_epochs=num_epochs,
        )
        self, optimizer = self.fabric.setup(self, optimizer)
        self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        if val_dataloader is not None:
            val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

        # use early stopping if val_dataloader is not None
        if val_dataloader is not None:
            best_loss = float("inf")
            patience = 0
        else:
            best_loss = None
            patience = None
        for current_epoch in tqdm(range(num_epochs), total=num_epochs):
            mlflow.log_metric("epoch", current_epoch, step=current_epoch)
            avg_loss, losses = self._training_epoch(
                dataloader=train_dataloader, optimizer=optimizer
            )
            mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
            scheduler.step()

            if current_epoch % 10 == 0 or current_epoch == num_epochs - 1:
                logger.info(f"Loss at epoch {current_epoch}: {avg_loss}")
                print(f"Loss at epoch {current_epoch}: {avg_loss}")

            if val_dataloader is not None and current_epoch % 25 == 0:
                val_loss = self._validation_epoch(val_dataloader)
                mlflow.log_metric("val_loss", val_loss, step=current_epoch)

                if val_loss < best_loss:
                    best_loss = val_loss
                    patience = 0
                else:
                    patience += 1

                if patience == 10:
                    logger.info("Early stopping")
                    break

        if val_dataloader is not None:
            val_loss = self._validation_epoch(val_dataloader)
            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0

        if best_loss is not None:
            best_epoch = (
                current_epoch - patience
                if patience is not None
                else current_epoch
            )
            if not isinstance(best_loss, float):
                best_loss = float(best_loss)
            assert isinstance(best_epoch, int)
        else:
            best_epoch = None

        return (
            best_loss,
            best_epoch,
        )
fit(train_dataloader, num_epochs, learning_rate, val_dataloader=None)

Fit the HIVAE model

Source code in vambn/modelling/models/templates.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def fit(
    self,
    train_dataloader: DataLoader,
    num_epochs: int,
    learning_rate: float,
    val_dataloader: Optional[DataLoader] = None,
) -> Tuple[float, int]:
    """Fit the HIVAE model"""
    if num_epochs <= 0:
        raise Exception("Number of epochs must be at least 1")

    # determine number of trainable parameters and non-trainable parameters
    trainable_params = sum(
        p.numel() for p in self.parameters() if p.requires_grad
    )
    non_trainable_params = sum(
        p.numel() for p in self.parameters() if not p.requires_grad
    )
    logger.info(f"Trainable parameters: {trainable_params}")
    logger.info(f"Non-trainable parameters: {non_trainable_params}")

    optimizer, scheduler = self.get_optimizer(
        learning_rate=learning_rate,
        num_epochs=num_epochs,
    )
    self, optimizer = self.fabric.setup(self, optimizer)
    self.fabric.clip_gradients(self, optimizer, max_norm=1.0)
    train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
    if val_dataloader is not None:
        val_dataloader = self.fabric.setup_dataloaders(val_dataloader)

    # use early stopping if val_dataloader is not None
    if val_dataloader is not None:
        best_loss = float("inf")
        patience = 0
    else:
        best_loss = None
        patience = None
    for current_epoch in tqdm(range(num_epochs), total=num_epochs):
        mlflow.log_metric("epoch", current_epoch, step=current_epoch)
        avg_loss, losses = self._training_epoch(
            dataloader=train_dataloader, optimizer=optimizer
        )
        mlflow.log_metric("train_loss", avg_loss, step=current_epoch)
        scheduler.step()

        if current_epoch % 10 == 0 or current_epoch == num_epochs - 1:
            logger.info(f"Loss at epoch {current_epoch}: {avg_loss}")
            print(f"Loss at epoch {current_epoch}: {avg_loss}")

        if val_dataloader is not None and current_epoch % 25 == 0:
            val_loss = self._validation_epoch(val_dataloader)
            mlflow.log_metric("val_loss", val_loss, step=current_epoch)

            if val_loss < best_loss:
                best_loss = val_loss
                patience = 0
            else:
                patience += 1

            if patience == 10:
                logger.info("Early stopping")
                break

    if val_dataloader is not None:
        val_loss = self._validation_epoch(val_dataloader)
        if val_loss < best_loss:
            best_loss = val_loss
            patience = 0

    if best_loss is not None:
        best_epoch = (
            current_epoch - patience
            if patience is not None
            else current_epoch
        )
        if not isinstance(best_loss, float):
            best_loss = float(best_loss)
        assert isinstance(best_epoch, int)
    else:
        best_epoch = None

    return (
        best_loss,
        best_epoch,
    )
get_optimizer(learning_rate, num_epochs, beta1=0.9, beta2=0.999)

Get the optimizer for the Modular-HIVAE

Source code in vambn/modelling/models/templates.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def get_optimizer(
    self,
    learning_rate: float,
    num_epochs: int,
    beta1: float = 0.9,
    beta2: float = 0.999,
) -> Tuple[optim.Optimizer, optim.lr_scheduler.OneCycleLR]:
    """Get the optimizer for the Modular-HIVAE"""
    optimizer = optim.Adam(
        self.parameters(),
        lr=learning_rate,
        betas=(beta1, beta2),
        weight_decay=0.01,
    )
    return optimizer, torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer, max_lr=learning_rate, total_steps=num_epochs
    )

mtl

minnormsolver

This script includes code adapted from the 'impartial-vaes' repository with minor modifications. The original code can be found at: https://github.com/adrianjav/impartial-vaes

Credit to the original authors: Adrian Javaloy, Maryam Meghdadi, and Isabel Valera for their valuable work.

MinNormLinearSolver

Bases: Module

Solves the min norm problem in case of 2 vectors (lies on a line).

Source code in vambn/modelling/mtl/minnormsolver.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class MinNormLinearSolver(nn.Module):
    """Solves the min norm problem in case of 2 vectors (lies on a line)."""

    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def forward(self, v1v1, v1v2, v2v2):
        """
        Solver execution on scalar products of 2 vectors.

        Args:
            v1v1 (float): Scalar product <V1, V1>.
            v1v2 (float): Scalar product <V1, V2>.
            v2v2 (float): Scalar product <V2, V2>.

        Returns:
            tuple: A tuple containing:
                - gamma (float): Min-norm solution c = (gamma, 1. - gamma).
                - cost (float): The norm of min-norm point.
        """
        if v1v2 >= v1v1:
            return 1.0, v1v1
        if v1v2 >= v2v2:
            return 0.0, v2v2
        gamma = (v2v2 - v1v2) / (v1v1 + v2v2 - 2 * v1v2 + 1e-8)
        cost = v2v2 + gamma * (v1v2 - v2v2)
        return gamma, cost
forward(v1v1, v1v2, v2v2)

Solver execution on scalar products of 2 vectors.

Parameters:

Name Type Description Default
v1v1 float

Scalar product .

required
v1v2 float

Scalar product .

required
v2v2 float

Scalar product .

required

Returns:

Name Type Description
tuple

A tuple containing: - gamma (float): Min-norm solution c = (gamma, 1. - gamma). - cost (float): The norm of min-norm point.

Source code in vambn/modelling/mtl/minnormsolver.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@torch.no_grad()
def forward(self, v1v1, v1v2, v2v2):
    """
    Solver execution on scalar products of 2 vectors.

    Args:
        v1v1 (float): Scalar product <V1, V1>.
        v1v2 (float): Scalar product <V1, V2>.
        v2v2 (float): Scalar product <V2, V2>.

    Returns:
        tuple: A tuple containing:
            - gamma (float): Min-norm solution c = (gamma, 1. - gamma).
            - cost (float): The norm of min-norm point.
    """
    if v1v2 >= v1v1:
        return 1.0, v1v1
    if v1v2 >= v2v2:
        return 0.0, v2v2
    gamma = (v2v2 - v1v2) / (v1v1 + v2v2 - 2 * v1v2 + 1e-8)
    cost = v2v2 + gamma * (v1v2 - v2v2)
    return gamma, cost

MinNormPlanarSolver

Bases: Module

Solves the min norm problem in case the vectors lie on the same plane.

Source code in vambn/modelling/mtl/minnormsolver.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class MinNormPlanarSolver(nn.Module):
    """Solves the min norm problem in case the vectors lie on the same plane."""

    def __init__(self, n_tasks):
        """
        Initializes the MinNormPlanarSolver.

        Args:
            n_tasks (int): Number of tasks/vectors.
        """
        super().__init__()
        i_grid = torch.arange(n_tasks)
        j_grid = torch.arange(n_tasks)
        ii_grid, jj_grid = torch.meshgrid(i_grid, j_grid)
        i_triu, j_triu = np.triu_indices(n_tasks, 1)

        self.register_buffer("n", torch.tensor(n_tasks))
        self.register_buffer("i_triu", torch.from_numpy(i_triu))
        self.register_buffer("j_triu", torch.from_numpy(j_triu))
        self.register_buffer("ii_triu", ii_grid[i_triu, j_triu])
        self.register_buffer("jj_triu", jj_grid[i_triu, j_triu])
        self.register_buffer("one", torch.ones(self.ii_triu.shape))
        self.register_buffer("zero", torch.zeros(self.ii_triu.shape))

    @torch.no_grad()
    def line_solver_vectorized(self, v1v1, v1v2, v2v2):
        """
        Linear case solver, but for collection of vector pairs (Vi, Vj).

        Args:
            v1v1 (Tensor): Vector of scalar products <Vi, Vi>.
            v1v2 (Tensor): Vector of scalar products <Vi, Vj>.
            v2v2 (Tensor): Vector of scalar products <Vj, Vj>.

        Returns:
            tuple: A tuple containing:
                - gamma (Tensor): Vector of min-norm solution c = (gamma, 1. - gamma).
                - cost (Tensor): Vector of the norm of min-norm point.
        """
        gamma = (v2v2 - v1v2) / (v1v1 + v2v2 - 2 * v1v2 + 1e-8)
        gamma = gamma.where(v1v2 < v2v2, self.zero)
        gamma = gamma.where(v1v2 < v1v1, self.one)

        cost = v2v2 + gamma * (v1v2 - v2v2)
        cost = cost.where(v1v2 < v2v2, v2v2)
        cost = cost.where(v1v2 < v1v1, v1v1)
        return gamma, cost

    @torch.no_grad()
    def forward(self, grammian):
        """
        Planar case solver, when Vi lies on the same plane.

        Args:
            grammian (Tensor): Grammian matrix G[i, j] = [<Vi, Vj>], G is a nxn tensor.

        Returns:
            Tensor: Coefficients c = [c1, ... cn] that solves the min-norm problem.
        """
        vivj = grammian[self.ii_triu, self.jj_triu]
        vivi = grammian[self.ii_triu, self.ii_triu]
        vjvj = grammian[self.jj_triu, self.jj_triu]

        gamma, cost = self.line_solver_vectorized(vivi, vivj, vjvj)
        offset = torch.argmin(cost)
        i_min, j_min = self.i_triu[offset], self.j_triu[offset]
        sol = torch.zeros(self.n, device=grammian.device)
        sol[i_min], sol[j_min] = gamma[offset], 1.0 - gamma[offset]
        return sol
__init__(n_tasks)

Initializes the MinNormPlanarSolver.

Parameters:

Name Type Description Default
n_tasks int

Number of tasks/vectors.

required
Source code in vambn/modelling/mtl/minnormsolver.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(self, n_tasks):
    """
    Initializes the MinNormPlanarSolver.

    Args:
        n_tasks (int): Number of tasks/vectors.
    """
    super().__init__()
    i_grid = torch.arange(n_tasks)
    j_grid = torch.arange(n_tasks)
    ii_grid, jj_grid = torch.meshgrid(i_grid, j_grid)
    i_triu, j_triu = np.triu_indices(n_tasks, 1)

    self.register_buffer("n", torch.tensor(n_tasks))
    self.register_buffer("i_triu", torch.from_numpy(i_triu))
    self.register_buffer("j_triu", torch.from_numpy(j_triu))
    self.register_buffer("ii_triu", ii_grid[i_triu, j_triu])
    self.register_buffer("jj_triu", jj_grid[i_triu, j_triu])
    self.register_buffer("one", torch.ones(self.ii_triu.shape))
    self.register_buffer("zero", torch.zeros(self.ii_triu.shape))
forward(grammian)

Planar case solver, when Vi lies on the same plane.

Parameters:

Name Type Description Default
grammian Tensor

Grammian matrix G[i, j] = [], G is a nxn tensor.

required

Returns:

Name Type Description
Tensor

Coefficients c = [c1, ... cn] that solves the min-norm problem.

Source code in vambn/modelling/mtl/minnormsolver.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@torch.no_grad()
def forward(self, grammian):
    """
    Planar case solver, when Vi lies on the same plane.

    Args:
        grammian (Tensor): Grammian matrix G[i, j] = [<Vi, Vj>], G is a nxn tensor.

    Returns:
        Tensor: Coefficients c = [c1, ... cn] that solves the min-norm problem.
    """
    vivj = grammian[self.ii_triu, self.jj_triu]
    vivi = grammian[self.ii_triu, self.ii_triu]
    vjvj = grammian[self.jj_triu, self.jj_triu]

    gamma, cost = self.line_solver_vectorized(vivi, vivj, vjvj)
    offset = torch.argmin(cost)
    i_min, j_min = self.i_triu[offset], self.j_triu[offset]
    sol = torch.zeros(self.n, device=grammian.device)
    sol[i_min], sol[j_min] = gamma[offset], 1.0 - gamma[offset]
    return sol
line_solver_vectorized(v1v1, v1v2, v2v2)

Linear case solver, but for collection of vector pairs (Vi, Vj).

Parameters:

Name Type Description Default
v1v1 Tensor

Vector of scalar products .

required
v1v2 Tensor

Vector of scalar products .

required
v2v2 Tensor

Vector of scalar products .

required

Returns:

Name Type Description
tuple

A tuple containing: - gamma (Tensor): Vector of min-norm solution c = (gamma, 1. - gamma). - cost (Tensor): Vector of the norm of min-norm point.

Source code in vambn/modelling/mtl/minnormsolver.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@torch.no_grad()
def line_solver_vectorized(self, v1v1, v1v2, v2v2):
    """
    Linear case solver, but for collection of vector pairs (Vi, Vj).

    Args:
        v1v1 (Tensor): Vector of scalar products <Vi, Vi>.
        v1v2 (Tensor): Vector of scalar products <Vi, Vj>.
        v2v2 (Tensor): Vector of scalar products <Vj, Vj>.

    Returns:
        tuple: A tuple containing:
            - gamma (Tensor): Vector of min-norm solution c = (gamma, 1. - gamma).
            - cost (Tensor): Vector of the norm of min-norm point.
    """
    gamma = (v2v2 - v1v2) / (v1v1 + v2v2 - 2 * v1v2 + 1e-8)
    gamma = gamma.where(v1v2 < v2v2, self.zero)
    gamma = gamma.where(v1v2 < v1v1, self.one)

    cost = v2v2 + gamma * (v1v2 - v2v2)
    cost = cost.where(v1v2 < v2v2, v2v2)
    cost = cost.where(v1v2 < v1v1, v1v1)
    return gamma, cost

MinNormSolver

Bases: Module

Solves the min norm problem in the general case.

Source code in vambn/modelling/mtl/minnormsolver.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
class MinNormSolver(nn.Module):
    """Solves the min norm problem in the general case."""

    def __init__(self, n_tasks, max_iter=250, stop_crit=1e-6):
        """
        Initializes the MinNormSolver.

        Args:
            n_tasks (int): Number of tasks/vectors.
            max_iter (int, optional): Maximum number of iterations. Defaults to 250.
            stop_crit (float, optional): Stopping criterion. Defaults to 1e-6.
        """
        super().__init__()
        self.n = n_tasks
        self.linear_solver = MinNormLinearSolver()
        self.planar_solver = MinNormPlanarSolver(n_tasks)

        n_grid = torch.arange(n_tasks)
        i_grid = torch.arange(n_tasks, dtype=torch.float32) + 1
        ii_grid, jj_grid = torch.meshgrid(n_grid, n_grid)

        self.register_buffer("n_ts", torch.tensor(n_tasks))
        self.register_buffer("i_grid", i_grid)
        self.register_buffer("ii_grid", ii_grid)
        self.register_buffer("jj_grid", jj_grid)
        self.register_buffer("zero", torch.zeros(n_tasks))
        self.register_buffer("stop_crit", torch.tensor(stop_crit))

        self.max_iter = max_iter
        self.two_sol = nn.Parameter(torch.zeros(2))
        self.two_sol.require_grad = False

    @torch.no_grad()
    def projection_to_simplex(self, gamma):
        """
        Projects gamma to the simplex.

        Args:
            gamma (Tensor): The input tensor to project.

        Returns:
            Tensor: The projected tensor.
        """
        sorted_gamma, indices = torch.sort(gamma, descending=True)
        tmp_sum = torch.cumsum(sorted_gamma, 0)
        tmp_max = (tmp_sum - 1.0) / self.i_grid

        non_zeros = torch.nonzero(tmp_max[:-1] > sorted_gamma[1:])
        if non_zeros.shape[0] > 0:
            tmax_f = tmp_max[:-1][non_zeros[0][0]]
        else:
            tmax_f = tmp_max[-1]
        return torch.max(gamma - tmax_f, self.zero)

    @torch.no_grad()
    def next_point(self, cur_val, grad):
        """
        Computes the next point in the optimization.

        Args:
            cur_val (Tensor): Current value.
            grad (Tensor): Gradient.

        Returns:
            Tensor: The next point.
        """
        proj_grad = grad - (torch.sum(grad) / self.n_ts)
        lt_zero = torch.nonzero(proj_grad < 0)
        lt_zero = lt_zero.view(lt_zero.numel())
        gt_zero = torch.nonzero(proj_grad > 0)
        gt_zero = gt_zero.view(gt_zero.numel())
        tm1 = -cur_val[lt_zero] / proj_grad[lt_zero]
        tm2 = (1.0 - cur_val[gt_zero]) / proj_grad[gt_zero]

        t = torch.tensor(1.0, device=grad.device)
        tm1_gt_zero = torch.nonzero(tm1 > 1e-7)
        tm1_gt_zero = tm1_gt_zero.view(tm1_gt_zero.numel())
        if tm1_gt_zero.shape[0] > 0:
            t = torch.min(tm1[tm1_gt_zero])

        tm2_gt_zero = torch.nonzero(tm2 > 1e-7)
        tm2_gt_zero = tm2_gt_zero.view(tm2_gt_zero.numel())
        if tm2_gt_zero.shape[0] > 0:
            t = torch.min(t, torch.min(tm2[tm2_gt_zero]))

        next_point = proj_grad * t + cur_val
        next_point = self.projection_to_simplex(next_point)
        return next_point

    @torch.no_grad()
    def forward(self, vecs):
        """
        General case solver using simplex projection algorithm.

        Args:
            vecs (Tensor): 2D tensor V, where each row is a vector Vi.

        Returns:
            Tensor: Coefficients c = [c1, ... cn] that solves the min-norm problem.
        """
        if self.n == 1:
            return vecs[0]
        if self.n == 2:
            v1v1 = torch.dot(vecs[0], vecs[0])
            v1v2 = torch.dot(vecs[0], vecs[1])
            v2v2 = torch.dot(vecs[1], vecs[1])
            self.two_sol[0], cost = self.linear_solver(v1v1, v1v2, v2v2)
            self.two_sol[1] = 1.0 - self.two_sol[0]
            return self.two_sol.clone()

        grammian = torch.mm(vecs, vecs.t())
        sol_vec = self.planar_solver(grammian)

        ii, jj = self.ii_grid, self.jj_grid
        for iter_count in range(self.max_iter):
            grad_dir = -torch.mv(grammian, sol_vec)
            new_point = self.next_point(sol_vec, grad_dir)

            v1v1 = (sol_vec[ii] * sol_vec[jj] * grammian[ii, jj]).sum()
            v1v2 = (sol_vec[ii] * new_point[jj] * grammian[ii, jj]).sum()
            v2v2 = (new_point[ii] * new_point[jj] * grammian[ii, jj]).sum()

            gamma, cost = self.linear_solver(v1v1, v1v2, v2v2)
            new_sol_vec = gamma * sol_vec + (1 - gamma) * new_point
            change = new_sol_vec - sol_vec
            if torch.sum(torch.abs(change)) < self.stop_crit:
                return sol_vec
            sol_vec = new_sol_vec
        return sol_vec
__init__(n_tasks, max_iter=250, stop_crit=1e-06)

Initializes the MinNormSolver.

Parameters:

Name Type Description Default
n_tasks int

Number of tasks/vectors.

required
max_iter int

Maximum number of iterations. Defaults to 250.

250
stop_crit float

Stopping criterion. Defaults to 1e-6.

1e-06
Source code in vambn/modelling/mtl/minnormsolver.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def __init__(self, n_tasks, max_iter=250, stop_crit=1e-6):
    """
    Initializes the MinNormSolver.

    Args:
        n_tasks (int): Number of tasks/vectors.
        max_iter (int, optional): Maximum number of iterations. Defaults to 250.
        stop_crit (float, optional): Stopping criterion. Defaults to 1e-6.
    """
    super().__init__()
    self.n = n_tasks
    self.linear_solver = MinNormLinearSolver()
    self.planar_solver = MinNormPlanarSolver(n_tasks)

    n_grid = torch.arange(n_tasks)
    i_grid = torch.arange(n_tasks, dtype=torch.float32) + 1
    ii_grid, jj_grid = torch.meshgrid(n_grid, n_grid)

    self.register_buffer("n_ts", torch.tensor(n_tasks))
    self.register_buffer("i_grid", i_grid)
    self.register_buffer("ii_grid", ii_grid)
    self.register_buffer("jj_grid", jj_grid)
    self.register_buffer("zero", torch.zeros(n_tasks))
    self.register_buffer("stop_crit", torch.tensor(stop_crit))

    self.max_iter = max_iter
    self.two_sol = nn.Parameter(torch.zeros(2))
    self.two_sol.require_grad = False
forward(vecs)

General case solver using simplex projection algorithm.

Parameters:

Name Type Description Default
vecs Tensor

2D tensor V, where each row is a vector Vi.

required

Returns:

Name Type Description
Tensor

Coefficients c = [c1, ... cn] that solves the min-norm problem.

Source code in vambn/modelling/mtl/minnormsolver.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@torch.no_grad()
def forward(self, vecs):
    """
    General case solver using simplex projection algorithm.

    Args:
        vecs (Tensor): 2D tensor V, where each row is a vector Vi.

    Returns:
        Tensor: Coefficients c = [c1, ... cn] that solves the min-norm problem.
    """
    if self.n == 1:
        return vecs[0]
    if self.n == 2:
        v1v1 = torch.dot(vecs[0], vecs[0])
        v1v2 = torch.dot(vecs[0], vecs[1])
        v2v2 = torch.dot(vecs[1], vecs[1])
        self.two_sol[0], cost = self.linear_solver(v1v1, v1v2, v2v2)
        self.two_sol[1] = 1.0 - self.two_sol[0]
        return self.two_sol.clone()

    grammian = torch.mm(vecs, vecs.t())
    sol_vec = self.planar_solver(grammian)

    ii, jj = self.ii_grid, self.jj_grid
    for iter_count in range(self.max_iter):
        grad_dir = -torch.mv(grammian, sol_vec)
        new_point = self.next_point(sol_vec, grad_dir)

        v1v1 = (sol_vec[ii] * sol_vec[jj] * grammian[ii, jj]).sum()
        v1v2 = (sol_vec[ii] * new_point[jj] * grammian[ii, jj]).sum()
        v2v2 = (new_point[ii] * new_point[jj] * grammian[ii, jj]).sum()

        gamma, cost = self.linear_solver(v1v1, v1v2, v2v2)
        new_sol_vec = gamma * sol_vec + (1 - gamma) * new_point
        change = new_sol_vec - sol_vec
        if torch.sum(torch.abs(change)) < self.stop_crit:
            return sol_vec
        sol_vec = new_sol_vec
    return sol_vec
next_point(cur_val, grad)

Computes the next point in the optimization.

Parameters:

Name Type Description Default
cur_val Tensor

Current value.

required
grad Tensor

Gradient.

required

Returns:

Name Type Description
Tensor

The next point.

Source code in vambn/modelling/mtl/minnormsolver.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
@torch.no_grad()
def next_point(self, cur_val, grad):
    """
    Computes the next point in the optimization.

    Args:
        cur_val (Tensor): Current value.
        grad (Tensor): Gradient.

    Returns:
        Tensor: The next point.
    """
    proj_grad = grad - (torch.sum(grad) / self.n_ts)
    lt_zero = torch.nonzero(proj_grad < 0)
    lt_zero = lt_zero.view(lt_zero.numel())
    gt_zero = torch.nonzero(proj_grad > 0)
    gt_zero = gt_zero.view(gt_zero.numel())
    tm1 = -cur_val[lt_zero] / proj_grad[lt_zero]
    tm2 = (1.0 - cur_val[gt_zero]) / proj_grad[gt_zero]

    t = torch.tensor(1.0, device=grad.device)
    tm1_gt_zero = torch.nonzero(tm1 > 1e-7)
    tm1_gt_zero = tm1_gt_zero.view(tm1_gt_zero.numel())
    if tm1_gt_zero.shape[0] > 0:
        t = torch.min(tm1[tm1_gt_zero])

    tm2_gt_zero = torch.nonzero(tm2 > 1e-7)
    tm2_gt_zero = tm2_gt_zero.view(tm2_gt_zero.numel())
    if tm2_gt_zero.shape[0] > 0:
        t = torch.min(t, torch.min(tm2[tm2_gt_zero]))

    next_point = proj_grad * t + cur_val
    next_point = self.projection_to_simplex(next_point)
    return next_point
projection_to_simplex(gamma)

Projects gamma to the simplex.

Parameters:

Name Type Description Default
gamma Tensor

The input tensor to project.

required

Returns:

Name Type Description
Tensor

The projected tensor.

Source code in vambn/modelling/mtl/minnormsolver.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@torch.no_grad()
def projection_to_simplex(self, gamma):
    """
    Projects gamma to the simplex.

    Args:
        gamma (Tensor): The input tensor to project.

    Returns:
        Tensor: The projected tensor.
    """
    sorted_gamma, indices = torch.sort(gamma, descending=True)
    tmp_sum = torch.cumsum(sorted_gamma, 0)
    tmp_max = (tmp_sum - 1.0) / self.i_grid

    non_zeros = torch.nonzero(tmp_max[:-1] > sorted_gamma[1:])
    if non_zeros.shape[0] > 0:
        tmax_f = tmp_max[:-1][non_zeros[0][0]]
    else:
        tmax_f = tmp_max[-1]
    return torch.max(gamma - tmax_f, self.zero)

moo

This script includes code adapted from the 'impartial-vaes' repository with minor modifications. The original code can be found at: https://github.com/adrianjav/impartial-vaes

Credit to the original authors: Adrian Javaloy, Maryam Meghdadi, and Isabel Valera for their valuable work.

MOOForLoop

Bases: Module

A PyTorch Module for Multiple Objective Optimization (MOO) within a loop.

Source code in vambn/modelling/mtl/moo.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
class MOOForLoop(nn.Module):
    """A PyTorch Module for Multiple Objective Optimization (MOO) within a loop."""

    inputs: Optional[torch.Tensor]

    def __init__(self, num_heads: int, moo_method: Optional[nn.Module] = None):
        """
        Initialize the MOOForLoop module.

        Args:
            num_heads (int): Number of heads for extending the input.
            moo_method (nn.Module, optional): The MOO method to be used. Default is None.
        """
        super().__init__()

        self._moo_method = [moo_method]
        self.num_heads = num_heads
        self.inputs = None
        self.outputs = None

        if self.moo_method is not None:
            self.register_full_backward_hook(MOOForLoop._hook)

    @property
    def moo_method(self):
        """Get the MOO method."""
        return self._moo_method[0]

    def _hook(
        self, grads_input: Tuple[torch.Tensor], grads_output: Any
    ) -> Tuple[torch.Tensor]:
        """
        Hook function to replace gradients with MOO directions.

        Args:
            grads_input (Tuple[torch.Tensor]): Gradients of the module's inputs.
            grads_output (Any): Gradients of the module's outputs.

        Returns:
            Tuple[torch.Tensor]: Modified gradients.
        """
        moo_directions = self.moo_method(
            grads_output[0], self.inputs, self.outputs
        )
        self.outputs = None

        original_norm = grads_output[0].sum(dim=0).norm(p=2)
        moo_norm = moo_directions.sum(dim=0).norm(p=2).clamp_min(1e-10)
        moo_directions.mul_(original_norm / moo_norm)

        return (moo_directions.sum(dim=0),)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward pass. Extend the input to the number of heads and store it.

        Args:
            z (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Extended input tensor.
        """
        extended_shape = [self.num_heads] + [-1 for _ in range(z.ndim)]
        if self.moo_method.requires_input and z.requires_grad:
            self.inputs = z.detach()
        extended_z = z.unsqueeze(0).expand(extended_shape)
        return extended_z

    def __str__(self) -> str:
        return f"MOOForLoop({self.moo_method})"
moo_method property

Get the MOO method.

__init__(num_heads, moo_method=None)

Initialize the MOOForLoop module.

Parameters:

Name Type Description Default
num_heads int

Number of heads for extending the input.

required
moo_method Module

The MOO method to be used. Default is None.

None
Source code in vambn/modelling/mtl/moo.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(self, num_heads: int, moo_method: Optional[nn.Module] = None):
    """
    Initialize the MOOForLoop module.

    Args:
        num_heads (int): Number of heads for extending the input.
        moo_method (nn.Module, optional): The MOO method to be used. Default is None.
    """
    super().__init__()

    self._moo_method = [moo_method]
    self.num_heads = num_heads
    self.inputs = None
    self.outputs = None

    if self.moo_method is not None:
        self.register_full_backward_hook(MOOForLoop._hook)
forward(z)

Forward pass. Extend the input to the number of heads and store it.

Parameters:

Name Type Description Default
z Tensor

Input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Extended input tensor.

Source code in vambn/modelling/mtl/moo.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def forward(self, z: torch.Tensor) -> torch.Tensor:
    """
    Forward pass. Extend the input to the number of heads and store it.

    Args:
        z (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Extended input tensor.
    """
    extended_shape = [self.num_heads] + [-1 for _ in range(z.ndim)]
    if self.moo_method.requires_input and z.requires_grad:
        self.inputs = z.detach()
    extended_z = z.unsqueeze(0).expand(extended_shape)
    return extended_z

MooMulti

Bases: Module

A PyTorch Module for Multiple Objective Optimization (MOO) within a loop.

Source code in vambn/modelling/mtl/moo.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class MooMulti(nn.Module):
    """A PyTorch Module for Multiple Objective Optimization (MOO) within a loop."""

    inputs: Optional[torch.Tensor]

    def __init__(
        self, num_modules: int, moo_method: Optional[nn.Module] = None
    ):
        """
        Initialize the MooMulti module.

        Args:
            num_modules (int): Number of heads for extending the input.
            moo_method (nn.Module, optional): The MOO method to be used. Default is None.
        """
        super().__init__()

        self._moo_method = [moo_method]
        self.num_heads = num_modules
        self.inputs = None
        self.outputs = None

        if self.moo_method is not None:
            self.register_full_backward_hook(MooMulti._hook)

    @property
    def moo_method(self):
        """Get the MOO method."""
        return self._moo_method[0]

    def _hook(
        self, grads_input: Tuple[torch.Tensor], grads_output: Any
    ) -> Tuple[torch.Tensor]:
        """
        Hook function to replace gradients with MOO directions.

        Args:
            grads_input (Tuple[torch.Tensor]): Gradients of the module's inputs.
            grads_output (Any): Gradients of the module's outputs.

        Returns:
            Tuple[torch.Tensor]: Modified gradients.
        """
        moo_directions = self.moo_method(
            grads_output[0], self.inputs, self.outputs
        )
        self.outputs = None

        if grads_output[0].shape != moo_directions.shape:
            raise ValueError(
                f"MOO directions shape {moo_directions.shape} does not match grads_output shape {grads_output[0].shape}"
            )

        original_norm = grads_output[0].norm(p=2)
        moo_norm = moo_directions.norm(p=2).clamp_min(1e-10)
        scaling_factor = original_norm / moo_norm
        scaled_moo_directions = moo_directions * scaling_factor

        if grads_input[0].shape != scaled_moo_directions.shape:
            raise ValueError(
                f"Scaled MOO directions shape {scaled_moo_directions.shape} does not match grads_input shape {grads_input[0].shape}"
            )
        return (scaled_moo_directions,)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward pass. Extend the input to the number of heads and store it.

        Args:
            z (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Extended input tensor.
        """
        return z

    def __str__(self) -> str:
        return f"MooMulti({self.moo_method})"
moo_method property

Get the MOO method.

__init__(num_modules, moo_method=None)

Initialize the MooMulti module.

Parameters:

Name Type Description Default
num_modules int

Number of heads for extending the input.

required
moo_method Module

The MOO method to be used. Default is None.

None
Source code in vambn/modelling/mtl/moo.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self, num_modules: int, moo_method: Optional[nn.Module] = None
):
    """
    Initialize the MooMulti module.

    Args:
        num_modules (int): Number of heads for extending the input.
        moo_method (nn.Module, optional): The MOO method to be used. Default is None.
    """
    super().__init__()

    self._moo_method = [moo_method]
    self.num_heads = num_modules
    self.inputs = None
    self.outputs = None

    if self.moo_method is not None:
        self.register_full_backward_hook(MooMulti._hook)
forward(z)

Forward pass. Extend the input to the number of heads and store it.

Parameters:

Name Type Description Default
z Tensor

Input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Extended input tensor.

Source code in vambn/modelling/mtl/moo.py
87
88
89
90
91
92
93
94
95
96
97
def forward(self, z: torch.Tensor) -> torch.Tensor:
    """
    Forward pass. Extend the input to the number of heads and store it.

    Args:
        z (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Extended input tensor.
    """
    return z

MultiMOOForLoop

Bases: Module

A PyTorch Module for applying multiple MOOForLoop modules in parallel.

Source code in vambn/modelling/mtl/moo.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
class MultiMOOForLoop(nn.Module):
    """A PyTorch Module for applying multiple MOOForLoop modules in parallel."""

    def __init__(self, num_heads: int, moo_methods: Sequence[nn.Module]):
        """
        Initialize the MultiMOOForLoop module.

        Args:
            num_heads (int): Number of heads for each MOOForLoop.
            moo_methods (Sequence[nn.Module]): List of MOO methods to be used.
        """
        super().__init__()

        self.num_inputs = len(moo_methods)
        self.loops = [MOOForLoop(num_heads, method) for method in moo_methods]

    def forward(self, *args) -> Generator[torch.Tensor, None, None]:
        """
        Forward pass. Applies each MOOForLoop to its corresponding input.

        Args:
            *args (torch.Tensor): Variable number of input tensors.

        Returns:
            Generator: A generator of extended input tensors after applying MOOForLoop.
        """
        if len(args) != self.num_inputs:
            raise ValueError(
                f"Expected {self.num_inputs} inputs, got {len(args)} instead."
            )
        return (loop(z) for z, loop in zip(args, self.loops))
__init__(num_heads, moo_methods)

Initialize the MultiMOOForLoop module.

Parameters:

Name Type Description Default
num_heads int

Number of heads for each MOOForLoop.

required
moo_methods Sequence[Module]

List of MOO methods to be used.

required
Source code in vambn/modelling/mtl/moo.py
178
179
180
181
182
183
184
185
186
187
188
189
def __init__(self, num_heads: int, moo_methods: Sequence[nn.Module]):
    """
    Initialize the MultiMOOForLoop module.

    Args:
        num_heads (int): Number of heads for each MOOForLoop.
        moo_methods (Sequence[nn.Module]): List of MOO methods to be used.
    """
    super().__init__()

    self.num_inputs = len(moo_methods)
    self.loops = [MOOForLoop(num_heads, method) for method in moo_methods]
forward(*args)

Forward pass. Applies each MOOForLoop to its corresponding input.

Parameters:

Name Type Description Default
*args Tensor

Variable number of input tensors.

()

Returns:

Name Type Description
Generator Generator[Tensor, None, None]

A generator of extended input tensors after applying MOOForLoop.

Source code in vambn/modelling/mtl/moo.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def forward(self, *args) -> Generator[torch.Tensor, None, None]:
    """
    Forward pass. Applies each MOOForLoop to its corresponding input.

    Args:
        *args (torch.Tensor): Variable number of input tensors.

    Returns:
        Generator: A generator of extended input tensors after applying MOOForLoop.
    """
    if len(args) != self.num_inputs:
        raise ValueError(
            f"Expected {self.num_inputs} inputs, got {len(args)} instead."
        )
    return (loop(z) for z, loop in zip(args, self.loops))

setup_moo(hparams, num_tasks)

Setup the multi-task learning module.

Parameters:

Name Type Description Default
hparams List[MtlMethodParams]

MTL method parameters.

required
num_tasks int

Number of tasks to perform.

required

Raises:

Type Description
ValueError

If invalid method name is provided.

Returns:

Type Description
Module

nn.Module: Module for MTL objective.

Source code in vambn/modelling/mtl/moo.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def setup_moo(hparams: List[MtlMethodParams], num_tasks: int) -> nn.Module:
    """
    Setup the multi-task learning module.

    Args:
        hparams (List[MtlMethodParams]): MTL method parameters.
        num_tasks (int): Number of tasks to perform.

    Raises:
        ValueError: If invalid method name is provided.

    Returns:
        nn.Module: Module for MTL objective.
    """
    if len(hparams) == 0:
        return mtl.Identity()

    modules = []
    for obj in hparams:
        try:
            method = mtl.MtlMethods[obj.name].value
        except KeyError:
            raise ValueError(f"Invalid method name: {obj.name}")

        if obj.name in ["nsgd"]:
            modules.append(method(num_tasks=num_tasks, update_at=obj.update_at))
        elif obj.name in ["gradnorm"]:
            modules.append(
                method(
                    num_tasks=num_tasks,
                    alpha=obj.alpha,
                    update_at=obj.update_at,
                )
            )
        elif obj.name in ["cagrad"]:
            modules.append(method(alpha=obj.alpha))
        elif obj.name in ["graddrop"]:
            modules.append(method(leakage=[0.2] * num_tasks))
        else:
            modules.append(method())

    return mtl.Compose(*modules) if len(modules) != 0 else None

mtl

This script includes code adapted from the 'impartial-vaes' repository with minor modifications. The original code can be found at: https://github.com/adrianjav/impartial-vaes

Credit to the original authors: Adrian Javaloy, Maryam Meghdadi, and Isabel Valera for their valuable work.

CAGrad

Bases: MOOMethod

CAGrad method for multiple objective optimization.

Source code in vambn/modelling/mtl/mtl.py
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
class CAGrad(MOOMethod):
    """CAGrad method for multiple objective optimization."""

    requires_input: bool = False

    def __init__(self, alpha: float):
        """
        Initialize CAGrad method.

        Args:
            alpha: Alpha parameter for CAGrad.
        """
        super(CAGrad, self).__init__()
        self.alpha = alpha

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using CAGrad method.

        Args:
            grads: Gradients tensor.
            inputs: Input tensor.
            outputs: Output tensor.

        Returns:
            New gradients tensor.
        """
        shape = grads.size()
        num_tasks = len(grads)
        grads = grads.flatten(start_dim=1).t()

        GG = grads.t().mm(grads).cpu()
        g0_norm = (GG.mean() + 1e-8).sqrt()

        x_start = np.ones(num_tasks) / num_tasks
        bnds = tuple((0, 1) for _ in x_start)
        cons = {"type": "eq", "fun": lambda x: 1 - sum(x)}

        A = GG.numpy()
        b = x_start.copy()
        c = (self.alpha * g0_norm + 1e-8).item()

        def objfn(x):
            return (
                x.reshape(1, num_tasks).dot(A).dot(b.reshape(num_tasks, 1))
                + c
                * np.sqrt(
                    x.reshape(1, num_tasks).dot(A).dot(x.reshape(num_tasks, 1))
                    + 1e-8
                )
            ).sum()

        res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
        w_cpu = res.x

        ww = torch.Tensor(w_cpu).to(grads.device)
        gw = (grads * ww.view(1, -1)).sum(1)
        gw_norm = gw.norm()
        lmbda = c / (gw_norm + 1e-8)
        g = (grads + lmbda * gw.unsqueeze(1)) / num_tasks

        g = g.t().reshape(shape)
        grads = g

        return grads
__init__(alpha)

Initialize CAGrad method.

Parameters:

Name Type Description Default
alpha float

Alpha parameter for CAGrad.

required
Source code in vambn/modelling/mtl/mtl.py
840
841
842
843
844
845
846
847
848
def __init__(self, alpha: float):
    """
    Initialize CAGrad method.

    Args:
        alpha: Alpha parameter for CAGrad.
    """
    super(CAGrad, self).__init__()
    self.alpha = alpha
forward(grads, inputs, outputs)

Compute new gradients using CAGrad method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using CAGrad method.

    Args:
        grads: Gradients tensor.
        inputs: Input tensor.
        outputs: Output tensor.

    Returns:
        New gradients tensor.
    """
    shape = grads.size()
    num_tasks = len(grads)
    grads = grads.flatten(start_dim=1).t()

    GG = grads.t().mm(grads).cpu()
    g0_norm = (GG.mean() + 1e-8).sqrt()

    x_start = np.ones(num_tasks) / num_tasks
    bnds = tuple((0, 1) for _ in x_start)
    cons = {"type": "eq", "fun": lambda x: 1 - sum(x)}

    A = GG.numpy()
    b = x_start.copy()
    c = (self.alpha * g0_norm + 1e-8).item()

    def objfn(x):
        return (
            x.reshape(1, num_tasks).dot(A).dot(b.reshape(num_tasks, 1))
            + c
            * np.sqrt(
                x.reshape(1, num_tasks).dot(A).dot(x.reshape(num_tasks, 1))
                + 1e-8
            )
        ).sum()

    res = minimize(objfn, x_start, bounds=bnds, constraints=cons)
    w_cpu = res.x

    ww = torch.Tensor(w_cpu).to(grads.device)
    gw = (grads * ww.view(1, -1)).sum(1)
    gw_norm = gw.norm()
    lmbda = c / (gw_norm + 1e-8)
    g = (grads + lmbda * gw.unsqueeze(1)) / num_tasks

    g = g.t().reshape(shape)
    grads = g

    return grads

Compose

Bases: MOOMethod

Compose multiple MOO methods.

Parameters:

Name Type Description Default
modules MOOMethod

List of MOO methods to compose.

()

Attributes:

Name Type Description
methods ModuleList

List of MOO methods.

requires_input bool

Flag indicating if input is required.

Source code in vambn/modelling/mtl/mtl.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class Compose(MOOMethod):
    """
    Compose multiple MOO methods.

    Args:
        modules (MOOMethod): List of MOO methods to compose.

    Attributes:
        methods (nn.ModuleList): List of MOO methods.
        requires_input (bool): Flag indicating if input is required.

    """

    def __init__(self, *modules: MOOMethod):
        super().__init__()
        self.methods = nn.ModuleList(modules)
        self.requires_input = any([m.requires_input for m in modules])

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Apply composed MOO methods sequentially.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: Modified gradients.
        """
        for module in self.methods:
            grads = module(grads, inputs, outputs)
        return grads
forward(grads, inputs, outputs)

Apply composed MOO methods sequentially.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Modified gradients.

Source code in vambn/modelling/mtl/mtl.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Apply composed MOO methods sequentially.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: Modified gradients.
    """
    for module in self.methods:
        grads = module(grads, inputs, outputs)
    return grads

GradDrop

Bases: MOOMethod

Gradient Dropout (GradDrop) method for MOO.

Parameters:

Name Type Description Default
leakage List[float]

List of leakage rates for each task.

required

Attributes:

Name Type Description
leakage List[float]

List of leakage rates for each task.

Source code in vambn/modelling/mtl/mtl.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
class GradDrop(MOOMethod):
    """Gradient Dropout (GradDrop) method for MOO.

    Args:
        leakage (List[float]): List of leakage rates for each task.

    Attributes:
        leakage (List[float]): List of leakage rates for each task.

    """

    requires_input: bool = True

    def __init__(self, leakage: List[float]):
        """
        Initialize GradDrop method.

        Args:
            leakage (List[float]): List of leakage rates for each task.

        Raises:
            AssertionError: If any leakage rate is not in the range [0, 1].

        """
        super(GradDrop, self).__init__()
        assert all(
            [0 <= x <= 1 for x in leakage]
        ), "All leakages should be in the range [0, 1]"
        self.leakage = leakage

    def forward(
        self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute new gradients using GradDrop method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.

        Raises:
            AssertionError: If the number of leakage parameters does not match the number of task gradients.

        """
        assert len(self.leakage) == len(
            grads
        ), "Leakage parameters should match the number of task gradients"
        sign_grads = [None for _ in range(len(grads))]
        for i in range(len(grads)):
            sign_grads[i] = inputs.sign() * grads[i]
            if len(grads[0].size()) > 1:  # It is batch-separated
                sign_grads[i] = grads[i].sum(dim=0, keepdim=True)

        odds = 0.5 * (
            1 + sum(sign_grads) / (sum(map(torch.abs, sign_grads)) + 1e-15)
        ).clamp(0, 1)
        assert odds.size() == sign_grads[0].size()  # pytype: disable=attribute-error

        new_grads = []
        samples = torch.rand(odds.size(), device=grads[0].device)
        for i in range(len(grads)):
            mask_i = torch.where(
                (odds > samples) & (sign_grads[i] > 0)  # pytype: disable=unsupported-operands
                | (odds < samples) & (sign_grads[i] < 0),  # pytype: disable=unsupported-operands
                torch.ones_like(odds),
                torch.zeros_like(odds),
            )
            mask_i = torch.lerp(
                mask_i, torch.ones_like(mask_i), self.leakage[i]
            )
            assert mask_i.size() == odds.size()
            new_grads.append(mask_i * grads[i])

        return torch.stack(new_grads, dim=0)
__init__(leakage)

Initialize GradDrop method.

Parameters:

Name Type Description Default
leakage List[float]

List of leakage rates for each task.

required

Raises:

Type Description
AssertionError

If any leakage rate is not in the range [0, 1].

Source code in vambn/modelling/mtl/mtl.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def __init__(self, leakage: List[float]):
    """
    Initialize GradDrop method.

    Args:
        leakage (List[float]): List of leakage rates for each task.

    Raises:
        AssertionError: If any leakage rate is not in the range [0, 1].

    """
    super(GradDrop, self).__init__()
    assert all(
        [0 <= x <= 1 for x in leakage]
    ), "All leakages should be in the range [0, 1]"
    self.leakage = leakage
forward(grads, inputs, outputs)

Compute new gradients using GradDrop method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Raises:

Type Description
AssertionError

If the number of leakage parameters does not match the number of task gradients.

Source code in vambn/modelling/mtl/mtl.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def forward(
    self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor
) -> torch.Tensor:
    """
    Compute new gradients using GradDrop method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.

    Raises:
        AssertionError: If the number of leakage parameters does not match the number of task gradients.

    """
    assert len(self.leakage) == len(
        grads
    ), "Leakage parameters should match the number of task gradients"
    sign_grads = [None for _ in range(len(grads))]
    for i in range(len(grads)):
        sign_grads[i] = inputs.sign() * grads[i]
        if len(grads[0].size()) > 1:  # It is batch-separated
            sign_grads[i] = grads[i].sum(dim=0, keepdim=True)

    odds = 0.5 * (
        1 + sum(sign_grads) / (sum(map(torch.abs, sign_grads)) + 1e-15)
    ).clamp(0, 1)
    assert odds.size() == sign_grads[0].size()  # pytype: disable=attribute-error

    new_grads = []
    samples = torch.rand(odds.size(), device=grads[0].device)
    for i in range(len(grads)):
        mask_i = torch.where(
            (odds > samples) & (sign_grads[i] > 0)  # pytype: disable=unsupported-operands
            | (odds < samples) & (sign_grads[i] < 0),  # pytype: disable=unsupported-operands
            torch.ones_like(odds),
            torch.zeros_like(odds),
        )
        mask_i = torch.lerp(
            mask_i, torch.ones_like(mask_i), self.leakage[i]
        )
        assert mask_i.size() == odds.size()
        new_grads.append(mask_i * grads[i])

    return torch.stack(new_grads, dim=0)

GradNorm

Bases: GradNormBase

Gradient Normalization (GradNorm) method for MOO.

Parameters:

Name Type Description Default
GradNormBase class

Base class for GradNorm.

required

Attributes:

Name Type Description
requires_input bool

Flag indicating whether input is required.

Methods:

Name Description
forward

Compute new gradients using GradNorm method.

Source code in vambn/modelling/mtl/mtl.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class GradNorm(GradNormBase):
    """Gradient Normalization (GradNorm) method for MOO.

    Args:
        GradNormBase (class): Base class for GradNorm.

    Attributes:
        requires_input (bool): Flag indicating whether input is required.

    Methods:
        forward: Compute new gradients using GradNorm method.

    """

    requires_input: bool = False

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using GradNorm method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        return self._forward(grads, outputs)
forward(grads, inputs, outputs)

Compute new gradients using GradNorm method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
445
446
447
448
449
450
451
452
453
454
455
456
457
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using GradNorm method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    return self._forward(grads, outputs)

GradNormBase

Bases: MOOMethod

Base class for Gradient Normalization (GradNorm) method.

Source code in vambn/modelling/mtl/mtl.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
class GradNormBase(MOOMethod):
    """Base class for Gradient Normalization (GradNorm) method."""

    initial_values: torch.Tensor
    counter: torch.Tensor

    def __init__(self, num_tasks: int, alpha: float, update_at: int = 20):
        """
        Initialize GradNormBase method.

        Args:
            num_tasks (int): Number of tasks.
            alpha (float): Alpha parameter for GradNorm.
            update_at (int): Update interval.
        """
        super(GradNormBase, self).__init__()
        self.epsilon = 1e-5
        self.num_tasks = num_tasks
        self.weight_ = nn.Parameter(torch.ones([num_tasks]), requires_grad=True)
        self.alpha = alpha
        self.update_at = update_at
        self.register_buffer("initial_values", torch.ones(self.num_tasks))
        self.register_buffer("counter", torch.zeros([]))

    @property
    def weight(self) -> torch.Tensor:
        """
        Compute normalized weights.

        Returns:
            torch.Tensor: Normalized weights.
        """
        ws = self.weight_.exp().clamp(self.epsilon, float("inf"))
        norm_coef = self.num_tasks / ws.sum()
        return ws * norm_coef

    def _forward(self, grads: torch.Tensor, values: List[float]) -> torch.Tensor:
        """
        Compute new gradients using GradNorm method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            values (List[float]): Values for each task.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        if self.initial_values is None or self.counter == self.update_at:
            self.initial_values = torch.tensor(values)
        self.counter += 1

        with torch.enable_grad():
            grads_norm = grads.flatten(start_dim=1).norm(p=2, dim=1)
            mean_grad_norm = (
                torch.mean(batch_product(grads_norm, self.weight), dim=0)
                .detach()
                .clone()
            )

            values = [
                x / y.clamp_min(self.epsilon)
                for x, y in zip(values, self.initial_values)
            ]
            average_value = torch.mean(torch.stack(values))

            loss = grads.new_zeros([])
            for i, [grad, value] in enumerate(zip(grads_norm, values)):
                r_i = value / average_value.clamp_min(self.epsilon)
                loss += torch.abs(
                    grad * self.weight[i]
                    - mean_grad_norm * torch.pow(r_i, self.alpha)
                )
            loss.backward()

        with torch.no_grad():
            new_grads = batch_product(grads, self.weight.detach())
        return new_grads
weight: torch.Tensor property

Compute normalized weights.

Returns:

Type Description
Tensor

torch.Tensor: Normalized weights.

__init__(num_tasks, alpha, update_at=20)

Initialize GradNormBase method.

Parameters:

Name Type Description Default
num_tasks int

Number of tasks.

required
alpha float

Alpha parameter for GradNorm.

required
update_at int

Update interval.

20
Source code in vambn/modelling/mtl/mtl.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def __init__(self, num_tasks: int, alpha: float, update_at: int = 20):
    """
    Initialize GradNormBase method.

    Args:
        num_tasks (int): Number of tasks.
        alpha (float): Alpha parameter for GradNorm.
        update_at (int): Update interval.
    """
    super(GradNormBase, self).__init__()
    self.epsilon = 1e-5
    self.num_tasks = num_tasks
    self.weight_ = nn.Parameter(torch.ones([num_tasks]), requires_grad=True)
    self.alpha = alpha
    self.update_at = update_at
    self.register_buffer("initial_values", torch.ones(self.num_tasks))
    self.register_buffer("counter", torch.zeros([]))

GradNormModified

Bases: GradNormBase

Modified Gradient Normalization (GradNorm) method for MOO.

Uses task-gradient convergence instead of task loss convergence.

Attributes:

Name Type Description
requires_input bool

Indicates whether the method requires input tensor.

Methods:

Name Description
forward

Compute new gradients using modified GradNorm method.

Source code in vambn/modelling/mtl/mtl.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
class GradNormModified(GradNormBase):
    """
    Modified Gradient Normalization (GradNorm) method for MOO.

    Uses task-gradient convergence instead of task loss convergence.

    Attributes:
        requires_input (bool): Indicates whether the method requires input tensor.

    Methods:
        forward(grads, inputs, outputs): Compute new gradients using modified GradNorm method.

    """

    requires_input: bool = False

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using modified GradNorm method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        return self._forward(grads, grads.flatten(start_dim=1).norm(p=2, dim=1))
forward(grads, inputs, outputs)

Compute new gradients using modified GradNorm method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
476
477
478
479
480
481
482
483
484
485
486
487
488
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using modified GradNorm method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    return self._forward(grads, grads.flatten(start_dim=1).norm(p=2, dim=1))

GradVac

Bases: MOOMethod

Gradient Vaccination (GradVac) method for MOO.

Source code in vambn/modelling/mtl/mtl.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
class GradVac(MOOMethod):
    """Gradient Vaccination (GradVac) method for MOO."""

    requires_input: bool = False

    def __init__(self, decay: float):
        """
        Initialize GradVac method.

        Args:
            decay: Decay rate for EMA.
        """
        super(GradVac, self).__init__()
        self.decay = decay

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using GradVac method.

        Args:
            grads: Gradients tensor.
            inputs: Input tensor.
            outputs: Output tensor.

        Returns:
            New gradients tensor.
        """

        def vac_projection(u: torch.Tensor, v: torch.Tensor, pre_ema: float, post_ema: float) -> torch.Tensor:
            norm_u = torch.dot(u, u).sqrt()
            norm_v = torch.dot(v, v).sqrt()

            numer = norm_u * (
                pre_ema * math.sqrt(1 - post_ema**2)
                - post_ema * math.sqrt(1 - pre_ema**2)
            )
            denom = norm_v * math.sqrt(1 - pre_ema**2)

            return numer / denom.clamp_min(1e-15) * v

        size = grads.size()[1:]
        num_tasks = grads.size(0)

        grads_list = [g.flatten() for g in grads]
        ema = [[0 for _ in range(num_tasks)] for _ in range(num_tasks)]

        new_grads = []
        for i in range(num_tasks):
            grad_i = grads_list[i]
            for j in np.random.permutation(num_tasks):
                if i == j:
                    continue
                grad_j = grads_list[j]
                cos_sim = torch.cosine_similarity(grad_i, grad_j, dim=0)
                if cos_sim < ema[i][j]:
                    grad_i = grad_i + vac_projection(
                        grad_i, grad_j, ema[i][j], cos_sim
                    )
                    assert id(grads_list[i]) != id(grad_i), "Aliasing!"
                ema[i][j] = (1 - self.decay) * ema[i][j] + self.decay * cos_sim
            new_grads.append(grad_i.reshape(size))

        return torch.stack(new_grads, dim=0)
__init__(decay)

Initialize GradVac method.

Parameters:

Name Type Description Default
decay float

Decay rate for EMA.

required
Source code in vambn/modelling/mtl/mtl.py
536
537
538
539
540
541
542
543
544
def __init__(self, decay: float):
    """
    Initialize GradVac method.

    Args:
        decay: Decay rate for EMA.
    """
    super(GradVac, self).__init__()
    self.decay = decay
forward(grads, inputs, outputs)

Compute new gradients using GradVac method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using GradVac method.

    Args:
        grads: Gradients tensor.
        inputs: Input tensor.
        outputs: Output tensor.

    Returns:
        New gradients tensor.
    """

    def vac_projection(u: torch.Tensor, v: torch.Tensor, pre_ema: float, post_ema: float) -> torch.Tensor:
        norm_u = torch.dot(u, u).sqrt()
        norm_v = torch.dot(v, v).sqrt()

        numer = norm_u * (
            pre_ema * math.sqrt(1 - post_ema**2)
            - post_ema * math.sqrt(1 - pre_ema**2)
        )
        denom = norm_v * math.sqrt(1 - pre_ema**2)

        return numer / denom.clamp_min(1e-15) * v

    size = grads.size()[1:]
    num_tasks = grads.size(0)

    grads_list = [g.flatten() for g in grads]
    ema = [[0 for _ in range(num_tasks)] for _ in range(num_tasks)]

    new_grads = []
    for i in range(num_tasks):
        grad_i = grads_list[i]
        for j in np.random.permutation(num_tasks):
            if i == j:
                continue
            grad_j = grads_list[j]
            cos_sim = torch.cosine_similarity(grad_i, grad_j, dim=0)
            if cos_sim < ema[i][j]:
                grad_i = grad_i + vac_projection(
                    grad_i, grad_j, ema[i][j], cos_sim
                )
                assert id(grads_list[i]) != id(grad_i), "Aliasing!"
            ema[i][j] = (1 - self.decay) * ema[i][j] + self.decay * cos_sim
        new_grads.append(grad_i.reshape(size))

    return torch.stack(new_grads, dim=0)

IMTLG

Bases: MOOMethod

IMTLG method for multiple objective optimization.

Source code in vambn/modelling/mtl/mtl.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
class IMTLG(MOOMethod):
    """IMTLG method for multiple objective optimization."""

    requires_input: bool = False

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using IMTLG method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        flatten_grads = grads.flatten(start_dim=1)
        num_tasks = len(grads)
        if num_tasks == 1:
            return grads

        grad_diffs, unit_diffs = [], []
        for i in range(1, num_tasks):
            grad_diffs.append(flatten_grads[0] - flatten_grads[i])
            unit_diffs.append(
                unitary(flatten_grads[0]) - unitary(flatten_grads[i])
            )
        grad_diffs = torch.stack(grad_diffs, dim=0)
        unit_diffs = torch.stack(unit_diffs, dim=0)

        DU_T = torch.einsum("ik,jk->ij", grad_diffs, unit_diffs)
        DU_T_inv = torch.pinverse(DU_T)

        alphas = torch.einsum(
            "i,ki,kj->j", grads[0].flatten(), unit_diffs, DU_T_inv
        )
        alphas = torch.cat(
            (1 - alphas.sum(dim=0).unsqueeze(dim=0), alphas), dim=0
        )

        return batch_product(grads, alphas)
forward(grads, inputs, outputs)

Compute new gradients using IMTLG method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using IMTLG method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    flatten_grads = grads.flatten(start_dim=1)
    num_tasks = len(grads)
    if num_tasks == 1:
        return grads

    grad_diffs, unit_diffs = [], []
    for i in range(1, num_tasks):
        grad_diffs.append(flatten_grads[0] - flatten_grads[i])
        unit_diffs.append(
            unitary(flatten_grads[0]) - unitary(flatten_grads[i])
        )
    grad_diffs = torch.stack(grad_diffs, dim=0)
    unit_diffs = torch.stack(unit_diffs, dim=0)

    DU_T = torch.einsum("ik,jk->ij", grad_diffs, unit_diffs)
    DU_T_inv = torch.pinverse(DU_T)

    alphas = torch.einsum(
        "i,ki,kj->j", grads[0].flatten(), unit_diffs, DU_T_inv
    )
    alphas = torch.cat(
        (1 - alphas.sum(dim=0).unsqueeze(dim=0), alphas), dim=0
    )

    return batch_product(grads, alphas)

Identity

Bases: MOOMethod

Identity MOO method that returns the input gradients unchanged.

Source code in vambn/modelling/mtl/mtl.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class Identity(MOOMethod):
    """Identity MOO method that returns the input gradients unchanged."""

    def forward(
        self,
        grads: torch.Tensor,
        inputs: Optional[torch.Tensor],
        outputs: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Return the input gradients unchanged.

        Args:
            grads (torch.Tensor): Input gradients.
            inputs (torch.Tensor, optional): Input tensor.
            outputs (torch.Tensor, optional): Output tensor.

        Returns:
            torch.Tensor: Unchanged input gradients.
        """
        return grads
forward(grads, inputs, outputs)

Return the input gradients unchanged.

Parameters:

Name Type Description Default
grads Tensor

Input gradients.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Unchanged input gradients.

Source code in vambn/modelling/mtl/mtl.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def forward(
    self,
    grads: torch.Tensor,
    inputs: Optional[torch.Tensor],
    outputs: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Return the input gradients unchanged.

    Args:
        grads (torch.Tensor): Input gradients.
        inputs (torch.Tensor, optional): Input tensor.
        outputs (torch.Tensor, optional): Output tensor.

    Returns:
        torch.Tensor: Unchanged input gradients.
    """
    return grads

MGDAUB

Bases: MOOMethod

MGDA-UB method for multiple objective optimization.

Source code in vambn/modelling/mtl/mtl.py
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
class MGDAUB(MOOMethod):
    """MGDA-UB method for multiple objective optimization."""

    requires_input: bool = False

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using MGDA-UB method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        epsilon: float = 1e-3
        shape: Tuple[int] = grads.size()[1:]
        grads = grads.flatten(start_dim=1).unsqueeze(dim=1)

        weights, min_norm = MinNormSolver.find_min_norm_element(
            grads.unbind(dim=0)
        )
        weights = [min(w, epsilon) for w in weights]

        grads = torch.stack(
            [g.reshape(shape) * w for g, w in zip(grads, weights)], dim=0
        )
        return grads
forward(grads, inputs, outputs)

Compute new gradients using MGDA-UB method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using MGDA-UB method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    epsilon: float = 1e-3
    shape: Tuple[int] = grads.size()[1:]
    grads = grads.flatten(start_dim=1).unsqueeze(dim=1)

    weights, min_norm = MinNormSolver.find_min_norm_element(
        grads.unbind(dim=0)
    )
    weights = [min(w, epsilon) for w in weights]

    grads = torch.stack(
        [g.reshape(shape) * w for g, w in zip(grads, weights)], dim=0
    )
    return grads

MOOMethod

Bases: Module

Base class for multiple objective optimization (MOO) methods.

Source code in vambn/modelling/mtl/mtl.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class MOOMethod(nn.Module, metaclass=ABCMeta):
    """Base class for multiple objective optimization (MOO) methods."""

    requires_input: bool = False

    def __init__(self):
        super().__init__()

    @abstractmethod
    def forward(
        self,
        grads: torch.Tensor,
        inputs: Optional[torch.Tensor],
        outputs: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Computes the new task gradients based on the original ones.

        Given K gradients of size D, returns a new set of K gradients of size D based on some criterion.

        Args:
            grads (torch.Tensor): Tensor of size K x D with the different gradients.
            inputs (torch.Tensor, optional): Tensor with the input of the forward pass (if requires_input is set to True).
            outputs (torch.Tensor, optional): Tensor with the K outputs of the module (not used currently).

        Returns:
            torch.Tensor: A tensor of the same size as `grads` with the new gradients to use during backpropagation.
        """
        raise NotImplementedError("You need to implement the forward pass.")
forward(grads, inputs, outputs) abstractmethod

Computes the new task gradients based on the original ones.

Given K gradients of size D, returns a new set of K gradients of size D based on some criterion.

Parameters:

Name Type Description Default
grads Tensor

Tensor of size K x D with the different gradients.

required
inputs Tensor

Tensor with the input of the forward pass (if requires_input is set to True).

required
outputs Tensor

Tensor with the K outputs of the module (not used currently).

required

Returns:

Type Description
Tensor

torch.Tensor: A tensor of the same size as grads with the new gradients to use during backpropagation.

Source code in vambn/modelling/mtl/mtl.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@abstractmethod
def forward(
    self,
    grads: torch.Tensor,
    inputs: Optional[torch.Tensor],
    outputs: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Computes the new task gradients based on the original ones.

    Given K gradients of size D, returns a new set of K gradients of size D based on some criterion.

    Args:
        grads (torch.Tensor): Tensor of size K x D with the different gradients.
        inputs (torch.Tensor, optional): Tensor with the input of the forward pass (if requires_input is set to True).
        outputs (torch.Tensor, optional): Tensor with the K outputs of the module (not used currently).

    Returns:
        torch.Tensor: A tensor of the same size as `grads` with the new gradients to use during backpropagation.
    """
    raise NotImplementedError("You need to implement the forward pass.")

MinNormSolver

Solver for finding the minimum norm solution in the convex hull of vectors.

Source code in vambn/modelling/mtl/mtl.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
class MinNormSolver:
    """Solver for finding the minimum norm solution in the convex hull of vectors."""

    MAX_ITER = 250
    STOP_CRIT = 1e-5

    @staticmethod
    def _min_norm_element_from2(v1v1: float, v1v2: float, v2v2: float) -> tuple:
        """
        Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2.

        Args:
            v1v1: <x1, x1>.
            v1v2: <x1, x2>.
            v2v2: <x2, x2>.

        Returns:
            tuple: Coefficients and cost for the minimum norm element.
        """
        if v1v2 >= v1v1:
            gamma = 0.999
            cost = v1v1
            return gamma, cost
        if v1v2 >= v2v2:
            gamma = 0.001
            cost = v2v2
            return gamma, cost
        gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
        cost = v2v2 + gamma * (v1v2 - v2v2)
        return gamma, cost

    @staticmethod
    def _min_norm_2d(vecs: list, dps: dict) -> tuple:
        """
        Find the minimum norm solution as a combination of two points in 2D.

        Args:
            vecs: List of vectors.
            dps: Dictionary to store dot products.

        Returns:
            tuple: Solution and updated dot products.
        """
        dmin = float("inf")
        for i in range(len(vecs)):
            for j in range(i + 1, len(vecs)):
                if (i, j) not in dps:
                    dps[(i, j)] = sum(
                        torch.dot(vecs[i][k], vecs[j][k]).item()
                        for k in range(len(vecs[i]))
                    )
                    dps[(j, i)] = dps[(i, j)]
                if (i, i) not in dps:
                    dps[(i, i)] = sum(
                        torch.dot(vecs[i][k], vecs[i][k]).item()
                        for k in range(len(vecs[i]))
                    )
                if (j, j) not in dps:
                    dps[(j, j)] = sum(
                        torch.dot(vecs[j][k], vecs[j][k]).item()
                        for k in range(len(vecs[i]))
                    )
                c, d = MinNormSolver._min_norm_element_from2(
                    dps[(i, i)], dps[(i, j)], dps[(j, j)]
                )
                if d < dmin:
                    dmin = d
                    sol = [(i, j), c, d]
        return sol, dps

    @staticmethod
    def _projection2simplex(y: np.ndarray) -> np.ndarray:
        """
        Project y onto the simplex.

        Args:
            y: Input array.

        Returns:
            Projected array.
        """
        m = len(y)
        sorted_y = np.flip(np.sort(y), axis=0)
        tmpsum = 0.0
        tmax_f = (np.sum(y) - 1.0) / m
        for i in range(m - 1):
            tmpsum += sorted_y[i]
            tmax = (tmpsum - 1) / (i + 1.0)
            if tmax > sorted_y[i + 1]:
                tmax_f = tmax
                break
        return np.maximum(y - tmax_f, np.zeros(y.shape))

    @staticmethod
    def _next_point(cur_val: np.ndarray, grad: np.ndarray, n: int) -> np.ndarray:
        """
        Compute the next point for the projected gradient descent.

        Args:
            cur_val: Current value.
            grad: Gradient.
            n: Dimension of the problem.

        Returns:
            Next point.
        """
        proj_grad = grad - (np.sum(grad) / n)
        tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
        tm2 = (1.0 - cur_val[proj_grad > 0]) / proj_grad[proj_grad > 0]

        t = 1
        if len(tm1[tm1 > 1e-7]) > 0:
            t = np.min(tm1[tm1 > 1e-7])
        if len(tm2[tm2 > 1e-7]) > 0:
            t = min(t, np.min(tm2[tm2 > 1e-7]))

        next_point = proj_grad * t + cur_val
        next_point = MinNormSolver._projection2simplex(next_point)
        return next_point

    @staticmethod
    def find_min_norm_element(vecs: List) -> Tuple | None:
        """
        Find the minimum norm element in the convex hull of vectors.

        Args:
            vecs: List of vectors.

        Returns:
            Minimum norm element and its cost.
        """
        dps = {}
        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)

        n = len(vecs)
        sol_vec = np.zeros(n)
        sol_vec[init_sol[0][0]] = init_sol[1]
        sol_vec[init_sol[0][1]] = 1 - init_sol[1]

        if n < 3:
            return sol_vec, init_sol[2]

        iter_count = 0

        grad_mat = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                grad_mat[i, j] = dps[(i, j)]

        while iter_count < MinNormSolver.MAX_ITER:
            grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
            new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
            v1v1 = sum(
                sol_vec[i] * sol_vec[j] * dps[(i, j)]
                for i in range(n)
                for j in range(n)
            )
            v1v2 = sum(
                sol_vec[i] * new_point[j] * dps[(i, j)]
                for i in range(n)
                for j in range(n)
            )
            v2v2 = sum(
                new_point[i] * new_point[j] * dps[(i, j)]
                for i in range(n)
                for j in range(n)
            )
            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
            new_sol_vec = nc * sol_vec + (1 - nc) * new_point
            change = new_sol_vec - sol_vec
            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
                return sol_vec, nd
            sol_vec = new_sol_vec
find_min_norm_element(vecs) staticmethod

Find the minimum norm element in the convex hull of vectors.

Parameters:

Name Type Description Default
vecs List

List of vectors.

required

Returns:

Type Description
Tuple | None

Minimum norm element and its cost.

Source code in vambn/modelling/mtl/mtl.py
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
@staticmethod
def find_min_norm_element(vecs: List) -> Tuple | None:
    """
    Find the minimum norm element in the convex hull of vectors.

    Args:
        vecs: List of vectors.

    Returns:
        Minimum norm element and its cost.
    """
    dps = {}
    init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)

    n = len(vecs)
    sol_vec = np.zeros(n)
    sol_vec[init_sol[0][0]] = init_sol[1]
    sol_vec[init_sol[0][1]] = 1 - init_sol[1]

    if n < 3:
        return sol_vec, init_sol[2]

    iter_count = 0

    grad_mat = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            grad_mat[i, j] = dps[(i, j)]

    while iter_count < MinNormSolver.MAX_ITER:
        grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
        new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
        v1v1 = sum(
            sol_vec[i] * sol_vec[j] * dps[(i, j)]
            for i in range(n)
            for j in range(n)
        )
        v1v2 = sum(
            sol_vec[i] * new_point[j] * dps[(i, j)]
            for i in range(n)
            for j in range(n)
        )
        v2v2 = sum(
            new_point[i] * new_point[j] * dps[(i, j)]
            for i in range(n)
            for j in range(n)
        )
        nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
        new_sol_vec = nc * sol_vec + (1 - nc) * new_point
        change = new_sol_vec - sol_vec
        if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
            return sol_vec, nd
        sol_vec = new_sol_vec

MtlMethods

Bases: Enum

Enumeration of available multi-task learning methods.

Source code in vambn/modelling/mtl/mtl.py
902
903
904
905
906
907
908
909
910
911
912
class MtlMethods(Enum):
    """Enumeration of available multi-task learning methods."""

    imtlg = IMTLG
    nsgd = NSGD
    gradnorm = GradNormModified
    pcgrad = PCGrad
    mgda_ub = MGDAUB
    identity = Identity
    cagrad = CAGrad
    graddrop = GradDrop

NSGD

Bases: MOOMethod

Normalized Stochastic Gradient Descent (NSGD) method for MOO.

Source code in vambn/modelling/mtl/mtl.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class NSGD(MOOMethod):
    """Normalized Stochastic Gradient Descent (NSGD) method for MOO."""

    initial_grads: torch.Tensor
    requires_input: bool = False

    def __init__(self, num_tasks: int, update_at: int = 20):
        """
        Initialize NSGD method.

        Args:
            num_tasks (int): Number of tasks.
            update_at (int): Update interval.
        """
        super().__init__()
        self.num_tasks = num_tasks
        self.update_at = update_at
        self.register_buffer("initial_grads", torch.ones(num_tasks))
        self.counter = 0

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using NSGD method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        grad_norms = grads.flatten(start_dim=1).norm(dim=1)

        if self.initial_grads is None or self.counter == self.update_at:
            self.initial_grads = grad_norms

        self.counter += 1

        conv_ratios = grad_norms / self.initial_grads.clamp_min(1e-15)
        alphas = conv_ratios / conv_ratios.sum().clamp_min(1e-15)
        alphas = alphas / alphas.sum()

        weighted_sum_norms = (alphas * grad_norms).sum()
        grads = batch_product(
            grads, weighted_sum_norms / grad_norms.clamp_min(1e-15)
        )
        return grads
__init__(num_tasks, update_at=20)

Initialize NSGD method.

Parameters:

Name Type Description Default
num_tasks int

Number of tasks.

required
update_at int

Update interval.

20
Source code in vambn/modelling/mtl/mtl.py
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(self, num_tasks: int, update_at: int = 20):
    """
    Initialize NSGD method.

    Args:
        num_tasks (int): Number of tasks.
        update_at (int): Update interval.
    """
    super().__init__()
    self.num_tasks = num_tasks
    self.update_at = update_at
    self.register_buffer("initial_grads", torch.ones(num_tasks))
    self.counter = 0
forward(grads, inputs, outputs)

Compute new gradients using NSGD method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using NSGD method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    grad_norms = grads.flatten(start_dim=1).norm(dim=1)

    if self.initial_grads is None or self.counter == self.update_at:
        self.initial_grads = grad_norms

    self.counter += 1

    conv_ratios = grad_norms / self.initial_grads.clamp_min(1e-15)
    alphas = conv_ratios / conv_ratios.sum().clamp_min(1e-15)
    alphas = alphas / alphas.sum()

    weighted_sum_norms = (alphas * grad_norms).sum()
    grads = batch_product(
        grads, weighted_sum_norms / grad_norms.clamp_min(1e-15)
    )
    return grads

PCGrad

Bases: MOOMethod

Projected Conflicting Gradient (PCGrad) method for MOO.

Attributes:

Name Type Description
requires_input bool

Indicates whether the method requires input tensor.

Source code in vambn/modelling/mtl/mtl.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
class PCGrad(MOOMethod):
    """Projected Conflicting Gradient (PCGrad) method for MOO.

    Attributes:
        requires_input (bool): Indicates whether the method requires input tensor.
    """

    requires_input: bool = False

    def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
        """
        Compute new gradients using PCGrad method.

        Args:
            grads (torch.Tensor): Gradients tensor.
            inputs (torch.Tensor): Input tensor.
            outputs (torch.Tensor): Output tensor.

        Returns:
            torch.Tensor: New gradients tensor.
        """
        size = grads.size()[1:]
        num_tasks = grads.size(0)
        grads_list = [g.flatten() for g in grads]

        new_grads = [None for _ in range(num_tasks)]
        for i in np.random.permutation(num_tasks):
            grad_i = grads_list[i]
            for j in np.random.permutation(num_tasks):
                if i == j:
                    continue
                grad_j = grads_list[j]
                if torch.cosine_similarity(grad_i, grad_j, dim=0) < 0:
                    grad_i = grad_i - projection(grad_i, grad_j)
                    assert id(grads_list[i]) != id(grad_i), "Aliasing!"
            new_grads[i] = grad_i.reshape(size)

        return torch.stack(new_grads, dim=0)
forward(grads, inputs, outputs)

Compute new gradients using PCGrad method.

Parameters:

Name Type Description Default
grads Tensor

Gradients tensor.

required
inputs Tensor

Input tensor.

required
outputs Tensor

Output tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: New gradients tensor.

Source code in vambn/modelling/mtl/mtl.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
def forward(self, grads: torch.Tensor, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
    """
    Compute new gradients using PCGrad method.

    Args:
        grads (torch.Tensor): Gradients tensor.
        inputs (torch.Tensor): Input tensor.
        outputs (torch.Tensor): Output tensor.

    Returns:
        torch.Tensor: New gradients tensor.
    """
    size = grads.size()[1:]
    num_tasks = grads.size(0)
    grads_list = [g.flatten() for g in grads]

    new_grads = [None for _ in range(num_tasks)]
    for i in np.random.permutation(num_tasks):
        grad_i = grads_list[i]
        for j in np.random.permutation(num_tasks):
            if i == j:
                continue
            grad_j = grads_list[j]
            if torch.cosine_similarity(grad_i, grad_j, dim=0) < 0:
                grad_i = grad_i - projection(grad_i, grad_j)
                assert id(grads_list[i]) != id(grad_i), "Aliasing!"
        new_grads[i] = grad_i.reshape(size)

    return torch.stack(new_grads, dim=0)

divide(numer, denom)

Numerically stable division.

Parameters:

Name Type Description Default
numer Tensor

Numerator tensor.

required
denom Tensor

Denominator tensor.

required

Returns:

Type Description

torch.Tensor: Result of numerically stable division.

Source code in vambn/modelling/mtl/mtl.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def divide(numer, denom):
    """
    Numerically stable division.

    Args:
        numer (torch.Tensor): Numerator tensor.
        denom (torch.Tensor): Denominator tensor.

    Returns:
        torch.Tensor: Result of numerically stable division.
    """
    epsilon = 1e-15
    return (
        torch.sign(numer)
        * torch.sign(denom)
        * torch.exp(
            torch.log(numer.abs() + epsilon) - torch.log(denom.abs() + epsilon)
        )
    )

gradient_normalizers(grads, losses, normalization_type)

Compute gradient normalizers based on the specified normalization type.

Parameters:

Name Type Description Default
grads dict

A dictionary of gradients.

required
losses dict

A dictionary of losses.

required
normalization_type str

The type of normalization ('l2', 'loss', 'loss+', 'none').

required

Returns:

Type Description
dict

A dictionary of gradient normalizers.

Source code in vambn/modelling/mtl/mtl.py
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
def gradient_normalizers(grads: dict, losses: dict, normalization_type: str) -> dict:
    """
    Compute gradient normalizers based on the specified normalization type.

    Args:
        grads: A dictionary of gradients.
        losses: A dictionary of losses.
        normalization_type: The type of normalization ('l2', 'loss', 'loss+', 'none').

    Returns:
        A dictionary of gradient normalizers.
    """
    gn = {}
    if normalization_type == "l2":
        for t in grads:
            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().item() for gr in grads[t]]))
    elif normalization_type == "loss":
        for t in grads:
            gn[t] = losses[t]
    elif normalization_type == "loss+":
        for t in grads:
            gn[t] = losses[t] * np.sqrt(
                np.sum([gr.pow(2).sum().item() for gr in grads[t]])
            )
    elif normalization_type == "none":
        for t in grads:
            gn[t] = 1.0
    else:
        print("ERROR: Invalid Normalization Type")
    return gn

norm(tensor)

Compute the L2 norm of a tensor along the last dimension.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor.

required

Returns:

Type Description

torch.Tensor: L2 norm of the input tensor.

Source code in vambn/modelling/mtl/mtl.py
24
25
26
27
28
29
30
31
32
33
34
def norm(tensor):
    """
    Compute the L2 norm of a tensor along the last dimension.

    Args:
        tensor (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: L2 norm of the input tensor.
    """
    return tensor.norm(p=2, dim=-1, keepdim=True)

projection(u, v)

Project vector u onto vector v.

Parameters:

Name Type Description Default
u Tensor

Vector to be projected.

required
v Tensor

Vector onto which u is projected.

required

Returns:

Type Description

torch.Tensor: Projection of u onto v.

Source code in vambn/modelling/mtl/mtl.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def projection(u, v):
    """
    Project vector u onto vector v.

    Args:
        u (torch.Tensor): Vector to be projected.
        v (torch.Tensor): Vector onto which u is projected.

    Returns:
        torch.Tensor: Projection of u onto v.
    """
    numer = torch.dot(u, v)
    denom = torch.dot(v, v)

    return numer / denom.clamp_min(1e-15) * v

unitary(tensor)

Normalize the tensor to unit norm.

Parameters:

Name Type Description Default
tensor Tensor

Input tensor.

required

Returns:

Type Description

torch.Tensor: Unitary (normalized) tensor.

Source code in vambn/modelling/mtl/mtl.py
58
59
60
61
62
63
64
65
66
67
68
def unitary(tensor):
    """
    Normalize the tensor to unit norm.

    Args:
        tensor (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Unitary (normalized) tensor.
    """
    return divide(tensor, norm(tensor) + 1e-15)

parameters

MtlMethodParams dataclass

Params and method description for multi-task learning.

Attributes:

Name Type Description
name str

Name of the MTL method.

update_at Optional[int]

Update interval, specific to certain methods.

alpha Optional[float]

Alpha parameter, specific to certain methods.

Source code in vambn/modelling/mtl/parameters.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class MtlMethodParams:
    """
    Params and method description for multi-task learning.

    Attributes:
        name (str): Name of the MTL method.
        update_at (Optional[int]): Update interval, specific to certain methods.
        alpha (Optional[float]): Alpha parameter, specific to certain methods.
    """

    name: str
    update_at: Optional[int] = None
    alpha: Optional[float] = None

    def __post_init__(self):
        """
        Post-initialization to set default values for specific methods.
        """
        if self.name == "nsgd":
            if self.update_at is None:
                self.update_at = 1
        elif self.name == "gradnorm":
            if self.update_at is None:
                self.update_at = 1
            if self.alpha is None:
                self.alpha = 1.0
        elif self.name == "pcgrad":
            if self.update_at is None:
                self.update_at = 1
        elif self.name == "cagrad":
            if self.alpha is None:
                self.alpha = 10
__post_init__()

Post-initialization to set default values for specific methods.

Source code in vambn/modelling/mtl/parameters.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __post_init__(self):
    """
    Post-initialization to set default values for specific methods.
    """
    if self.name == "nsgd":
        if self.update_at is None:
            self.update_at = 1
    elif self.name == "gradnorm":
        if self.update_at is None:
            self.update_at = 1
        if self.alpha is None:
            self.alpha = 1.0
    elif self.name == "pcgrad":
        if self.update_at is None:
            self.update_at = 1
    elif self.name == "cagrad":
        if self.alpha is None:
            self.alpha = 10

utils

This script includes code adapted from the 'impartial-vaes' repository with minor modifications. The original code can be found at: https://github.com/adrianjav/impartial-vaes

Credit to the original authors: Adrian Javaloy, Maryam Meghdadi, and Isabel Valera for their valuable work.

batch_product(batch, weight)

Multiplies each slice of the first dimension of batch by the corresponding scalar in the weight vector.

Parameters:

Name Type Description Default
batch Tensor

Tensor of size [B, ...].

required
weight Tensor

Tensor of size [B].

required

Returns:

Type Description

torch.Tensor: A tensor such that result[i] = batch[i] * weight[i].

Source code in vambn/modelling/mtl/utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
def batch_product(batch: torch.Tensor, weight: torch.Tensor):
    r"""
    Multiplies each slice of the first dimension of batch by the corresponding scalar in the weight vector.

    Args:
        batch (torch.Tensor): Tensor of size [B, ...].
        weight (torch.Tensor): Tensor of size [B].

    Returns:
        torch.Tensor: A tensor such that `result[i] = batch[i] * weight[i]`.
    """
    assert batch.size(0) == weight.size(0)
    return (batch.T * weight.T).T

run_model