Skip to content

HIVAE

config

HivaeConfig dataclass

Bases: ModelConfig

Configuration class for HIVAE models.

Attributes:

Name Type Description
name str

Name of the configuration. Typically the module name.

variable_types VarTypes

Definition of the variable types. See VarTypes.

dim_s int

Dimension of the latent space S.

dim_y int

Dimension of the Y space.

dim_z int

Dimension of the latent space Z.

mtl_methods Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file.

use_imputation_layer bool

Flag to use imputation layer.

dropout float

Dropout rate.

n_layers Optional[int]

Number of layers. Needed for longitudinal data.

num_timepoints int

Number of timepoints for longitudinal data. Default is 1.

Source code in vambn/modelling/models/hivae/config.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@dataclass
class HivaeConfig(ModelConfig):
    """Configuration class for HIVAE models.

    Attributes:
        name (str): Name of the configuration. Typically the module name.
        variable_types (VarTypes): Definition of the variable types. See VarTypes.
        dim_s (int): Dimension of the latent space S.
        dim_y (int): Dimension of the Y space.
        dim_z (int): Dimension of the latent space Z.
        mtl_methods (Tuple[str, ...]): Methods for multi-task learning. Tested 
            possibilities are combinations of "identity", "gradnorm", "graddrop". 
            Further implementations and details can be found in the mtl.py file.
        use_imputation_layer (bool): Flag to use imputation layer.
        dropout (float): Dropout rate.
        n_layers (Optional[int]): Number of layers. Needed for longitudinal data.
        num_timepoints (int): Number of timepoints for longitudinal data. Default is 1.
    """
    name: str
    variable_types: VarTypes
    dim_s: int
    dim_y: int
    dim_z: int
    mtl_methods: Tuple[str, ...]
    use_imputation_layer: bool
    dropout: float

    # Only needed for longitudinal data
    n_layers: Optional[int] = None
    num_timepoints: int = 1

    def __post_init__(self):
        """Converts mtl_methods to a tuple if it is a list."""
        if isinstance(self.mtl_methods, List):
            self.mtl_methods = tuple(self.mtl_methods)

    @cached_property
    def input_dim(self):
        """Gets the input dimension based on variable types.

        Returns:
            int: The input dimension.
        """
        return get_input_dim(self.variable_types)

    @cached_property
    def is_longitudinal(self) -> bool:
        """Checks if the data is longitudinal.

        Returns:
            bool: True if the data has more than one timepoint, False otherwise.
        """
        return self.num_timepoints > 1

input_dim cached property

Gets the input dimension based on variable types.

Returns:

Name Type Description
int

The input dimension.

is_longitudinal: bool cached property

Checks if the data is longitudinal.

Returns:

Name Type Description
bool bool

True if the data has more than one timepoint, False otherwise.

__post_init__()

Converts mtl_methods to a tuple if it is a list.

Source code in vambn/modelling/models/hivae/config.py
84
85
86
87
def __post_init__(self):
    """Converts mtl_methods to a tuple if it is a list."""
    if isinstance(self.mtl_methods, List):
        self.mtl_methods = tuple(self.mtl_methods)

ModularHivaeConfig dataclass

Bases: ModelConfig

Configuration class for Modular HIVAE models.

Attributes:

Name Type Description
module_config Tuple[DataModuleConfig]

Configuration for the data modules. See DataModuleConfig.

dim_s int | Tuple[int, ...] | Dict[str, int]

Dimension of the latent space S.

dim_z int

Dimension of the latent space Z.

dim_ys int

Dimension of the YS space.

dim_y int | Tuple[int, ...] | Dict[str, int]

Dimension of the Y space.

mtl_method Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file.

use_imputation_layer bool

Flag to use imputation layer.

dropout float

Dropout rate.

n_layers int

Number of layers.

shared_element str

Shared element type. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Default is "none".

Source code in vambn/modelling/models/hivae/config.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@dataclass
class ModularHivaeConfig(ModelConfig):
    """Configuration class for Modular HIVAE models.

    Attributes:
        module_config (Tuple[DataModuleConfig]): Configuration for the data modules. See DataModuleConfig.
        dim_s (int | Tuple[int, ...] | Dict[str, int]): Dimension of the latent space S.
        dim_z (int): Dimension of the latent space Z.
        dim_ys (int): Dimension of the YS space.
        dim_y (int | Tuple[int, ...] | Dict[str, int]): Dimension of the Y space.         
        mtl_method (Tuple[str, ...]): Methods for multi-task learning. Tested
            possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file.
        use_imputation_layer (bool): Flag to use imputation layer.
        dropout (float): Dropout rate.
        n_layers (int): Number of layers.
        shared_element (str): Shared element type. Possible values are "none", 
            "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl".
            Default is "none".
    """    
    module_config: Tuple[DataModuleConfig]
    dim_s: int | Tuple[int, ...] | Dict[str, int]
    dim_z: int
    dim_ys: int
    dim_y: int | Tuple[int, ...] | Dict[str, int]
    mtl_method: Tuple[str, ...]
    use_imputation_layer: bool
    dropout: float
    n_layers: int
    shared_element: str = "none"

    def __post_init__(self):
        """Validates and sets up the configuration after initialization.

        Raises:
            Exception: If the number of layers is less than 1.
        """
        if self.n_layers < 1:
            raise Exception("Number of layers must be at least 1")

        for module in self.module_config:
            module.n_layers = self.n_layers

__post_init__()

Validates and sets up the configuration after initialization.

Raises:

Type Description
Exception

If the number of layers is less than 1.

Source code in vambn/modelling/models/hivae/config.py
40
41
42
43
44
45
46
47
48
49
50
def __post_init__(self):
    """Validates and sets up the configuration after initialization.

    Raises:
        Exception: If the number of layers is less than 1.
    """
    if self.n_layers < 1:
        raise Exception("Number of layers must be at least 1")

    for module in self.module_config:
        module.n_layers = self.n_layers

decoder

Decoder

Bases: Module

HIVAE Decoder class.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects. See VarTypes in data/dataclasses.py.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
mtl_method Tuple[str, ...]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
decoder_shared Module

Shared decoder module. Defaults to nn.Identity().

Identity()

Attributes:

Name Type Description
prior_s_val Parameter

Prior distribution for s values.

prior_loc_z ModifiedLinear

Linear layer for z prior distribution.

s_dim int

Dimension of s space.

z_dim int

Dimension of z space.

y_dim int

Dimension of y space.

variable_types VarTypes

List of variable types.

decoder_shared Module

Shared decoder module.

internal_layer_norm Module

Layer normalization module.

heads ModuleList

List of head modules for each variable type.

mtl_methods Tuple[str, ...]

Methods for multi-task learning.

_mtl_module_y Module

Multi-task learning module for y.

_mtl_module_s Module

Multi-task learning module for s.

moo_block MultiMOOForLoop

Multi-task learning block.

_decoding bool

Decoding flag.

Source code in vambn/modelling/models/hivae/decoder.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class Decoder(nn.Module):
    """HIVAE Decoder class.

    Args:
        variable_types (VarTypes): List of VariableType objects. See VarTypes in
            data/dataclasses.py.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        mtl_method (Tuple[str, ...], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        decoder_shared (nn.Module, optional): Shared decoder module. Defaults to nn.Identity().

    Attributes:
        prior_s_val (nn.Parameter): Prior distribution for s values.
        prior_loc_z (ModifiedLinear): Linear layer for z prior distribution.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        variable_types (VarTypes): List of variable types.
        decoder_shared (nn.Module): Shared decoder module.
        internal_layer_norm (nn.Module): Layer normalization module.
        heads (nn.ModuleList): List of head modules for each variable type.
        mtl_methods (Tuple[str, ...]): Methods for multi-task learning.
        _mtl_module_y (nn.Module): Multi-task learning module for y.
        _mtl_module_s (nn.Module): Multi-task learning module for s.
        moo_block (moo.MultiMOOForLoop): Multi-task learning block.
        _decoding (bool): Decoding flag.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        s_dim: int,
        z_dim: int,
        y_dim: int,
        mtl_method: Tuple[str, ...] = ("identity",),
        decoder_shared: nn.Module = nn.Identity(),
    ):
        """
        Initialize the HIVAE Decoder.

        Args:
            variable_types (List[VariableType]): List of VariableType objects.
            s_dim (int): Dimension of s space.
            z_dim (int): Dimension of z space.
            y_dim (int): Dimension of y space.
            mtl_method (Tuple[str]): List of methods to use for multi-task learning.
        """
        super().__init__()

        # prior distributions
        self.prior_s_val = nn.Parameter(
            torch.ones(s_dim) / s_dim, requires_grad=False
        )
        self.prior_loc_z = ModifiedLinear(s_dim, z_dim, bias=True)

        self.s_dim = s_dim
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.variable_types = variable_types

        self.decoder_shared = decoder_shared
        self.internal_layer_norm = nn.Identity()  # nn.LayerNorm(self.z_dim)
        self.heads: List[
            RealHead | PosHead | CountHead | CatHead
        ] | nn.ModuleList = nn.ModuleList(
            [
                HEAD_DICT[variable_types[i].data_type](
                    variable_types[i], s_dim, z_dim, y_dim
                )
                for i in range(len(variable_types))
            ]
        )

        self.mtl_methods = mtl_method
        self._mtl_module_y: nn.Module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=len(self.heads),
        )
        self._mtl_module_s: nn.Module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=len(self.heads),
        )
        self.moo_block = moo.MultiMOOForLoop(
            len(self.heads),
            moo_methods=(self._mtl_module_y, self._mtl_module_s),
        )

        self._decoding = True

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self._decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self._decoding = value

    @cached_property
    def colnames(self) -> List[str]:
        """Gets the column names of the data.

        Returns:
            List[str]: List of column names.
        """
        return [var.name for var in self.variable_types]

    @property
    def prior_s(self) -> torch.distributions.OneHotCategorical:
        """Gets the prior distribution for s.

        Returns:
            torch.distributions.OneHotCategorical: Prior distribution for s.
        """
        return torch.distributions.OneHotCategorical(
            probs=self.prior_s_val, validate_args=False
        )

    def prior_z(self, loc: torch.Tensor) -> torch.distributions.Normal:
        """Gets the prior distribution for z.

        Args:
            loc (torch.Tensor): Location parameter for z.

        Returns:
            torch.distributions.Normal: Prior distribution for z.
        """
        return torch.distributions.Normal(loc, torch.ones_like(loc))

    def kl_s(self, encoder_output: EncoderOutput) -> torch.Tensor:
        """Computes the KL divergence for s.

        Args:
            encoder_output (EncoderOutput): Encoder output.

        Returns:
            torch.Tensor: KL divergence for s.
        """
        return torch.distributions.kl.kl_divergence(
            torch.distributions.OneHotCategorical(
                logits=encoder_output.logits_s, validate_args=False
            ),
            self.prior_s,
        )

    def kl_z(
        self,
        mean_qz: torch.Tensor,
        std_qz: torch.Tensor,
        mean_pz: torch.Tensor,
        std_pz: torch.Tensor,
    ) -> torch.Tensor:
        """Computes the KL divergence for z.

        Args:
            mean_qz (torch.Tensor): Mean of the posterior distribution.
            std_qz (torch.Tensor): Standard deviation of the posterior distribution.
            mean_pz (torch.Tensor): Mean of the prior distribution.
            std_pz (torch.Tensor): Standard deviation of the prior distribution.

        Returns:
            torch.Tensor: KL divergence for z.
        """
        return torch.distributions.kl.kl_divergence(
            torch.distributions.Normal(mean_qz, std_qz),
            torch.distributions.Normal(mean_pz, std_pz),
        ).sum(dim=-1)

    def _cat_samples(self, samples: List[torch.Tensor]) -> torch.Tensor:
        """Concatenates samples.

        Args:
            samples (List[torch.Tensor]): List of samples.

        Returns:
            torch.Tensor: Concatenated samples.

        Raises:
            ValueError: If no samples were drawn or if samples have an incorrect shape.
        """
        if len(samples) == 0 or all(x is None for x in samples):
            raise ValueError("No samples were drawn")
        else:
            sample_stack = torch.stack(samples, dim=1)
            if sample_stack.ndim == 2:
                return sample_stack
            elif sample_stack.ndim == 3 and sample_stack.shape[2] == 1:
                return sample_stack.squeeze(2)
            else:
                raise ValueError(
                    "Samples should be of shape (batch, features) or (batch, features, 1)"
                )

    def forward(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
        normalization_parameters: NormalizationParameters,
    ) -> HivaeOutput:
        """Forward pass of the decoder.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Mask for the data.
            encoder_output (EncoderOutput): Output from the encoder.
            normalization_parameters (NormalizationParameters): Parameters for normalization.

        Returns:
            HivaeOutput: Output from the decoder.
        """
        samples_z = encoder_output.samples_z
        decoder_representation = encoder_output.decoder_representation
        samples_s = encoder_output.samples_s

        # Obtaining the parameters for the decoder
        # decoder representation of shape batch x dim_z
        interim_decoder_representation = self.decoder_shared(
            decoder_representation
        )  # identity
        interim_decoder_representation = self.internal_layer_norm(
            interim_decoder_representation
        )  # identity
        if self.training:
            moo_out_s, moo_out_z = self.moo_block(
                samples_s, interim_decoder_representation
            )
            x_params = []
            for head, s_i, k_i in zip(self.heads, moo_out_s, moo_out_z):
                x_params.append(head(k_i, s_i))
            x_params = tuple(x_params)
        else:
            x_params = tuple(
                [
                    head(interim_decoder_representation, samples_s)
                    for head in self.heads
                ]
            )

        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_parameters
        )

        # Compute the likelihood and kl divergences
        log_probs: List[torch.Tensor] = [torch.Tensor([-1])] * len(
            self.variable_types
        )
        samples: List[torch.Tensor] = [torch.Tensor([-1])] * len(
            self.variable_types
        )
        for i, (x_i, m_i, head_i, params_i) in enumerate(
            zip(data.T, mask.T, self.heads, x_params)
        ):
            head_i.dist(params_i)
            log_probs[i] = head_i.log_prob(x_i) * m_i

            if not self.training and self.decoding:
                # draw samples for evaluation and decoding
                samples[i] = head_i.sample()
            elif not self.training and not self.decoding:
                # draw samples for evaluation
                samples[i] = head_i.sample()

        # Stack the log likelihoods
        log_prob = torch.stack(log_probs, dim=1)  # batch, features
        log_prob = log_prob.sum(dim=1)  # / (mask.sum(dim=1) + 1e-6)  # batch
        cat_samples = self._cat_samples(samples)  # batch, features

        # Compute the KL divergences
        # KL divergence for s
        # samples_s of shape (batch, dim_s)
        kl_s = self.kl_s(encoder_output)  # shape (batch,)

        # KL divergence for z
        pz_loc = self.prior_loc_z(encoder_output.samples_s)  # batch, dim_z
        mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
        mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
        kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)  # shape (batch,)

        loss = -torch.sum(log_prob - kl_s - kl_z) / (
            (torch.sum(mask) / mask.shape[-1]) + 1e-6
        )
        print(f"Loss: {loss}, num_samples: {mask.sum()}")
        return HivaeOutput(
            enc_s=samples_s,
            enc_z=samples_z,
            samples=cat_samples,
            loss=loss,
        )

    @torch.no_grad()
    def decode(
        self,
        encoding_z: torch.Tensor,
        encoding_s: torch.Tensor,
        normalization_params: NormalizationParameters,
    ) -> torch.Tensor:
        """Decoding logic for the decoder.

        Args:
            encoding_z (torch.Tensor): Encoding for z.
            encoding_s (torch.Tensor): Encoding for s.
            normalization_params (NormalizationParameters): Parameters for normalization.

        Returns:
            torch.Tensor: Decoded samples.

        Raises:
            ValueError: If no samples were drawn.
        """
        # Implement the decoding logic here
        assert not self.training, "Model should be in eval mode"
        # shared_y of shape (batch, dim_y)
        decoder_interim_representation = self.decoder_shared(
            encoding_z
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        if encoding_s.shape[1] != self.s_dim:
            encoding_s = F.one_hot(
                encoding_s.squeeze(1).long(), num_classes=self.s_dim
            ).float()  # batch, dim_s

        x_params = tuple(
            [
                head(decoder_interim_representation, encoding_s)
                for head in self.heads
            ]
        )
        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_params
        )
        _ = [
            head.dist(
                params=params,
            )
            for i, (head, params) in enumerate(zip(self.heads, x_params))
        ]
        if self.decoding:
            samples = [head.sample() for head in self.heads]
        else:
            samples = [head.sample() for head in self.heads]
        cat_samples = self._cat_samples(samples)
        if cat_samples is not None:
            return cat_samples
        else:
            raise ValueError("No samples were drawn")

colnames: List[str] cached property

Gets the column names of the data.

Returns:

Type Description
List[str]

List[str]: List of column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

prior_s: torch.distributions.OneHotCategorical property

Gets the prior distribution for s.

Returns:

Type Description
OneHotCategorical

torch.distributions.OneHotCategorical: Prior distribution for s.

__init__(variable_types, s_dim, z_dim, y_dim, mtl_method=('identity'), decoder_shared=nn.Identity())

Initialize the HIVAE Decoder.

Parameters:

Name Type Description Default
variable_types List[VariableType]

List of VariableType objects.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
mtl_method Tuple[str]

List of methods to use for multi-task learning.

('identity')
Source code in vambn/modelling/models/hivae/decoder.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    variable_types: VarTypes,
    s_dim: int,
    z_dim: int,
    y_dim: int,
    mtl_method: Tuple[str, ...] = ("identity",),
    decoder_shared: nn.Module = nn.Identity(),
):
    """
    Initialize the HIVAE Decoder.

    Args:
        variable_types (List[VariableType]): List of VariableType objects.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        mtl_method (Tuple[str]): List of methods to use for multi-task learning.
    """
    super().__init__()

    # prior distributions
    self.prior_s_val = nn.Parameter(
        torch.ones(s_dim) / s_dim, requires_grad=False
    )
    self.prior_loc_z = ModifiedLinear(s_dim, z_dim, bias=True)

    self.s_dim = s_dim
    self.z_dim = z_dim
    self.y_dim = y_dim
    self.variable_types = variable_types

    self.decoder_shared = decoder_shared
    self.internal_layer_norm = nn.Identity()  # nn.LayerNorm(self.z_dim)
    self.heads: List[
        RealHead | PosHead | CountHead | CatHead
    ] | nn.ModuleList = nn.ModuleList(
        [
            HEAD_DICT[variable_types[i].data_type](
                variable_types[i], s_dim, z_dim, y_dim
            )
            for i in range(len(variable_types))
        ]
    )

    self.mtl_methods = mtl_method
    self._mtl_module_y: nn.Module = moo.setup_moo(
        [MtlMethodParams(x) for x in mtl_method],
        num_tasks=len(self.heads),
    )
    self._mtl_module_s: nn.Module = moo.setup_moo(
        [MtlMethodParams(x) for x in mtl_method],
        num_tasks=len(self.heads),
    )
    self.moo_block = moo.MultiMOOForLoop(
        len(self.heads),
        moo_methods=(self._mtl_module_y, self._mtl_module_s),
    )

    self._decoding = True

decode(encoding_z, encoding_s, normalization_params)

Decoding logic for the decoder.

Parameters:

Name Type Description Default
encoding_z Tensor

Encoding for z.

required
encoding_s Tensor

Encoding for s.

required
normalization_params NormalizationParameters

Parameters for normalization.

required

Returns:

Type Description
Tensor

torch.Tensor: Decoded samples.

Raises:

Type Description
ValueError

If no samples were drawn.

Source code in vambn/modelling/models/hivae/decoder.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
@torch.no_grad()
def decode(
    self,
    encoding_z: torch.Tensor,
    encoding_s: torch.Tensor,
    normalization_params: NormalizationParameters,
) -> torch.Tensor:
    """Decoding logic for the decoder.

    Args:
        encoding_z (torch.Tensor): Encoding for z.
        encoding_s (torch.Tensor): Encoding for s.
        normalization_params (NormalizationParameters): Parameters for normalization.

    Returns:
        torch.Tensor: Decoded samples.

    Raises:
        ValueError: If no samples were drawn.
    """
    # Implement the decoding logic here
    assert not self.training, "Model should be in eval mode"
    # shared_y of shape (batch, dim_y)
    decoder_interim_representation = self.decoder_shared(
        encoding_z
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    if encoding_s.shape[1] != self.s_dim:
        encoding_s = F.one_hot(
            encoding_s.squeeze(1).long(), num_classes=self.s_dim
        ).float()  # batch, dim_s

    x_params = tuple(
        [
            head(decoder_interim_representation, encoding_s)
            for head in self.heads
        ]
    )
    x_params = Normalization.denormalize_params(
        x_params, self.variable_types, normalization_params
    )
    _ = [
        head.dist(
            params=params,
        )
        for i, (head, params) in enumerate(zip(self.heads, x_params))
    ]
    if self.decoding:
        samples = [head.sample() for head in self.heads]
    else:
        samples = [head.sample() for head in self.heads]
    cat_samples = self._cat_samples(samples)
    if cat_samples is not None:
        return cat_samples
    else:
        raise ValueError("No samples were drawn")

forward(data, mask, encoder_output, normalization_parameters)

Forward pass of the decoder.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Mask for the data.

required
encoder_output EncoderOutput

Output from the encoder.

required
normalization_parameters NormalizationParameters

Parameters for normalization.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Output from the decoder.

Source code in vambn/modelling/models/hivae/decoder.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def forward(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
    normalization_parameters: NormalizationParameters,
) -> HivaeOutput:
    """Forward pass of the decoder.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Mask for the data.
        encoder_output (EncoderOutput): Output from the encoder.
        normalization_parameters (NormalizationParameters): Parameters for normalization.

    Returns:
        HivaeOutput: Output from the decoder.
    """
    samples_z = encoder_output.samples_z
    decoder_representation = encoder_output.decoder_representation
    samples_s = encoder_output.samples_s

    # Obtaining the parameters for the decoder
    # decoder representation of shape batch x dim_z
    interim_decoder_representation = self.decoder_shared(
        decoder_representation
    )  # identity
    interim_decoder_representation = self.internal_layer_norm(
        interim_decoder_representation
    )  # identity
    if self.training:
        moo_out_s, moo_out_z = self.moo_block(
            samples_s, interim_decoder_representation
        )
        x_params = []
        for head, s_i, k_i in zip(self.heads, moo_out_s, moo_out_z):
            x_params.append(head(k_i, s_i))
        x_params = tuple(x_params)
    else:
        x_params = tuple(
            [
                head(interim_decoder_representation, samples_s)
                for head in self.heads
            ]
        )

    x_params = Normalization.denormalize_params(
        x_params, self.variable_types, normalization_parameters
    )

    # Compute the likelihood and kl divergences
    log_probs: List[torch.Tensor] = [torch.Tensor([-1])] * len(
        self.variable_types
    )
    samples: List[torch.Tensor] = [torch.Tensor([-1])] * len(
        self.variable_types
    )
    for i, (x_i, m_i, head_i, params_i) in enumerate(
        zip(data.T, mask.T, self.heads, x_params)
    ):
        head_i.dist(params_i)
        log_probs[i] = head_i.log_prob(x_i) * m_i

        if not self.training and self.decoding:
            # draw samples for evaluation and decoding
            samples[i] = head_i.sample()
        elif not self.training and not self.decoding:
            # draw samples for evaluation
            samples[i] = head_i.sample()

    # Stack the log likelihoods
    log_prob = torch.stack(log_probs, dim=1)  # batch, features
    log_prob = log_prob.sum(dim=1)  # / (mask.sum(dim=1) + 1e-6)  # batch
    cat_samples = self._cat_samples(samples)  # batch, features

    # Compute the KL divergences
    # KL divergence for s
    # samples_s of shape (batch, dim_s)
    kl_s = self.kl_s(encoder_output)  # shape (batch,)

    # KL divergence for z
    pz_loc = self.prior_loc_z(encoder_output.samples_s)  # batch, dim_z
    mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
    mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
    kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)  # shape (batch,)

    loss = -torch.sum(log_prob - kl_s - kl_z) / (
        (torch.sum(mask) / mask.shape[-1]) + 1e-6
    )
    print(f"Loss: {loss}, num_samples: {mask.sum()}")
    return HivaeOutput(
        enc_s=samples_s,
        enc_z=samples_z,
        samples=cat_samples,
        loss=loss,
    )

kl_s(encoder_output)

Computes the KL divergence for s.

Parameters:

Name Type Description Default
encoder_output EncoderOutput

Encoder output.

required

Returns:

Type Description
Tensor

torch.Tensor: KL divergence for s.

Source code in vambn/modelling/models/hivae/decoder.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def kl_s(self, encoder_output: EncoderOutput) -> torch.Tensor:
    """Computes the KL divergence for s.

    Args:
        encoder_output (EncoderOutput): Encoder output.

    Returns:
        torch.Tensor: KL divergence for s.
    """
    return torch.distributions.kl.kl_divergence(
        torch.distributions.OneHotCategorical(
            logits=encoder_output.logits_s, validate_args=False
        ),
        self.prior_s,
    )

kl_z(mean_qz, std_qz, mean_pz, std_pz)

Computes the KL divergence for z.

Parameters:

Name Type Description Default
mean_qz Tensor

Mean of the posterior distribution.

required
std_qz Tensor

Standard deviation of the posterior distribution.

required
mean_pz Tensor

Mean of the prior distribution.

required
std_pz Tensor

Standard deviation of the prior distribution.

required

Returns:

Type Description
Tensor

torch.Tensor: KL divergence for z.

Source code in vambn/modelling/models/hivae/decoder.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def kl_z(
    self,
    mean_qz: torch.Tensor,
    std_qz: torch.Tensor,
    mean_pz: torch.Tensor,
    std_pz: torch.Tensor,
) -> torch.Tensor:
    """Computes the KL divergence for z.

    Args:
        mean_qz (torch.Tensor): Mean of the posterior distribution.
        std_qz (torch.Tensor): Standard deviation of the posterior distribution.
        mean_pz (torch.Tensor): Mean of the prior distribution.
        std_pz (torch.Tensor): Standard deviation of the prior distribution.

    Returns:
        torch.Tensor: KL divergence for z.
    """
    return torch.distributions.kl.kl_divergence(
        torch.distributions.Normal(mean_qz, std_qz),
        torch.distributions.Normal(mean_pz, std_pz),
    ).sum(dim=-1)

prior_z(loc)

Gets the prior distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Location parameter for z.

required

Returns:

Type Description
Normal

torch.distributions.Normal: Prior distribution for z.

Source code in vambn/modelling/models/hivae/decoder.py
155
156
157
158
159
160
161
162
163
164
def prior_z(self, loc: torch.Tensor) -> torch.distributions.Normal:
    """Gets the prior distribution for z.

    Args:
        loc (torch.Tensor): Location parameter for z.

    Returns:
        torch.distributions.Normal: Prior distribution for z.
    """
    return torch.distributions.Normal(loc, torch.ones_like(loc))

LstmDecoder

Bases: Decoder

LSTM-based HIVAE Decoder class.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects. See VarTypes in data/dataclasses.py.

required
s_dim int

Dimension of s space.

required
z_dim int

Dimension of z space.

required
y_dim int

Dimension of y space.

required
num_timepoints int

Number of timepoints.

required
n_layers int

Number of LSTM layers. Defaults to 1.

1
mtl_method Tuple[str, ...]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
decoder_shared Module

Shared decoder module. Defaults to nn.Identity().

Identity()

Attributes:

Name Type Description
num_timepoints int

Number of timepoints.

n_layers int

Number of LSTM layers.

lstm_decoder LSTM

LSTM decoder module.

fc Linear

Fully connected layer.

Source code in vambn/modelling/models/hivae/decoder.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
class LstmDecoder(Decoder):
    """LSTM-based HIVAE Decoder class.

    Args:
        variable_types (VarTypes): List of VariableType objects. See VarTypes in
            data/dataclasses.py.
        s_dim (int): Dimension of s space.
        z_dim (int): Dimension of z space.
        y_dim (int): Dimension of y space.
        num_timepoints (int): Number of timepoints.
        n_layers (int, optional): Number of LSTM layers. Defaults to 1.
        mtl_method (Tuple[str, ...], optional): List of methods to use for multi-task learning. 
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".  
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        decoder_shared (nn.Module, optional): Shared decoder module. Defaults to nn.Identity().

    Attributes:
        num_timepoints (int): Number of timepoints.
        n_layers (int): Number of LSTM layers.
        lstm_decoder (nn.LSTM): LSTM decoder module.
        fc (nn.Linear): Fully connected layer.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        s_dim: int,
        z_dim: int,
        y_dim: int,
        num_timepoints: int,
        n_layers: int = 1,
        mtl_method: Tuple[str, ...] = ("identity",),
        decoder_shared: nn.Module = nn.Identity(),
    ) -> None:
        """HIVAE Decoder

        Args:
            type_array (np.ndarray): Array containing the data type information (type, class, ndim)
            s_dim (int): Dimension of s space
            z_dim (int): Dimension of z space
            y_dim (int): Dimension of y space
            num_timepoints (int): Number of num_timepoints.
            mtl_method (Tuple[str], optional): List of methods to use for multi-task learning. Defaults to
                ["gradnorm", "pcgrad"].
        """
        super().__init__(
            variable_types=variable_types,
            s_dim=s_dim,
            z_dim=z_dim,
            y_dim=y_dim,
            mtl_method=mtl_method,
            decoder_shared=decoder_shared,
        )
        self.num_timepoints = num_timepoints
        self.n_layers = n_layers
        self.lstm_decoder = nn.LSTM(
            input_size=z_dim,
            hidden_size=z_dim,
            num_layers=self.n_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(z_dim, z_dim)

    def forward(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
        normalization_parameters: NormalizationParameters,
    ) -> HivaeOutput:
        """Forward pass of the LSTM decoder.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Mask for the data.
            encoder_output (EncoderOutput): Output from the encoder.
            normalization_parameters (NormalizationParameters): Parameters for normalization.

        Returns:
            HivaeOutput: Output from the decoder.
        """
        samples_z = encoder_output.samples_z
        decoder_representation = encoder_output.decoder_representation
        samples_s = encoder_output.samples_s

        # Obtaining the parameters for the decoder
        # decoder representation of shape batch x dim_z
        decoder_interim_representation = self.decoder_shared(
            decoder_representation
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        h0 = self.fc(decoder_interim_representation).repeat(self.n_layers, 1, 1)
        c0 = torch.zeros_like(h0)
        y_repeated = decoder_interim_representation.unsqueeze(1).repeat(
            1, self.num_timepoints, 1
        )
        decoder_interim_representation, _ = self.lstm_decoder(
            y_repeated, (h0, c0)
        )

        log_probs = [None] * self.num_timepoints
        samples = [None] * self.num_timepoints
        for t in range(self.num_timepoints):
            sub_data = data[:, t, :]
            sub_mask = mask[:, t, :]
            if self.training:
                x_params = tuple(
                    [
                        head(y_i, s_i)
                        for head, s_i, y_i in zip(  # type: ignore
                            self.heads,
                            *self.moo_block.forward(
                                samples_s,
                                decoder_interim_representation[:, t, :],
                            ),
                        )
                    ]
                )
            else:
                x_params = tuple(
                    [
                        head(decoder_interim_representation[:, t, :], samples_s)
                        for head in self.heads
                    ]
                )

            x_params = Normalization.denormalize_params(
                x_params, self.variable_types, normalization_parameters[t]
            )

            log_probs[t] = torch.stack(
                [
                    head_i.log_prob(
                        params=params_i,
                        data=d_i,
                    )
                    * m_i
                    for i, (head_i, params_i, d_i, m_i) in enumerate(
                        zip(self.heads, x_params, sub_data.T, sub_mask.T)
                    )
                ],
                dim=1,
            ).sum(-1)  # batch
            assert isinstance(
                log_probs[t], torch.Tensor
            ), f"Log probs: {log_probs[t]}"
            assert log_probs[t].shape == (
                encoder_output.samples_s.shape[0],
            ), f"Log probs shape: {log_probs[t].shape}"

            if not self.training and self.decoding:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )
            elif not self.training and not self.decoding:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )

        log_prob = torch.stack(log_probs, dim=1)  # batch, timepoints
        log_prob = log_prob.sum(
            dim=1
        )  # / (mask.sum(dim=(1, 2)) + 1e-9)  # batch
        if torch.isnan(log_prob).any() or torch.isinf(log_prob).any():
            raise ValueError("Log likelihood is nan or inf")
        samples = (
            torch.stack(samples, dim=1) if samples[0] is not None else None
        )

        # Compute the KL divergences
        # # KL divergence for s
        kl_s = self.kl_s(encoder_output)
        if torch.isnan(kl_s).any() or torch.isinf(kl_s).any():
            raise ValueError("KL divergence for s is nan or inf")

        # KL divergence for z
        pz_loc = self.prior_loc_z(encoder_output.samples_s)
        mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
        mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
        kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)

        assert kl_z.shape == (encoder_output.samples_s.shape[0],)
        if torch.isnan(kl_z).any() or torch.isinf(kl_z).any():
            raise ValueError("KL divergence for z is nan or inf")

        # Compute the loss
        loss = -torch.sum(log_prob - kl_s - kl_z) / (
            torch.sum(mask) / mask.shape[-1]
        )

        return HivaeOutput(
            enc_s=encoder_output.samples_s,
            enc_z=samples_z,
            samples=samples,
            loss=loss,
        )

    @torch.no_grad()
    def decode(
        self,
        encoding_z: torch.Tensor,
        encoding_s: torch.Tensor,
        normalization_params: NormalizationParameters,
    ):
        """Decoding logic for 3D data.

        Args:
            encoding_z (torch.Tensor): Encoding for z.
            encoding_s (torch.Tensor): Encoding for s.
            normalization_params (NormalizationParameters): Parameters for normalization.

        Returns:
            torch.Tensor: Decoded samples.
        """
        assert not self.training, "Model should be in eval mode"

        # shared_y of shape (batch, dim_y)
        decoder_interim_representation = self.decoder_shared(
            encoding_z
        )  # identity
        decoder_interim_representation = self.internal_layer_norm(
            decoder_interim_representation
        )  # identity

        decoder_interim_representation = (
            decoder_interim_representation.unsqueeze(1).repeat(
                1, self.num_timepoints, 1
            )
        )
        decoder_interim_representation, _ = self.lstm_decoder(
            decoder_interim_representation
        )

        if encoding_s.shape[1] != self.s_dim:
            encoding_s = F.one_hot(
                encoding_s.squeeze(1).long(), num_classes=self.s_dim
            ).float()

        samples = [None] * self.num_timepoints
        for t in range(self.num_timepoints):
            x_params = tuple(
                [
                    head(decoder_interim_representation[:, t, :], encoding_s)
                    for head in self.heads
                ]
            )
            x_params = Normalization.denormalize_params(
                x_params, self.variable_types, normalization_params[t]
            )
            [head.dist(params) for head, params in zip(self.heads, x_params)]
            if self.decoding:
                samples[t] = self._cat_samples(
                    [head.sample() for head in self.heads]
                )
            else:
                samples[t] = self._cat_samples(
                    [
                        head.sample()
                        for head, params in zip(self.heads, x_params)
                    ]
                )
        time_dim_samples = torch.stack(samples, dim=1)
        return time_dim_samples

__init__(variable_types, s_dim, z_dim, y_dim, num_timepoints, n_layers=1, mtl_method=('identity'), decoder_shared=nn.Identity())

HIVAE Decoder

Parameters:

Name Type Description Default
type_array ndarray

Array containing the data type information (type, class, ndim)

required
s_dim int

Dimension of s space

required
z_dim int

Dimension of z space

required
y_dim int

Dimension of y space

required
num_timepoints int

Number of num_timepoints.

required
mtl_method Tuple[str]

List of methods to use for multi-task learning. Defaults to ["gradnorm", "pcgrad"].

('identity')
Source code in vambn/modelling/models/hivae/decoder.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def __init__(
    self,
    variable_types: VarTypes,
    s_dim: int,
    z_dim: int,
    y_dim: int,
    num_timepoints: int,
    n_layers: int = 1,
    mtl_method: Tuple[str, ...] = ("identity",),
    decoder_shared: nn.Module = nn.Identity(),
) -> None:
    """HIVAE Decoder

    Args:
        type_array (np.ndarray): Array containing the data type information (type, class, ndim)
        s_dim (int): Dimension of s space
        z_dim (int): Dimension of z space
        y_dim (int): Dimension of y space
        num_timepoints (int): Number of num_timepoints.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning. Defaults to
            ["gradnorm", "pcgrad"].
    """
    super().__init__(
        variable_types=variable_types,
        s_dim=s_dim,
        z_dim=z_dim,
        y_dim=y_dim,
        mtl_method=mtl_method,
        decoder_shared=decoder_shared,
    )
    self.num_timepoints = num_timepoints
    self.n_layers = n_layers
    self.lstm_decoder = nn.LSTM(
        input_size=z_dim,
        hidden_size=z_dim,
        num_layers=self.n_layers,
        batch_first=True,
    )
    self.fc = nn.Linear(z_dim, z_dim)

decode(encoding_z, encoding_s, normalization_params)

Decoding logic for 3D data.

Parameters:

Name Type Description Default
encoding_z Tensor

Encoding for z.

required
encoding_s Tensor

Encoding for s.

required
normalization_params NormalizationParameters

Parameters for normalization.

required

Returns:

Type Description

torch.Tensor: Decoded samples.

Source code in vambn/modelling/models/hivae/decoder.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
@torch.no_grad()
def decode(
    self,
    encoding_z: torch.Tensor,
    encoding_s: torch.Tensor,
    normalization_params: NormalizationParameters,
):
    """Decoding logic for 3D data.

    Args:
        encoding_z (torch.Tensor): Encoding for z.
        encoding_s (torch.Tensor): Encoding for s.
        normalization_params (NormalizationParameters): Parameters for normalization.

    Returns:
        torch.Tensor: Decoded samples.
    """
    assert not self.training, "Model should be in eval mode"

    # shared_y of shape (batch, dim_y)
    decoder_interim_representation = self.decoder_shared(
        encoding_z
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    decoder_interim_representation = (
        decoder_interim_representation.unsqueeze(1).repeat(
            1, self.num_timepoints, 1
        )
    )
    decoder_interim_representation, _ = self.lstm_decoder(
        decoder_interim_representation
    )

    if encoding_s.shape[1] != self.s_dim:
        encoding_s = F.one_hot(
            encoding_s.squeeze(1).long(), num_classes=self.s_dim
        ).float()

    samples = [None] * self.num_timepoints
    for t in range(self.num_timepoints):
        x_params = tuple(
            [
                head(decoder_interim_representation[:, t, :], encoding_s)
                for head in self.heads
            ]
        )
        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_params[t]
        )
        [head.dist(params) for head, params in zip(self.heads, x_params)]
        if self.decoding:
            samples[t] = self._cat_samples(
                [head.sample() for head in self.heads]
            )
        else:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )
    time_dim_samples = torch.stack(samples, dim=1)
    return time_dim_samples

forward(data, mask, encoder_output, normalization_parameters)

Forward pass of the LSTM decoder.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Mask for the data.

required
encoder_output EncoderOutput

Output from the encoder.

required
normalization_parameters NormalizationParameters

Parameters for normalization.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Output from the decoder.

Source code in vambn/modelling/models/hivae/decoder.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def forward(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
    normalization_parameters: NormalizationParameters,
) -> HivaeOutput:
    """Forward pass of the LSTM decoder.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Mask for the data.
        encoder_output (EncoderOutput): Output from the encoder.
        normalization_parameters (NormalizationParameters): Parameters for normalization.

    Returns:
        HivaeOutput: Output from the decoder.
    """
    samples_z = encoder_output.samples_z
    decoder_representation = encoder_output.decoder_representation
    samples_s = encoder_output.samples_s

    # Obtaining the parameters for the decoder
    # decoder representation of shape batch x dim_z
    decoder_interim_representation = self.decoder_shared(
        decoder_representation
    )  # identity
    decoder_interim_representation = self.internal_layer_norm(
        decoder_interim_representation
    )  # identity

    h0 = self.fc(decoder_interim_representation).repeat(self.n_layers, 1, 1)
    c0 = torch.zeros_like(h0)
    y_repeated = decoder_interim_representation.unsqueeze(1).repeat(
        1, self.num_timepoints, 1
    )
    decoder_interim_representation, _ = self.lstm_decoder(
        y_repeated, (h0, c0)
    )

    log_probs = [None] * self.num_timepoints
    samples = [None] * self.num_timepoints
    for t in range(self.num_timepoints):
        sub_data = data[:, t, :]
        sub_mask = mask[:, t, :]
        if self.training:
            x_params = tuple(
                [
                    head(y_i, s_i)
                    for head, s_i, y_i in zip(  # type: ignore
                        self.heads,
                        *self.moo_block.forward(
                            samples_s,
                            decoder_interim_representation[:, t, :],
                        ),
                    )
                ]
            )
        else:
            x_params = tuple(
                [
                    head(decoder_interim_representation[:, t, :], samples_s)
                    for head in self.heads
                ]
            )

        x_params = Normalization.denormalize_params(
            x_params, self.variable_types, normalization_parameters[t]
        )

        log_probs[t] = torch.stack(
            [
                head_i.log_prob(
                    params=params_i,
                    data=d_i,
                )
                * m_i
                for i, (head_i, params_i, d_i, m_i) in enumerate(
                    zip(self.heads, x_params, sub_data.T, sub_mask.T)
                )
            ],
            dim=1,
        ).sum(-1)  # batch
        assert isinstance(
            log_probs[t], torch.Tensor
        ), f"Log probs: {log_probs[t]}"
        assert log_probs[t].shape == (
            encoder_output.samples_s.shape[0],
        ), f"Log probs shape: {log_probs[t].shape}"

        if not self.training and self.decoding:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )
        elif not self.training and not self.decoding:
            samples[t] = self._cat_samples(
                [
                    head.sample()
                    for head, params in zip(self.heads, x_params)
                ]
            )

    log_prob = torch.stack(log_probs, dim=1)  # batch, timepoints
    log_prob = log_prob.sum(
        dim=1
    )  # / (mask.sum(dim=(1, 2)) + 1e-9)  # batch
    if torch.isnan(log_prob).any() or torch.isinf(log_prob).any():
        raise ValueError("Log likelihood is nan or inf")
    samples = (
        torch.stack(samples, dim=1) if samples[0] is not None else None
    )

    # Compute the KL divergences
    # # KL divergence for s
    kl_s = self.kl_s(encoder_output)
    if torch.isnan(kl_s).any() or torch.isinf(kl_s).any():
        raise ValueError("KL divergence for s is nan or inf")

    # KL divergence for z
    pz_loc = self.prior_loc_z(encoder_output.samples_s)
    mean_pz, std_pz = pz_loc, torch.ones_like(pz_loc)
    mean_qz, std_qz = encoder_output.mean_z, encoder_output.scale_z
    kl_z = self.kl_z(mean_qz, std_qz, mean_pz, std_pz)

    assert kl_z.shape == (encoder_output.samples_s.shape[0],)
    if torch.isnan(kl_z).any() or torch.isinf(kl_z).any():
        raise ValueError("KL divergence for z is nan or inf")

    # Compute the loss
    loss = -torch.sum(log_prob - kl_s - kl_z) / (
        torch.sum(mask) / mask.shape[-1]
    )

    return HivaeOutput(
        enc_s=encoder_output.samples_s,
        enc_z=samples_z,
        samples=samples,
        loss=loss,
    )

encoder

Encoder

Bases: Module

HIVAE Encoder.

Parameters:

Name Type Description Default
input_dim int

Dimension of input data (e.g., columns in dataframe).

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required

Attributes:

Name Type Description
input_dim int

Dimension of input data.

dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

encoder_s ModifiedLinear

Linear layer for s encoding.

encoder_z Module

Identity layer for z encoding.

param_z ModifiedLinear

Linear layer for z parameterization.

_tau float

Temperature parameter for Gumbel softmax.

_decoding bool

Flag indicating whether the model is in decoding mode.

Source code in vambn/modelling/models/hivae/encoder.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class Encoder(nn.Module):
    """HIVAE Encoder.

    Args:
        input_dim (int): Dimension of input data (e.g., columns in dataframe).
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.

    Attributes:
        input_dim (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        encoder_s (ModifiedLinear): Linear layer for s encoding.
        encoder_z (nn.Module): Identity layer for z encoding.
        param_z (ModifiedLinear): Linear layer for z parameterization.
        _tau (float): Temperature parameter for Gumbel softmax.
        _decoding (bool): Flag indicating whether the model is in decoding mode.
    """


    def __init__(self, input_dim: int, dim_s: int, dim_z: int):
        """HIVAE Encoder

        Args:
            input_dim (int): Dimension of input data (e.g. columns in dataframe)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
        """
        super().__init__()
        if input_dim <= 0:
            raise ValueError(
                f"Input dimension must be positive, got {input_dim}"
            )

        if dim_s <= 0:
            raise ValueError(f"S dimension must be positive, got {dim_s}")

        if dim_z <= 0:
            raise ValueError(f"Z dimension must be positive, got {dim_z}")

        self.input_dim = input_dim
        self.dim_s = dim_s
        self.dim_z = dim_z

        self.encoder_s = ModifiedLinear(self.input_dim, dim_s, bias=True)
        self.encoder_z = nn.Identity()
        self.param_z = ModifiedLinear(
            self.input_dim + dim_s, dim_z * 2, bias=True
        )

        self._tau = 1.0
        self._decoding = True

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self._decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self._decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self._tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.

        Raises:
            ValueError: If the temperature value is not positive.
        """
        if value <= 0:
            raise ValueError(f"Tau must be positive, got {value}")
        self._tau = value

    @staticmethod
    def q_z(loc, scale: Tensor) -> dists.Normal:
        """Creates a normal distribution for z.

        Args:
            loc (Tensor): Mean of the distribution.
            scale (Tensor): Standard deviation of the distribution.

        Returns:
            dists.Normal: Normal distribution.
        """
        return dists.Normal(loc, scale)

    def q_s(self, probs: Tensor) -> GumbelDistribution:
        """Creates a Gumbel distribution for s.

        Args:
            probs (Tensor): Probabilities for the Gumbel distribution.

        Returns:
            GumbelDistribution: Gumbel distribution.
        """
        return GumbelDistribution(probs=probs, temperature=self.tau)

    def forward(self, x: Tensor) -> EncoderOutput:
        """Forward pass of the encoder.

        Args:
            x (Tensor): Normalized input data.

        Raises:
            Exception: If samples contain NaN values.

        Returns:
            EncoderOutput: Contains samples, logits, and parameters.
        """
        logits_s = self.encoder_s(x)
        probs_s = F.softmax(logits_s, dim=-1).clamp(1e-6, 1 - 1e-6)

        if self.training:
            samples_s = self.q_s(probs_s).rsample()
        elif self.decoding:
            # NOTE:
            # The idea behind this was that we can use e.g. the mode during decoding
            # Experiments indicate that this is not helpful; the correlation between
            # the real and decoded data is improved, but the JSD is significantly worse.
            samples_s = self.q_s(probs_s).sample()
        else:
            samples_s = self.q_s(probs_s).sample()

        x_and_s = torch.cat((x, samples_s), dim=-1)

        h = self.encoder_z(x_and_s)

        loc_z, scale_z_unnorm = torch.chunk(self.param_z(h), 2, dim=-1)
        scale_z = F.softplus(scale_z_unnorm) + 1e-6

        samples_z = (
            self.q_z(loc_z, scale_z).rsample()
            if self.training
            else self.q_z(loc_z, scale_z).sample()
            if self.decoding
            else self.q_z(loc_z, scale_z).sample()
        )

        return EncoderOutput(
            samples_s=samples_s,
            samples_z=samples_z,
            logits_s=logits_s,
            mean_z=loc_z,
            scale_z=scale_z,
        )

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

__init__(input_dim, dim_s, dim_z)

HIVAE Encoder

Parameters:

Name Type Description Default
input_dim int

Dimension of input data (e.g. columns in dataframe)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
Source code in vambn/modelling/models/hivae/encoder.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, input_dim: int, dim_s: int, dim_z: int):
    """HIVAE Encoder

    Args:
        input_dim (int): Dimension of input data (e.g. columns in dataframe)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
    """
    super().__init__()
    if input_dim <= 0:
        raise ValueError(
            f"Input dimension must be positive, got {input_dim}"
        )

    if dim_s <= 0:
        raise ValueError(f"S dimension must be positive, got {dim_s}")

    if dim_z <= 0:
        raise ValueError(f"Z dimension must be positive, got {dim_z}")

    self.input_dim = input_dim
    self.dim_s = dim_s
    self.dim_z = dim_z

    self.encoder_s = ModifiedLinear(self.input_dim, dim_s, bias=True)
    self.encoder_z = nn.Identity()
    self.param_z = ModifiedLinear(
        self.input_dim + dim_s, dim_z * 2, bias=True
    )

    self._tau = 1.0
    self._decoding = True

forward(x)

Forward pass of the encoder.

Parameters:

Name Type Description Default
x Tensor

Normalized input data.

required

Raises:

Type Description
Exception

If samples contain NaN values.

Returns:

Name Type Description
EncoderOutput EncoderOutput

Contains samples, logits, and parameters.

Source code in vambn/modelling/models/hivae/encoder.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def forward(self, x: Tensor) -> EncoderOutput:
    """Forward pass of the encoder.

    Args:
        x (Tensor): Normalized input data.

    Raises:
        Exception: If samples contain NaN values.

    Returns:
        EncoderOutput: Contains samples, logits, and parameters.
    """
    logits_s = self.encoder_s(x)
    probs_s = F.softmax(logits_s, dim=-1).clamp(1e-6, 1 - 1e-6)

    if self.training:
        samples_s = self.q_s(probs_s).rsample()
    elif self.decoding:
        # NOTE:
        # The idea behind this was that we can use e.g. the mode during decoding
        # Experiments indicate that this is not helpful; the correlation between
        # the real and decoded data is improved, but the JSD is significantly worse.
        samples_s = self.q_s(probs_s).sample()
    else:
        samples_s = self.q_s(probs_s).sample()

    x_and_s = torch.cat((x, samples_s), dim=-1)

    h = self.encoder_z(x_and_s)

    loc_z, scale_z_unnorm = torch.chunk(self.param_z(h), 2, dim=-1)
    scale_z = F.softplus(scale_z_unnorm) + 1e-6

    samples_z = (
        self.q_z(loc_z, scale_z).rsample()
        if self.training
        else self.q_z(loc_z, scale_z).sample()
        if self.decoding
        else self.q_z(loc_z, scale_z).sample()
    )

    return EncoderOutput(
        samples_s=samples_s,
        samples_z=samples_z,
        logits_s=logits_s,
        mean_z=loc_z,
        scale_z=scale_z,
    )

q_s(probs)

Creates a Gumbel distribution for s.

Parameters:

Name Type Description Default
probs Tensor

Probabilities for the Gumbel distribution.

required

Returns:

Name Type Description
GumbelDistribution GumbelDistribution

Gumbel distribution.

Source code in vambn/modelling/models/hivae/encoder.py
117
118
119
120
121
122
123
124
125
126
def q_s(self, probs: Tensor) -> GumbelDistribution:
    """Creates a Gumbel distribution for s.

    Args:
        probs (Tensor): Probabilities for the Gumbel distribution.

    Returns:
        GumbelDistribution: Gumbel distribution.
    """
    return GumbelDistribution(probs=probs, temperature=self.tau)

q_z(loc, scale) staticmethod

Creates a normal distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Mean of the distribution.

required
scale Tensor

Standard deviation of the distribution.

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution.

Source code in vambn/modelling/models/hivae/encoder.py
104
105
106
107
108
109
110
111
112
113
114
115
@staticmethod
def q_z(loc, scale: Tensor) -> dists.Normal:
    """Creates a normal distribution for z.

    Args:
        loc (Tensor): Mean of the distribution.
        scale (Tensor): Standard deviation of the distribution.

    Returns:
        dists.Normal: Normal distribution.
    """
    return dists.Normal(loc, scale)

LstmEncoder

Bases: Module

Encoder for longitudinal input.

Parameters:

Name Type Description Default
input_dimension int

Dimension of input data.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
n_layers int

Number of LSTM layers.

required
hidden_size Optional[int]

Size of the hidden layer. Defaults to None.

None

Attributes:

Name Type Description
dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

hidden_size int

Size of the hidden layer.

lstm LSTM

LSTM layer.

encoder Encoder

Encoder module.

Source code in vambn/modelling/models/hivae/encoder.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
class LstmEncoder(nn.Module):
    """Encoder for longitudinal input.

    Args:
        input_dimension (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        n_layers (int): Number of LSTM layers.
        hidden_size (Optional[int], optional): Size of the hidden layer. Defaults to None.

    Attributes:
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        hidden_size (int): Size of the hidden layer.
        lstm (nn.LSTM): LSTM layer.
        encoder (Encoder): Encoder module.
    """

    @typeguard.typechecked
    def __init__(
        self,
        input_dimension: int,
        dim_s: int,
        dim_z: int,
        n_layers: int,
        hidden_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.dim_s = dim_s
        self.dim_z = dim_z

        if hidden_size is None:
            hidden_size = input_dimension

        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(
            input_size=input_dimension,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
        )
        self.encoder = Encoder(input_dim=hidden_size, dim_s=dim_s, dim_z=dim_z)

    def forward(self, input_data: Tensor) -> EncoderOutput:
        """Forward pass of the LSTM encoder.

        Args:
            input_data (Tensor): Time points/visits x batch size x variables/input size.

        Returns:
            EncoderOutput: Output for each time point.
        """
        out, (_, _) = self.lstm(input_data)
        # out: batch_size x time points x hidden size
        last_out = out[:, -1, :]
        encoder_output = self.encoder.forward(last_out)

        return encoder_output

    def q_z(self, loc, scale: Tensor) -> dists.Normal:
        """Creates a normal distribution for z.

        Args:
            loc (Tensor): Mean of the distribution.
            scale (Tensor): Standard deviation of the distribution.

        Returns:
            dists.Normal: Normal distribution.
        """
        return self.encoder.q_z(loc, scale)

    def q_s(self, probs: Tensor) -> GumbelDistribution:
        """Creates a Gumbel distribution for s.

        Args:
            probs (Tensor): Probabilities for the Gumbel distribution.

        Returns:
            GumbelDistribution: Gumbel distribution.
        """
        return self.encoder.q_s(probs)

forward(input_data)

Forward pass of the LSTM encoder.

Parameters:

Name Type Description Default
input_data Tensor

Time points/visits x batch size x variables/input size.

required

Returns:

Name Type Description
EncoderOutput EncoderOutput

Output for each time point.

Source code in vambn/modelling/models/hivae/encoder.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def forward(self, input_data: Tensor) -> EncoderOutput:
    """Forward pass of the LSTM encoder.

    Args:
        input_data (Tensor): Time points/visits x batch size x variables/input size.

    Returns:
        EncoderOutput: Output for each time point.
    """
    out, (_, _) = self.lstm(input_data)
    # out: batch_size x time points x hidden size
    last_out = out[:, -1, :]
    encoder_output = self.encoder.forward(last_out)

    return encoder_output

q_s(probs)

Creates a Gumbel distribution for s.

Parameters:

Name Type Description Default
probs Tensor

Probabilities for the Gumbel distribution.

required

Returns:

Name Type Description
GumbelDistribution GumbelDistribution

Gumbel distribution.

Source code in vambn/modelling/models/hivae/encoder.py
249
250
251
252
253
254
255
256
257
258
def q_s(self, probs: Tensor) -> GumbelDistribution:
    """Creates a Gumbel distribution for s.

    Args:
        probs (Tensor): Probabilities for the Gumbel distribution.

    Returns:
        GumbelDistribution: Gumbel distribution.
    """
    return self.encoder.q_s(probs)

q_z(loc, scale)

Creates a normal distribution for z.

Parameters:

Name Type Description Default
loc Tensor

Mean of the distribution.

required
scale Tensor

Standard deviation of the distribution.

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution.

Source code in vambn/modelling/models/hivae/encoder.py
237
238
239
240
241
242
243
244
245
246
247
def q_z(self, loc, scale: Tensor) -> dists.Normal:
    """Creates a normal distribution for z.

    Args:
        loc (Tensor): Mean of the distribution.
        scale (Tensor): Standard deviation of the distribution.

    Returns:
        dists.Normal: Normal distribution.
    """
    return self.encoder.q_z(loc, scale)

gan_hivae

GanHivae

Bases: AbstractGanModel[Tensor, Tensor, HivaeOutput, HivaeEncoding]

GAN-enhanced HIVAE model.

Parameters:

Name Type Description Default
variable_types VarTypes

Types of variables.

required
input_dim int

Dimension of input data.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
n_layers int

Number of layers.

required
noise_size int

Size of the noise vector. Defaults to 10.

10
num_timepoints Optional[int]

Number of time points. Defaults to None.

None
module_name Optional[str]

Name of the module. Defaults to "GanHivae".

'GanHivae'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Whether to use an imputation layer. Defaults to False.

False
individual_model bool

Whether to use individual models. Defaults to True.

True

Attributes:

Name Type Description
noise_size int

Size of the noise vector.

is_longitudinal bool

Flag for longitudinal data.

model Hivae or LstmHivae

HIVAE model.

generator Generator

GAN generator.

discriminator Discriminator

GAN discriminator.

device device

Device to run the model on.

one Parameter

Parameter for GAN training.

mone Parameter

Parameter for GAN training.

Source code in vambn/modelling/models/hivae/gan_hivae.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class GanHivae(
    AbstractGanModel[torch.Tensor, torch.Tensor, HivaeOutput, HivaeEncoding]
):
    """GAN-enhanced HIVAE model.

    Args:
        variable_types (VarTypes): Types of variables.
        input_dim (int): Dimension of input data.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        n_layers (int): Number of layers.
        noise_size (int, optional): Size of the noise vector. Defaults to 10.
        num_timepoints (Optional[int], optional): Number of time points. Defaults to None.
        module_name (Optional[str], optional): Name of the module. Defaults to "GanHivae".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Whether to use an imputation layer. Defaults to False.
        individual_model (bool, optional): Whether to use individual models. Defaults to True.

    Attributes:
        noise_size (int): Size of the noise vector.
        is_longitudinal (bool): Flag for longitudinal data.
        model (Hivae or LstmHivae): HIVAE model.
        generator (Generator): GAN generator.
        discriminator (Discriminator): GAN discriminator.
        device (torch.device): Device to run the model on.
        one (nn.Parameter): Parameter for GAN training.
        mone (nn.Parameter): Parameter for GAN training.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        n_layers: int,
        noise_size: int = 10,
        num_timepoints: Optional[int] = None,
        module_name: Optional[str] = "GanHivae",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ):
        super().__init__()

        self.noise_size = noise_size
        if num_timepoints is not None and num_timepoints > 1:
            self.is_longitudinal = True
            self.model = LstmHivae(
                variable_types=variable_types,
                input_dim=input_dim,
                dim_s=dim_s,
                dim_z=dim_z,
                dim_y=dim_y,
                n_layers=n_layers,
                num_timepoints=num_timepoints,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                individual_model=individual_model,
            )
        else:
            self.is_longitudinal = False
            self.model = Hivae(
                variable_types=variable_types,
                input_dim=input_dim,
                dim_s=dim_s,
                dim_z=dim_z,
                dim_y=dim_y,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                individual_model=individual_model,
            )

        self.generator = Generator(noise_size, (8, 4), dim_z)
        self.discriminator = Discriminator(dim_z, (8, 4), 1)
        #
        self.device = torch.device("cpu")
        self.one = nn.Parameter(
            torch.FloatTensor([1.0])[0], requires_grad=False
        )
        self.mone = nn.Parameter(
            torch.FloatTensor([-1.0])[0], requires_grad=False
        )

    def _train_gan_generator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> torch.Tensor:
        """Performs a training step for the GAN generator.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the generator.

        Returns:
            torch.Tensor: Generator loss.
        """
        optimizer.zero_grad()
        noise = torch.randn(data.shape[0], self.noise_size)
        fake_hidden = self.generator(noise)
        errG = self.discriminator(fake_hidden)
        self.fabric.backward(errG, self.one)
        optimizer.step()
        return errG

    def _train_gan_discriminator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs a training step for the GAN discriminator.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the discriminator.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Discriminator loss, real loss, and fake loss.
        """
        optimizer.zero_grad()
        _, _, encoder_output = self.model.encoder_part(data, mask)
        real_hidden = encoder_output.samples_z
        errD_real = self.discriminator(real_hidden.detach())
        self.fabric.backward(errD_real, self.one)

        noise = torch.randn(data.shape[0], self.noise_size)
        fake_hidden = self.generator(noise)
        errD_fake = self.discriminator(fake_hidden.detach())
        self.fabric.backward(errD_fake, self.mone)

        gradient_penaltiy = self._calc_gradient_penalty(
            real_hidden, fake_hidden
        )
        self.fabric.backward(gradient_penaltiy)
        optimizer.step()
        errD = -(errD_real - errD_fake)
        return errD, errD_real, errD_fake

    def _train_model_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> HivaeOutput:
        """Performs a training step for the HIVAE model.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the model.

        Returns:
            HivaeOutput: Model output.
        """
        optimizer.zero_grad()
        output = self.model.forward(data, mask)
        loss = output.loss
        # if loss > 1e10:
        #     raise optuna.TrialPruned("Unsuited hyperparameters. High loss")
        self.fabric.backward(loss)
        optimizer.step()
        return output

    def _train_model_from_discriminator_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> torch.Tensor:
        """Performs a training step for the model from the discriminator's perspective.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            optimizer (Optimizer): Optimizer for the model.

        Returns:
            torch.Tensor: Discriminator real loss.
        """
        optimizer.zero_grad()
        _, _, encoder_output = self.model.encoder_part(data, mask)
        real_hidden = encoder_output.samples_z
        errD_real = self.discriminator(real_hidden)
        self.fabric.backward(errD_real, self.mone)
        optimizer.step()
        return errD_real

    def _get_loss_from_output(self, output: HivaeOutput) -> torch.Tensor:
        """Extracts the loss from the model output.

        Args:
            output (HivaeOutput): Model output.

        Returns:
            torch.Tensor: Loss value.
        """
        return output.loss

    def _get_number_of_items(self, mask: torch.Tensor) -> int:
        """Gets the number of items in the mask.

        Args:
            mask (torch.Tensor): Data mask.

        Returns:
            int: Number of items.
        """
        return mask.sum()

    @cached_property
    def colnames(self) -> Tuple[str]:
        """Gets the column names of the data.

        Returns:
            Tuple[str]: Column names.
        """
        return self.model.colnames

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self.model.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self.model.decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self.model.tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.
        """
        self.model.tau = value

    @property
    def normalization_parameters(self) -> NormalizationParameters:
        """NormalizationParameters: Parameters for normalization."""
        return self.model.normalization_parameters

    @normalization_parameters.setter
    def normalization_parameters(self, value: NormalizationParameters) -> None:
        """Sets the normalization parameters.

        Args:
            value (NormalizationParameters): Normalization parameters.
        """
        self.model.normalization_parameters = value

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """Forward pass of the model.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            HivaeOutput: Model output.
        """
        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """Performs the decoder part of the forward pass.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: Model output.
        """
        return self.model.decoder_part(data, mask, encoder_output)

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """Performs the encoder part of the forward pass.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.
        """
        return self.model.encoder_part(data, mask)

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """Decodes the given encoding.

        Args:
            encoding (HivaeEncoding): Encoding to decode.

        Returns:
            torch.Tensor: Decoded data.
        """
        return self.model.decode(encoding)

    def _training_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        optimizer: Tuple[Optimizer],
    ) -> float:
        """Training step is not needed for this class."""
        raise Exception(f"Method not needed for {self.__class__.__name__}")

    def _validation_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> float:
        """Performs a validation step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            float: Validation loss.
        """
        return self.model._validation_step(data, mask)

    def _test_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> float:
        """Performs a test step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            float: Test loss.
        """
        return self.model._test_step(data, mask)

    def _predict_step(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
    ) -> HivaeOutput:
        """Performs a prediction step.

        Args:
            data (torch.Tensor): Input data.
            mask (torch.Tensor): Data mask.

        Returns:
            HivaeOutput: Prediction output.
        """
        return self.model._predict_step(data, mask)

colnames: Tuple[str] cached property

Gets the column names of the data.

Returns:

Type Description
Tuple[str]

Tuple[str]: Column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

normalization_parameters: NormalizationParameters property writable

NormalizationParameters: Parameters for normalization.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

decode(encoding)

Decodes the given encoding.

Parameters:

Name Type Description Default
encoding HivaeEncoding

Encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: Decoded data.

Source code in vambn/modelling/models/hivae/gan_hivae.py
333
334
335
336
337
338
339
340
341
342
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """Decodes the given encoding.

    Args:
        encoding (HivaeEncoding): Encoding to decode.

    Returns:
        torch.Tensor: Decoded data.
    """
    return self.model.decode(encoding)

decoder_part(data, mask, encoder_output)

Performs the decoder part of the forward pass.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """Performs the decoder part of the forward pass.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: Model output.
    """
    return self.model.decoder_part(data, mask, encoder_output)

encoder_part(data, mask)

Performs the encoder part of the forward pass.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.

Source code in vambn/modelling/models/hivae/gan_hivae.py
319
320
321
322
323
324
325
326
327
328
329
330
331
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """Performs the encoder part of the forward pass.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Encoder output and modified inputs.
    """
    return self.model.encoder_part(data, mask)

forward(data, mask)

Forward pass of the model.

Parameters:

Name Type Description Default
data Tensor

Input data.

required
mask Tensor

Data mask.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
287
288
289
290
291
292
293
294
295
296
297
298
299
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """Forward pass of the model.

    Args:
        data (torch.Tensor): Input data.
        mask (torch.Tensor): Data mask.

    Returns:
        HivaeOutput: Model output.
    """
    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output

GanModularHivae

Bases: AbstractGanModularModel[Tuple[Tensor, ...], Tuple[Tensor, ...], ModularHivaeOutput, ModularHivaeEncoding]

GAN-enhanced Modular HIVAE model.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for the data modules.

required
dim_s int | Dict[str, int]

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_ys int

Dimension of YS space.

required
dim_y int | Dict[str, int]

Dimension of y space.

required
noise_size int

Size of the noise vector. Defaults to 10.

10
shared_element_type str

Type of shared element. Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Whether to use an imputation layer. Defaults to False.

False

Attributes:

Name Type Description
module_configs Tuple[DataModuleConfig]

Configuration for the data modules.

mtl_method Tuple[str, ...]

Methods for multi-task learning.

use_imputation_layer bool

Whether to use an imputation layer.

dim_s int | Dict[str, int]

Dimension of s space.

dim_z int

Dimension of z space.

dim_ys int

Dimension of YS space.

dim_y int | Dict[str, int]

Dimension of y space.

model ModularHivae

Modular HIVAE model.

generators ModuleList

List of GAN generators.

discriminators ModuleList

List of GAN discriminators.

device device

Device to run the model on.

one Parameter

Parameter for GAN training.

mone Parameter

Parameter for GAN training.

noise_size int

Size of the noise vector.

Source code in vambn/modelling/models/hivae/gan_hivae.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
class GanModularHivae(
    AbstractGanModularModel[
        Tuple[Tensor, ...],
        Tuple[Tensor, ...],
        ModularHivaeOutput,
        ModularHivaeEncoding,
    ]
):
    """GAN-enhanced Modular HIVAE model.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for the data modules.
        dim_s (int | Dict[str, int]): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_ys (int): Dimension of YS space.
        dim_y (int | Dict[str, int]): Dimension of y space.
        noise_size (int, optional): Size of the noise vector. Defaults to 10.
        shared_element_type (str, optional): Type of shared element. Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Whether to use an imputation layer. Defaults to False.

    Attributes:
        module_configs (Tuple[DataModuleConfig]): Configuration for the data modules.
        mtl_method (Tuple[str, ...]): Methods for multi-task learning.
        use_imputation_layer (bool): Whether to use an imputation layer.
        dim_s (int | Dict[str, int]): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_ys (int): Dimension of YS space.
        dim_y (int | Dict[str, int]): Dimension of y space.
        model (ModularHivae): Modular HIVAE model.
        generators (nn.ModuleList): List of GAN generators.
        discriminators (nn.ModuleList): List of GAN discriminators.
        device (torch.device): Device to run the model on.
        one (nn.Parameter): Parameter for GAN training.
        mone (nn.Parameter): Parameter for GAN training.
        noise_size (int): Size of the noise vector.
    """

    def __init__(
        self,
        module_config: Tuple[DataModuleConfig],
        dim_s: int | Dict[str, int],
        dim_z: int,
        dim_ys: int,
        dim_y: int | Dict[str, int],
        noise_size: int = 10,
        shared_element_type: str = "none",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
    ):
        super().__init__()
        self.module_configs = module_config
        self.mtl_method = mtl_method
        self.use_imputation_layer = use_imputation_layer
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_ys = dim_ys
        self.dim_y = dim_y

        self.model = ModularHivae(
            module_config=module_config,
            dim_s=dim_s,
            dim_z=dim_z,
            dim_ys=dim_ys,
            dim_y=dim_y,
            shared_element_type=shared_element_type,
            mtl_method=mtl_method,
            use_imputation_layer=use_imputation_layer,
        )

        self.generators = nn.ModuleList(
            [
                Generator(noise_size, (8, 4), dim_z)
                for _ in range(len(module_config))
            ]
        )
        self.discriminators = nn.ModuleList(
            [Discriminator(dim_z, (8, 4), 1) for _ in range(len(module_config))]
        )

        self.device = torch.device("cpu")
        self.one = nn.Parameter(
            torch.FloatTensor([1.0])[0], requires_grad=False
        )
        self.mone = nn.Parameter(
            torch.FloatTensor([-1.0])[0], requires_grad=False
        )
        self.noise_size = noise_size

    def _train_gan_generator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor]:
        """Performs a training step for the GAN generators.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the generators.

        Returns:
            Tuple[Tensor]: Generator losses.
        """
        errG_list = []
        for data, mask, generator, discriminator, opt in zip(
            data, mask, self.generators, self.discriminators, optimizer
        ):
            opt.zero_grad()
            noise = torch.randn(data.shape[0], self.noise_size)
            fake_hidden = generator(noise)
            errG = discriminator(fake_hidden)
            self.fabric.backward(errG, self.one)
            opt.step()
            errG_list.append(errG.detach())
        return tuple(errG_list)

    def _train_gan_discriminator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """Performs a training step for the GAN discriminators.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the discriminators.

        Returns:
            Tuple[Tensor, Tensor, Tensor]: Discriminator loss, real loss, and fake loss.
        """
        errD_list = []
        errD_real_list = []
        errD_fake_list = []
        for data, mask, generator, discriminator, opt, module in zip(
            data,
            mask,
            self.generators,
            self.discriminators,
            optimizer,
            self.model.module_models.values(),
        ):
            opt.zero_grad()
            _, _, encoder_output = module.encoder_part(data, mask)
            real_hidden = encoder_output.samples_z
            errD_real = discriminator(real_hidden.detach())
            self.fabric.backward(errD_real, self.one)

            noise = torch.randn((data.shape[0], self.noise_size))
            fake_hidden = generator(noise)
            errD_fake = discriminator(fake_hidden.detach())
            self.fabric.backward(errD_fake, self.mone)

            gradient_penaltiy = self._calc_gradient_penalty(
                real_hidden, fake_hidden, discriminator
            )
            self.fabric.backward(gradient_penaltiy)
            opt.step()
            errD = -(errD_real - errD_fake)
            errD_list.append(errD.detach())
            errD_real_list.append(errD_real.detach())
            errD_fake_list.append(errD_fake.detach())
        return tuple(errD_list), tuple(errD_real_list), tuple(errD_fake_list)

    def _train_model_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> ModularHivaeOutput:
        """Performs a training step for the Modular HIVAE model.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the model.

        Returns:
            ModularHivaeOutput: Model output.
        """
        for opt in optimizer:
            if opt is None:
                continue
            opt.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        # if output.loss > 1e10:
        #     raise optuna.TrialPruned("Unsuited hyperparameters. High loss")
        for opt in optimizer:
            if opt is None:
                continue
            opt.step()
        return output

    def _train_model_from_discriminator_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> Tuple[Tensor]:
        """Performs a training step for the model from the discriminator's perspective.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.
            optimizer (Tuple[Optimizer]): Optimizers for the model.

        Returns:
            Tuple[Tensor]: Discriminator real losses.
        """
        errD_real_list = []
        for data, mask, discriminator, opt, module in zip(
            data,
            mask,
            self.discriminators,
            optimizer,
            self.model.module_models.values(),
        ):
            opt.zero_grad()
            _, _, encoder_output = module.encoder_part(data, mask)
            real_hidden = encoder_output.samples_z
            errD_real = discriminator(real_hidden)
            self.fabric.backward(errD_real, self.mone)
            opt.step()
            errD_real_list.append(errD_real.detach())
        return tuple(errD_real_list)

    def _get_loss_from_output(self, output: ModularHivaeOutput) -> float:
        """Extracts the loss from the model output.

        Args:
            output (ModularHivaeOutput): Model output.

        Returns:
            float: Loss value.
        """
        return output.avg_loss

    def _get_number_of_items(self, mask: Tuple[Tensor]) -> int:
        """Gets the number of items in the mask.

        Args:
            mask (Tuple[Tensor]): Data mask.

        Returns:
            int: Number of items.
        """
        return sum([m.sum() for m in mask])

    @cached_property
    def colnames(self, module_name: str) -> Tuple[str, ...]:
        """Gets the column names for a specific module.

        Args:
            module_name (str): Name of the module.

        Returns:
            Tuple[str, ...]: Column names.
        """
        return self.model.module_models[module_name].colnames

    @property
    def decoding(self) -> bool:
        """bool: Flag indicating whether the model is in decoding mode."""
        return self.model.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """Sets the decoding flag.

        Args:
            value (bool): Decoding flag.
        """
        self.model.decoding = value

    @property
    def tau(self) -> float:
        """float: Temperature parameter for Gumbel softmax."""
        return self.model.tau

    @tau.setter
    def tau(self, value: float) -> None:
        """Sets the temperature parameter for Gumbel softmax.

        Args:
            value (float): Temperature value.
        """
        self.model.tau = value

    def forward(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """Forward pass of the model.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            ModularHivaeOutput: Model output.
        """
        return self.model.forward(data, mask)

    def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
        """Decodes the given encoding.

        Args:
            encoding (ModularHivaeEncoding): Encoding to decode.

        Returns:
            Tuple[Tensor, ...]: Decoded data.
        """
        return self.model.decode(encoding)

    def _training_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[Optimizer],
    ) -> float:
        """Training step is not needed for this class."""
        raise Exception(f"Method not needed for {self.__class__.__name__}")

    def _validation_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> float:
        """Performs a validation step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            float: Validation loss.
        """
        return self.model._validation_step(data, mask)

    def _test_step(self, data: Tuple[Tensor], mask: Tuple[Tensor]) -> float:
        """Performs a test step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            float: Test loss.
        """
        return self.model._test_step(data, mask)

    def _predict_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """Performs a prediction step.

        Args:
            data (Tuple[Tensor]): Input data.
            mask (Tuple[Tensor]): Data mask.

        Returns:
            ModularHivaeOutput: Prediction output.
        """
        return self.model._predict_step(data, mask)

colnames: Tuple[str, ...] cached property

Gets the column names for a specific module.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: Column names.

decoding: bool property writable

bool: Flag indicating whether the model is in decoding mode.

tau: float property writable

float: Temperature parameter for Gumbel softmax.

decode(encoding)

Decodes the given encoding.

Parameters:

Name Type Description Default
encoding ModularHivaeEncoding

Encoding to decode.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[Tensor, ...]: Decoded data.

Source code in vambn/modelling/models/hivae/gan_hivae.py
708
709
710
711
712
713
714
715
716
717
def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
    """Decodes the given encoding.

    Args:
        encoding (ModularHivaeEncoding): Encoding to decode.

    Returns:
        Tuple[Tensor, ...]: Decoded data.
    """
    return self.model.decode(encoding)

forward(data, mask)

Forward pass of the model.

Parameters:

Name Type Description Default
data Tuple[Tensor]

Input data.

required
mask Tuple[Tensor]

Data mask.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

Model output.

Source code in vambn/modelling/models/hivae/gan_hivae.py
694
695
696
697
698
699
700
701
702
703
704
705
706
def forward(
    self, data: Tuple[Tensor], mask: Tuple[Tensor]
) -> ModularHivaeOutput:
    """Forward pass of the model.

    Args:
        data (Tuple[Tensor]): Input data.
        mask (Tuple[Tensor]): Data mask.

    Returns:
        ModularHivaeOutput: Model output.
    """
    return self.model.forward(data, mask)

heads

BaseModuleHead

Bases: Generic[ParameterType, DistributionType], Module, ABC

Base class for different data types.

Parameters:

Name Type Description Default
variable_type VariableType

Dataclass containing the type information.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required

Attributes:

Name Type Description
types VariableType

Dataclass containing the type information.

dim_s int

Dimension of s space.

dim_z int

Dimension of z space.

dim_y int

Dimension of y space.

internal_pass ModifiedLinear

Internal pass module.

Properties

num_parameters (int): Number of parameters.

Methods:

Name Description
forward

Tensor, samples_s: Tensor) -> ParameterType: Forward pass of the module.

dist

ParameterType) -> DistributionType: Compute the distribution given the parameters.

log_prob

Tensor, params: Optional[ParameterType] = None) -> Tensor: Compute the log probability of the data given the parameters.

sample

Sample from the distribution.

rsample

Sample using the reparameterization trick.

mode

Compute the mode of the distribution.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class BaseModuleHead(
    Generic[ParameterType, DistributionType], nn.Module, ABC
):
    """Base class for different data types.

    Args:
        variable_type (VariableType): Dataclass containing the type information.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.

    Attributes:
        types (VariableType): Dataclass containing the type information.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        internal_pass (ModifiedLinear): Internal pass module.

    Properties:
        num_parameters (int): Number of parameters.

    Methods:
        forward(samples_z: Tensor, samples_s: Tensor) -> ParameterType: Forward pass of the module.
        dist(params: ParameterType) -> DistributionType: Compute the distribution given the parameters.
        log_prob(data: Tensor, params: Optional[ParameterType] = None) -> Tensor: Compute the log probability of the data given the parameters.
        sample() -> Tensor: Sample from the distribution.
        rsample() -> Tensor: Sample using the reparameterization trick.
        mode() -> Tensor: Compute the mode of the distribution.

    Raises:
        RuntimeError: If the distribution is not initialized.

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Base class for the different datatypes

        Args:
            types (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__()
        # General parameters
        self.types = variable_type
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_y = dim_y
        self.internal_pass = ModifiedLinear(self.dim_z, self.dim_y, bias=True)

        # Specific parameters
        self._dist = None
        self._n_pars = None

    @abstractmethod
    def forward(self, samples_z: Tensor, samples_s: Tensor) -> ParameterType:
        """Forward pass of the module.

        Args:
            samples_z (Tensor): Samples from the z space.
            samples_s (Tensor): Samples from the s space.

        Returns:
            ParameterType: The output of the forward pass.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError

    @abstractmethod
    def dist(self, params: ParameterType) -> DistributionType:
        """Compute the distribution given the parameters.

        Args:
            params (ParameterType): The parameters of the distribution.

        Returns:
            DistributionType: The computed distribution.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError

    def log_prob(
        self, data: Tensor, params: Optional[ParameterType] = None
    ) -> Tensor:
        """Compute the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (Optional[ParameterType], optional): The parameters of the distribution. Defaults to None.

        Returns:
            Tensor: The log probability of the data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None and params is None:
            raise RuntimeError("Distribution is not initialized.")
        elif params is not None:
            self.dist(params)

        return self._dist.log_prob(data)

    def sample(self) -> Tensor:
        """Sample from the distribution.

        Returns:
            Tensor: The sampled data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")

        gen_sample = self._dist.sample()
        if gen_sample.ndim == 1:
            return gen_sample.unsqueeze(1)
        elif gen_sample.ndim == 0:
            # ensure 2 dim output
            return gen_sample.unsqueeze(0).unsqueeze(1)
        return gen_sample

    def rsample(self) -> Tensor:
        """Sample using the reparameterization trick.

        Returns:
            Tensor: The sampled data.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")

        gen_sample = self._dist.rsample()
        if gen_sample.ndim == 1:
            return gen_sample.unsqueeze(1)
        return gen_sample

    @property
    def mode(self) -> Tensor:
        """Compute the mode of the distribution.

        Returns:
            Tensor: The mode of the distribution.

        Raises:
            RuntimeError: If the distribution is not initialized.
        """
        if self._dist is None:
            raise RuntimeError("Distribution is not initialized.")
        res = self._dist.mode
        if res.ndim == 1:
            return res.unsqueeze(1)
        return res

    @property
    def num_parameters(self) -> int:
        """Number of parameters.

        Returns:
            int: The number of parameters.
        """
        return self._n_pars

mode: Tensor property

Compute the mode of the distribution.

Returns:

Name Type Description
Tensor Tensor

The mode of the distribution.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

num_parameters: int property

Number of parameters.

Returns:

Name Type Description
int int

The number of parameters.

__init__(variable_type, dim_s, dim_z, dim_y)

Base class for the different datatypes

Parameters:

Name Type Description Default
types VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Base class for the different datatypes

    Args:
        types (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__()
    # General parameters
    self.types = variable_type
    self.dim_s = dim_s
    self.dim_z = dim_z
    self.dim_y = dim_y
    self.internal_pass = ModifiedLinear(self.dim_z, self.dim_y, bias=True)

    # Specific parameters
    self._dist = None
    self._n_pars = None

dist(params) abstractmethod

Compute the distribution given the parameters.

Parameters:

Name Type Description Default
params ParameterType

The parameters of the distribution.

required

Returns:

Name Type Description
DistributionType DistributionType

The computed distribution.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in vambn/modelling/models/hivae/heads.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@abstractmethod
def dist(self, params: ParameterType) -> DistributionType:
    """Compute the distribution given the parameters.

    Args:
        params (ParameterType): The parameters of the distribution.

    Returns:
        DistributionType: The computed distribution.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError

forward(samples_z, samples_s) abstractmethod

Forward pass of the module.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space.

required
samples_s Tensor

Samples from the s space.

required

Returns:

Name Type Description
ParameterType ParameterType

The output of the forward pass.

Raises:

Type Description
NotImplementedError

If the method is not implemented.

Source code in vambn/modelling/models/hivae/heads.py
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@abstractmethod
def forward(self, samples_z: Tensor, samples_s: Tensor) -> ParameterType:
    """Forward pass of the module.

    Args:
        samples_z (Tensor): Samples from the z space.
        samples_s (Tensor): Samples from the s space.

    Returns:
        ParameterType: The output of the forward pass.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError

log_prob(data, params=None)

Compute the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params Optional[ParameterType]

The parameters of the distribution. Defaults to None.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def log_prob(
    self, data: Tensor, params: Optional[ParameterType] = None
) -> Tensor:
    """Compute the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (Optional[ParameterType], optional): The parameters of the distribution. Defaults to None.

    Returns:
        Tensor: The log probability of the data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None and params is None:
        raise RuntimeError("Distribution is not initialized.")
    elif params is not None:
        self.dist(params)

    return self._dist.log_prob(data)

rsample()

Sample using the reparameterization trick.

Returns:

Name Type Description
Tensor Tensor

The sampled data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def rsample(self) -> Tensor:
    """Sample using the reparameterization trick.

    Returns:
        Tensor: The sampled data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None:
        raise RuntimeError("Distribution is not initialized.")

    gen_sample = self._dist.rsample()
    if gen_sample.ndim == 1:
        return gen_sample.unsqueeze(1)
    return gen_sample

sample()

Sample from the distribution.

Returns:

Name Type Description
Tensor Tensor

The sampled data.

Raises:

Type Description
RuntimeError

If the distribution is not initialized.

Source code in vambn/modelling/models/hivae/heads.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def sample(self) -> Tensor:
    """Sample from the distribution.

    Returns:
        Tensor: The sampled data.

    Raises:
        RuntimeError: If the distribution is not initialized.
    """
    if self._dist is None:
        raise RuntimeError("Distribution is not initialized.")

    gen_sample = self._dist.sample()
    if gen_sample.ndim == 1:
        return gen_sample.unsqueeze(1)
    elif gen_sample.ndim == 0:
        # ensure 2 dim output
        return gen_sample.unsqueeze(0).unsqueeze(1)
    return gen_sample

CatHead

Bases: BaseModuleHead[CategoricalParameters, ReparameterizedCategorical]

Class representing the categorical head of a model.

Attributes:

Name Type Description
variable_type VariableType

Array containing the data type information (type, class, ndim)

dim_s int

Dimension of s space

dim_z int

Dimension of z space

dim_y int

Dimension of y space

logit_layer ModifiedLinear

Linear layer for computing logits

_dist_class ReparameterizedCategorical

Class for representing reparameterized categorical distribution

Source code in vambn/modelling/models/hivae/heads.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
class CatHead(
    BaseModuleHead[
        CategoricalParameters, ReparameterizedCategorical
    ]
):
    """Class representing the categorical head of a model.

    Attributes:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
        logit_layer (ModifiedLinear): Linear layer for computing logits
        _dist_class (ReparameterizedCategorical): Class for representing reparameterized categorical distribution
    """

    # @typechecked
    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initialize the CatHead class.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.logit_layer = ModifiedLinear(
            self.dim_y + self.dim_s, variable_type.n_parameters, bias=False
        )
        self._dist_class = ReparameterizedCategorical

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> CategoricalParameters:
        """Forward pass of the CatHead.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            CategoricalParameters: Categorical parameters
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        logits = self.logit_layer(s_and_y)
        return CategoricalParameters(logits)

    def dist(self, params: CategoricalParameters) -> ReparameterizedCategorical:
        """Compute the reparameterized categorical distribution.

        Args:
            params (CategoricalParameters): Categorical parameters

        Returns:
            ReparameterizedCategorical: Reparameterized categorical distribution
        """
        self._dist = self._dist_class(logits=params.logits)
        return self._dist

__init__(variable_type, dim_s, dim_z, dim_y)

Initialize the CatHead class.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initialize the CatHead class.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.logit_layer = ModifiedLinear(
        self.dim_y + self.dim_s, variable_type.n_parameters, bias=False
    )
    self._dist_class = ReparameterizedCategorical

dist(params)

Compute the reparameterized categorical distribution.

Parameters:

Name Type Description Default
params CategoricalParameters

Categorical parameters

required

Returns:

Name Type Description
ReparameterizedCategorical ReparameterizedCategorical

Reparameterized categorical distribution

Source code in vambn/modelling/models/hivae/heads.py
554
555
556
557
558
559
560
561
562
563
564
def dist(self, params: CategoricalParameters) -> ReparameterizedCategorical:
    """Compute the reparameterized categorical distribution.

    Args:
        params (CategoricalParameters): Categorical parameters

    Returns:
        ReparameterizedCategorical: Reparameterized categorical distribution
    """
    self._dist = self._dist_class(logits=params.logits)
    return self._dist

forward(samples_z, samples_s)

Forward pass of the CatHead.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
CategoricalParameters CategoricalParameters

Categorical parameters

Source code in vambn/modelling/models/hivae/heads.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def forward(self, samples_z: Tensor, samples_s: Tensor) -> CategoricalParameters:
    """Forward pass of the CatHead.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        CategoricalParameters: Categorical parameters
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    logits = self.logit_layer(s_and_y)
    return CategoricalParameters(logits)

CountHead

Bases: BaseModuleHead[PoissonParameters, Poisson]

Head module for the Poisson distribution (Count data).

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
lambda_layer ModifiedLinear

Linear layer for computing the rate parameter

_dist_class Poisson

Class for representing the Poisson distribution

Source code in vambn/modelling/models/hivae/heads.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
class CountHead(BaseModuleHead[PoissonParameters, dists.Poisson]):
    """Head module for the Poisson distribution (Count data).

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        lambda_layer (ModifiedLinear): Linear layer for computing the rate parameter
        _dist_class (dists.Poisson): Class for representing the Poisson distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initializes the CountHead class.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.lambda_layer = ModifiedLinear(
            self.dim_y + self.dim_s, 1, bias=False
        )
        self._dist_class = dists.Poisson

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> PoissonParameters:
        """Performs the forward pass of the CountHead.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            PoissonParameters: The Poisson parameter
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        rate = self.lambda_layer(s_and_y)
        return PoissonParameters(rate)

    def dist(self, params: PoissonParameters) -> dists.Poisson:
        """Creates a Poisson distribution from the given parameters.

        Args:
            params (PoissonParameters): The Poisson parameters

        Returns:
            dists.Poisson: The Poisson distribution
        """
        self._dist = self._dist_class(params.rate.squeeze())
        return self._dist

__init__(variable_type, dim_s, dim_z, dim_y)

Initializes the CountHead class.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initializes the CountHead class.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.lambda_layer = ModifiedLinear(
        self.dim_y + self.dim_s, 1, bias=False
    )
    self._dist_class = dists.Poisson

dist(params)

Creates a Poisson distribution from the given parameters.

Parameters:

Name Type Description Default
params PoissonParameters

The Poisson parameters

required

Returns:

Type Description
Poisson

dists.Poisson: The Poisson distribution

Source code in vambn/modelling/models/hivae/heads.py
492
493
494
495
496
497
498
499
500
501
502
def dist(self, params: PoissonParameters) -> dists.Poisson:
    """Creates a Poisson distribution from the given parameters.

    Args:
        params (PoissonParameters): The Poisson parameters

    Returns:
        dists.Poisson: The Poisson distribution
    """
    self._dist = self._dist_class(params.rate.squeeze())
    return self._dist

forward(samples_z, samples_s)

Performs the forward pass of the CountHead.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
PoissonParameters PoissonParameters

The Poisson parameter

Source code in vambn/modelling/models/hivae/heads.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def forward(self, samples_z: Tensor, samples_s: Tensor) -> PoissonParameters:
    """Performs the forward pass of the CountHead.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        PoissonParameters: The Poisson parameter
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    rate = self.lambda_layer(s_and_y)
    return PoissonParameters(rate)

PosHead

Bases: BaseModuleHead[LogNormalParameters, LogNormal]

Head module for the LogNormal (pos) distribution

Attributes:

Name Type Description
variable_type VariableType

The type of variable.

dim_s int

The dimension of s.

dim_z int

The dimension of z.

dim_y int

The dimension of y.

loc_layer ModifiedLinear

The linear layer for computing the location parameter.

scale_layer ModifiedLinear

The linear layer for computing the scale parameter.

_dist_class LogNormal

The class representing the LogNormal distribution.

_n_pars int

The number of parameters in the distribution.

Source code in vambn/modelling/models/hivae/heads.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
class PosHead(
    BaseModuleHead[LogNormalParameters, dists.LogNormal]
):
    """
    Head module for the LogNormal (pos) distribution

    Attributes:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimension of s.
        dim_z (int): The dimension of z.
        dim_y (int): The dimension of y.
        loc_layer (ModifiedLinear): The linear layer for computing the location parameter.
        scale_layer (ModifiedLinear): The linear layer for computing the scale parameter.
        _dist_class (dists.LogNormal): The class representing the LogNormal distribution.
        _n_pars (int): The number of parameters in the distribution.

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """
        Initializes the PosHead class.

        Args:
            variable_type (VariableType): The type of variable.
            dim_s (int): The dimension of s.
            dim_z (int): The dimension of z.
            dim_y (int): The dimension of y.
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._dist_class = dists.LogNormal
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> LogNormalParameters:
        """
        Performs the forward pass of the PosHead.

        Args:
            samples_z (Tensor): The z samples.
            samples_s (Tensor): The s samples.

        Returns:
            LogNormalParameters: The output of the forward pass.
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)

        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return LogNormalParameters(loc, scale)

    def dist(self, params: LogNormalParameters) -> dists.LogNormal:
        """
        Creates a LogNormal distribution based on the given parameters.

        Args:
            params (LogNormalParameters): The parameters of the distribution.

        Returns:
            dists.LogNormal: The LogNormal distribution.
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze()
        )
        return self._dist

    def log_prob(
        self, data: Tensor, params: LogNormalParameters | None = None
    ) -> Tensor:
        """
        Computes the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (LogNormalParameters | None): The parameters of the distribution.

        Returns:
            Tensor: The log probability of the data.
        """
        return super().log_prob(data.clamp(min=1e-3), params)

__init__(variable_type, dim_s, dim_z, dim_y)

Initializes the PosHead class.

Parameters:

Name Type Description Default
variable_type VariableType

The type of variable.

required
dim_s int

The dimension of s.

required
dim_z int

The dimension of z.

required
dim_y int

The dimension of y.

required
Source code in vambn/modelling/models/hivae/heads.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """
    Initializes the PosHead class.

    Args:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimension of s.
        dim_z (int): The dimension of z.
        dim_y (int): The dimension of y.
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._dist_class = dists.LogNormal
    self._n_pars = 2

dist(params)

Creates a LogNormal distribution based on the given parameters.

Parameters:

Name Type Description Default
params LogNormalParameters

The parameters of the distribution.

required

Returns:

Type Description
LogNormal

dists.LogNormal: The LogNormal distribution.

Source code in vambn/modelling/models/hivae/heads.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def dist(self, params: LogNormalParameters) -> dists.LogNormal:
    """
    Creates a LogNormal distribution based on the given parameters.

    Args:
        params (LogNormalParameters): The parameters of the distribution.

    Returns:
        dists.LogNormal: The LogNormal distribution.
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze()
    )
    return self._dist

forward(samples_z, samples_s)

Performs the forward pass of the PosHead.

Parameters:

Name Type Description Default
samples_z Tensor

The z samples.

required
samples_s Tensor

The s samples.

required

Returns:

Name Type Description
LogNormalParameters LogNormalParameters

The output of the forward pass.

Source code in vambn/modelling/models/hivae/heads.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def forward(self, samples_z: Tensor, samples_s: Tensor) -> LogNormalParameters:
    """
    Performs the forward pass of the PosHead.

    Args:
        samples_z (Tensor): The z samples.
        samples_s (Tensor): The s samples.

    Returns:
        LogNormalParameters: The output of the forward pass.
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)

    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return LogNormalParameters(loc, scale)

log_prob(data, params=None)

Computes the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params LogNormalParameters | None

The parameters of the distribution.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Source code in vambn/modelling/models/hivae/heads.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def log_prob(
    self, data: Tensor, params: LogNormalParameters | None = None
) -> Tensor:
    """
    Computes the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (LogNormalParameters | None): The parameters of the distribution.

    Returns:
        Tensor: The log probability of the data.
    """
    return super().log_prob(data.clamp(min=1e-3), params)

RealHead

Bases: BaseModuleHead[NormalParameters, Normal]

Class representing the RealHead module.

This module is used for data of type real or pos.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
loc_layer ModifiedLinear

Linear layer for computing the location parameter

scale_layer ModifiedLinear

Linear layer for computing the scale parameter

_dist_class type

Class representing the distribution

_n_pars int

Number of parameters in the distribution

Source code in vambn/modelling/models/hivae/heads.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class RealHead(BaseModuleHead[NormalParameters, dists.Normal]):
    """Class representing the RealHead module.

    This module is used for data of type real or pos.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        loc_layer (ModifiedLinear): Linear layer for computing the location parameter
        scale_layer (ModifiedLinear): Linear layer for computing the scale parameter
        _dist_class (type): Class representing the distribution
        _n_pars (int): Number of parameters in the distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """Initialize the RealHead module.

        Args:
            variable_type (VariableType): Array containing the data type information (type, class, ndim)
            dim_s (int): Dimension of s space
            dim_z (int): Dimension of z space
            dim_y (int): Dimension of y space
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._parameter_class = NormalParameters
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
        """Forward pass of the RealHead module.

        Args:
            samples_z (Tensor): Samples from the z space
            samples_s (Tensor): Samples from the s space

        Returns:
            NormalParameters: Parameters of the Normal distribution
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return NormalParameters(loc, scale)

    def dist(self, params: NormalParameters) -> dists.Normal:
        """Create a Normal distribution based on the given parameters.

        Args:
            params (NormalParameters): Parameters of the Normal distribution

        Returns:
            dists.Normal: Normal distribution
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze()
        )
        return self._dist

__init__(variable_type, dim_s, dim_z, dim_y)

Initialize the RealHead module.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required
Source code in vambn/modelling/models/hivae/heads.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """Initialize the RealHead module.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._parameter_class = NormalParameters
    self._n_pars = 2

dist(params)

Create a Normal distribution based on the given parameters.

Parameters:

Name Type Description Default
params NormalParameters

Parameters of the Normal distribution

required

Returns:

Type Description
Normal

dists.Normal: Normal distribution

Source code in vambn/modelling/models/hivae/heads.py
255
256
257
258
259
260
261
262
263
264
265
266
267
def dist(self, params: NormalParameters) -> dists.Normal:
    """Create a Normal distribution based on the given parameters.

    Args:
        params (NormalParameters): Parameters of the Normal distribution

    Returns:
        dists.Normal: Normal distribution
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze()
    )
    return self._dist

forward(samples_z, samples_s)

Forward pass of the RealHead module.

Parameters:

Name Type Description Default
samples_z Tensor

Samples from the z space

required
samples_s Tensor

Samples from the s space

required

Returns:

Name Type Description
NormalParameters NormalParameters

Parameters of the Normal distribution

Source code in vambn/modelling/models/hivae/heads.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
    """Forward pass of the RealHead module.

    Args:
        samples_z (Tensor): Samples from the z space
        samples_s (Tensor): Samples from the s space

    Returns:
        NormalParameters: Parameters of the Normal distribution
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return NormalParameters(loc, scale)

TruncatedNormalHead

Bases: BaseModuleHead[NormalParameters, TruncatedNormal]

Class representing the TruncatedNormalHead module.

This module is used for data of type real or pos.

Parameters:

Name Type Description Default
variable_type VariableType

Array containing the data type information (type, class, ndim)

required
dim_s int

Dimension of s space

required
dim_z int

Dimension of z space

required
dim_y int

Dimension of y space

required

Attributes:

Name Type Description
loc_layer ModifiedLinear

Linear layer for computing the location parameter

scale_layer ModifiedLinear

Linear layer for computing the scale parameter

_dist_class type

Class representing the distribution

_n_pars int

Number of parameters in the distribution

Source code in vambn/modelling/models/hivae/heads.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class TruncatedNormalHead(
    BaseModuleHead[NormalParameters, TruncatedNormal]
):
    """
    Class representing the TruncatedNormalHead module.

    This module is used for data of type real or pos.

    Args:
        variable_type (VariableType): Array containing the data type information (type, class, ndim)
        dim_s (int): Dimension of s space
        dim_z (int): Dimension of z space
        dim_y (int): Dimension of y space

    Attributes:
        loc_layer (ModifiedLinear): Linear layer for computing the location parameter
        scale_layer (ModifiedLinear): Linear layer for computing the scale parameter
        _dist_class (type): Class representing the distribution
        _n_pars (int): Number of parameters in the distribution

    """

    def __init__(
        self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
    ) -> None:
        """
        Initializes a TruncatedNormalHead object.

        Args:
            variable_type (VariableType): The type of variable.
            dim_s (int): The dimensionality of s.
            dim_z (int): The dimensionality of z.
            dim_y (int): The dimensionality of y.
        """
        super().__init__(variable_type, dim_s, dim_z, dim_y)
        self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
        self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

        self._dist_class = TruncatedNormal
        self._n_pars = 2

    def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
        """
        Performs a forward pass through the TruncatedNormalHead.

        Args:
            samples_z (Tensor): The z samples.
            samples_s (Tensor): The s samples.

        Returns:
            NormalParameters: The output of the forward pass.
        """
        y = self.internal_pass(samples_z)
        s_and_y = torch.cat([y, samples_s], dim=-1)
        loc = self.loc_layer(s_and_y)
        scale = nn.functional.softplus(self.scale_layer(samples_s))
        return NormalParameters(loc, scale)

    def dist(self, params: NormalParameters) -> TruncatedNormal:
        """
        Creates a TruncatedNormal distribution based on the given parameters.

        Args:
            params (NormalParameters): The parameters of the distribution.

        Returns:
            TruncatedNormal: The created TruncatedNormal distribution.
        """
        self._dist = self._dist_class(
            params.loc.squeeze(), params.scale.squeeze(), low=torch.tensor(0.0)
        )
        return self._dist

    def log_prob(
        self, data: Tensor, params: NormalParameters | None = None
    ) -> Tensor:
        """
        Computes the log probability of the data given the parameters.

        Args:
            data (Tensor): The input data.
            params (NormalParameters | None): The parameters of the distribution.

        Returns:
            Tensor: The log probability of the data.
        """
        return super().log_prob(data.clamp(min=1e-3), params)

__init__(variable_type, dim_s, dim_z, dim_y)

Initializes a TruncatedNormalHead object.

Parameters:

Name Type Description Default
variable_type VariableType

The type of variable.

required
dim_s int

The dimensionality of s.

required
dim_z int

The dimensionality of z.

required
dim_y int

The dimensionality of y.

required
Source code in vambn/modelling/models/hivae/heads.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def __init__(
    self, variable_type: VariableType, dim_s: int, dim_z: int, dim_y: int
) -> None:
    """
    Initializes a TruncatedNormalHead object.

    Args:
        variable_type (VariableType): The type of variable.
        dim_s (int): The dimensionality of s.
        dim_z (int): The dimensionality of z.
        dim_y (int): The dimensionality of y.
    """
    super().__init__(variable_type, dim_s, dim_z, dim_y)
    self.loc_layer = ModifiedLinear(self.dim_y + self.dim_s, 1, bias=False)
    self.scale_layer = ModifiedLinear(self.dim_s, 1, bias=False)

    self._dist_class = TruncatedNormal
    self._n_pars = 2

dist(params)

Creates a TruncatedNormal distribution based on the given parameters.

Parameters:

Name Type Description Default
params NormalParameters

The parameters of the distribution.

required

Returns:

Name Type Description
TruncatedNormal TruncatedNormal

The created TruncatedNormal distribution.

Source code in vambn/modelling/models/hivae/heads.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def dist(self, params: NormalParameters) -> TruncatedNormal:
    """
    Creates a TruncatedNormal distribution based on the given parameters.

    Args:
        params (NormalParameters): The parameters of the distribution.

    Returns:
        TruncatedNormal: The created TruncatedNormal distribution.
    """
    self._dist = self._dist_class(
        params.loc.squeeze(), params.scale.squeeze(), low=torch.tensor(0.0)
    )
    return self._dist

forward(samples_z, samples_s)

Performs a forward pass through the TruncatedNormalHead.

Parameters:

Name Type Description Default
samples_z Tensor

The z samples.

required
samples_s Tensor

The s samples.

required

Returns:

Name Type Description
NormalParameters NormalParameters

The output of the forward pass.

Source code in vambn/modelling/models/hivae/heads.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def forward(self, samples_z: Tensor, samples_s: Tensor) -> NormalParameters:
    """
    Performs a forward pass through the TruncatedNormalHead.

    Args:
        samples_z (Tensor): The z samples.
        samples_s (Tensor): The s samples.

    Returns:
        NormalParameters: The output of the forward pass.
    """
    y = self.internal_pass(samples_z)
    s_and_y = torch.cat([y, samples_s], dim=-1)
    loc = self.loc_layer(s_and_y)
    scale = nn.functional.softplus(self.scale_layer(samples_s))
    return NormalParameters(loc, scale)

log_prob(data, params=None)

Computes the log probability of the data given the parameters.

Parameters:

Name Type Description Default
data Tensor

The input data.

required
params NormalParameters | None

The parameters of the distribution.

None

Returns:

Name Type Description
Tensor Tensor

The log probability of the data.

Source code in vambn/modelling/models/hivae/heads.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def log_prob(
    self, data: Tensor, params: NormalParameters | None = None
) -> Tensor:
    """
    Computes the log probability of the data given the parameters.

    Args:
        data (Tensor): The input data.
        params (NormalParameters | None): The parameters of the distribution.

    Returns:
        Tensor: The log probability of the data.
    """
    return super().log_prob(data.clamp(min=1e-3), params)

hivae

Hivae

Bases: AbstractNormalModel[Tensor, Tensor, HivaeOutput, HivaeEncoding]

Entire HIVAE model containing Encoder and Decoder structure.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects defining the types of the variables in the data.

required
input_dim int

Dimension of input data (number of columns in the dataframe). If the data contains categorical variables, the input dimension is larger than the number of features.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
module_name str

Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.

'HIVAE'
mtl_method Tuple[str]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
individual_model bool

Flag to indicate if the current model is applied individually or as part of e.g. a modular HIVAE. Defaults to True.

True
Source code in vambn/modelling/models/hivae/hivae.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class Hivae(
    AbstractNormalModel[torch.Tensor, torch.Tensor, HivaeOutput, HivaeEncoding]
):
    """
    Entire HIVAE model containing Encoder and Decoder structure.

    Args:
        variable_types (VarTypes): List of VariableType objects defining the types
            of the variables in the data.
        input_dim (int): Dimension of input data (number of columns in the dataframe).
            If the data contains categorical variables, the input dimension is
            larger than the number of features.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        module_name (str, optional): Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        individual_model (bool, optional): Flag to indicate if the current model
            is applied individually or as part of e.g. a modular HIVAE. Defaults to True.
    """


    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        module_name: Optional[str] = "HIVAE",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ) -> None:


        super().__init__()

        self.variable_types = variable_types
        self.input_dim = input_dim
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_y = dim_y
        self.individual_model = individual_model

        self.use_imp_layer = use_imputation_layer
        self.module_name = module_name

        # Imputation layer
        if use_imputation_layer:
            self.imputation_layer = ImputationLayer(self.input_dim)
        else:
            self.imputation_layer = None

        # Normalization parameters
        self.register_buffer(
            "_mean_data",
            torch.zeros(len(self.variable_types), requires_grad=False),
        )
        self.register_buffer(
            "_std_data",
            torch.ones(len(self.variable_types), requires_grad=False),
        )
        self._batch_mean_data = self._batch_std_data = None

        self.encoder = Encoder(input_dim=input_dim, dim_s=dim_s, dim_z=dim_z)
        self.decoder = Decoder(
            variable_types=variable_types,
            s_dim=dim_s,
            z_dim=dim_z,
            y_dim=dim_y,
            mtl_method=mtl_method,
            decoder_shared=nn.Identity(),
        )
        self._temperature = 1.0
        self.tau = self._temperature

        # set to cpu by default
        self.device = torch.device("cpu")
        self.module_name = module_name

    @cached_property
    def colnames(self) -> Tuple[str, ...]:
        """
        Tuple of column names derived from variable types.

        Returns:
            Tuple[str, ...]: A tuple containing column names.
        """

        return tuple([v.name for v in self.variable_types])

    @property
    def decoding(self) -> bool:
        """
        Decoding flag indicating if the encoder and decoder are in decoding mode.

        Returns:
            bool: Decoding flag.
        """
        assert self.encoder.decoding == self.decoder.decoding
        return self.encoder.decoding

    @decoding.setter
    def decoding(self, value: bool) -> None:
        """
        Sets the decoding flag for both encoder and decoder.

        Args:
            value (bool): The decoding flag to set.
        """
        self.encoder.decoding = value
        self.decoder.decoding = value

    @property
    def tau(self) -> float:
        """
        Gets the temperature parameter for the model.

        Returns:
            float: The temperature parameter.
        """
        return self._temperature

    @tau.setter
    def tau(self, value: float) -> None:
        """
        Sets the temperature parameter for the model.

        Args:
            value (float): The temperature parameter to set.

        Raises:
            ValueError: If the value is not positive.
        """

        if value <= 0:
            raise ValueError(f"Tau must be positive, got {value}")
        self._temperature = value
        self.encoder.tau = value

    @property
    def normalization_parameters(self) -> NormalizationParameters:
        """
        Gets the normalization parameters (mean and standard deviation).

        Returns:
            NormalizationParameters: The normalization parameters.
        """

        if self._batch_std_data is not None:
            return NormalizationParameters(
                mean=self._batch_mean_data, std=self._batch_std_data
            )
        else:
            return NormalizationParameters(
                mean=self._mean_data, std=self._std_data
            )

    @normalization_parameters.setter
    def normalization_parameters(self, value: NormalizationParameters) -> None:
        """
        Sets the normalization parameters (mean and standard deviation).

        Args:
            value (NormalizationParameters): The normalization parameters to set.
        """

        if self.training:
            # calculate the mean and std of the data with momentum
            momentum = 0.01
            # Choosing a momentum of 0.01 to update the mean and std of the data
            # this leads to a smoother update of the mean and std of the data
            new_mean = (1 - momentum) * self._mean_data + momentum * value.mean
            new_std = (1 - momentum) * self._std_data + momentum * value.std
            self._mean_data = new_mean
            self._std_data = new_std
            self._batch_mean_data = value.mean
            self._batch_std_data = value.std
        else:
            logger.warning(
                "Running parameters are not updated during evaluation"
            )
            self._batch_mean_data = value.mean
            self._batch_std_data = value.std

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """
        Forward pass through the HIVAE model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the HIVAE model.
        """

        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """
        Pass through the decoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: The output of the decoder.
        """

        decoder_output = self.decoder(
            data=data,
            mask=mask,
            encoder_output=encoder_output,
            normalization_parameters=self.normalization_parameters,
        )

        return decoder_output

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """
        Pass through the encoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

        Raises:
            ValueError: If data tensor has invalid shape.
        """

        if data.ndim == 3 and data.shape[1] == 1:
            data = data.squeeze(1)
            mask = mask.squeeze(1)
        elif data.ndim == 3:
            raise ValueError(
                f"Data should have shape (batch_size, n_features), got {data.shape}"
            )

        self._batch_mean_data = None
        self._batch_std_data = None
        (
            new_x,
            new_mask,
            self.normalization_parameters,
        ) = Normalization.normalize_data(
            data, mask, variable_types=self.variable_types, prior_parameters=self.normalization_parameters
        )

        if self.use_imp_layer and self.imputation_layer is not None:
            new_x = self.imputation_layer(new_x, new_mask)

        encoder_output = self.encoder(new_x)
        return data, mask, encoder_output

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (HivaeEncoding): The encoding to decode.

        Returns:
            torch.Tensor: The reconstructed data tensor.
        """

        return self.decoder.decode(
            encoding_s=encoding.s,
            encoding_z=encoding.decoder_representation,
            normalization_params=self.normalization_parameters,
        )

    def _predict_step(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> HivaeOutput:
        """
        Prediction step without gradient calculation.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the HIVAE model.
        """

        with torch.no_grad():
            return self.forward(data, mask)

    def _test_step(self, data: torch.Tensor, mask: torch.Tensor) -> float:
        """
        Test step to evaluate the model on test data.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            float: The test loss.
        """

        with torch.no_grad():
            return self.forward(data, mask).loss.detach()

    def _training_step(
        self, data: torch.Tensor, mask: torch.Tensor, optimizer: Optimizer
    ) -> float:
        """
        Training step to update the model parameters.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            optimizer (Optimizer): Optimizer for updating model parameters.

        Returns:
            float: The training loss.
        """

        optimizer.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        optimizer.step()
        return output.loss.detach()

    def _validation_step(self, data: torch.Tensor, mask: torch.Tensor) -> float:
        """
        Validation step to evaluate the model on validation data.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            float: The validation loss.
        """

        with torch.no_grad():
            return self.forward(data, mask).loss.detach()

colnames: Tuple[str, ...] cached property

Tuple of column names derived from variable types.

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: A tuple containing column names.

decoding: bool property writable

Decoding flag indicating if the encoder and decoder are in decoding mode.

Returns:

Name Type Description
bool bool

Decoding flag.

normalization_parameters: NormalizationParameters property writable

Gets the normalization parameters (mean and standard deviation).

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

The normalization parameters.

tau: float property writable

Gets the temperature parameter for the model.

Returns:

Name Type Description
float float

The temperature parameter.

decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding HivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: The reconstructed data tensor.

Source code in vambn/modelling/models/hivae/hivae.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (HivaeEncoding): The encoding to decode.

    Returns:
        torch.Tensor: The reconstructed data tensor.
    """

    return self.decoder.decode(
        encoding_s=encoding.s,
        encoding_z=encoding.decoder_representation,
        normalization_params=self.normalization_parameters,
    )

decoder_part(data, mask, encoder_output)

Pass through the decoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the decoder.

Source code in vambn/modelling/models/hivae/hivae.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """
    Pass through the decoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: The output of the decoder.
    """

    decoder_output = self.decoder(
        data=data,
        mask=mask,
        encoder_output=encoder_output,
        normalization_parameters=self.normalization_parameters,
    )

    return decoder_output

encoder_part(data, mask)

Pass through the encoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

Raises:

Type Description
ValueError

If data tensor has invalid shape.

Source code in vambn/modelling/models/hivae/hivae.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """
    Pass through the encoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

    Raises:
        ValueError: If data tensor has invalid shape.
    """

    if data.ndim == 3 and data.shape[1] == 1:
        data = data.squeeze(1)
        mask = mask.squeeze(1)
    elif data.ndim == 3:
        raise ValueError(
            f"Data should have shape (batch_size, n_features), got {data.shape}"
        )

    self._batch_mean_data = None
    self._batch_std_data = None
    (
        new_x,
        new_mask,
        self.normalization_parameters,
    ) = Normalization.normalize_data(
        data, mask, variable_types=self.variable_types, prior_parameters=self.normalization_parameters
    )

    if self.use_imp_layer and self.imputation_layer is not None:
        new_x = self.imputation_layer(new_x, new_mask)

    encoder_output = self.encoder(new_x)
    return data, mask, encoder_output

forward(data, mask)

Forward pass through the HIVAE model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the HIVAE model.

Source code in vambn/modelling/models/hivae/hivae.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """
    Forward pass through the HIVAE model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        HivaeOutput: The output of the HIVAE model.
    """

    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output

LstmHivae

Bases: Hivae

LSTM-based HIVAE model with Encoder and Decoder structure for temporal data.

Parameters:

Name Type Description Default
variable_types VarTypes

List of VariableType objects defining the types of the variables in the data.

required
input_dim int

Dimension of input data (number of columns in the dataframe). If the data contains categorical variables, the input dimension is larger than the number of features.

required
dim_s int

Dimension of s space.

required
dim_z int

Dimension of z space.

required
dim_y int

Dimension of y space.

required
n_layers int

Number of layers in the LSTM.

required
num_timepoints int

Number of time points in the temporal data.

required
module_name str

Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.

'HIVAE'
mtl_method Tuple[str]

List of methods to use for multi-task learning. Assessed possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
individual_model bool

Flag to indicate if the current model is applied individually or as part of e.g. a modular HIVAE. Defaults to True.

True
Source code in vambn/modelling/models/hivae/hivae.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
class LstmHivae(Hivae):
    """
    LSTM-based HIVAE model with Encoder and Decoder structure for temporal data.

    Args:
        variable_types (VarTypes): List of VariableType objects defining the types
            of the variables in the data.
        input_dim (int): Dimension of input data (number of columns in the dataframe).
            If the data contains categorical variables, the input dimension is
            larger than the number of features.
        dim_s (int): Dimension of s space.
        dim_z (int): Dimension of z space.
        dim_y (int): Dimension of y space.
        n_layers (int): Number of layers in the LSTM.
        num_timepoints (int): Number of time points in the temporal data.
        module_name (str, optional): Name of the module this HIVAE is associated with. Defaults to 'HIVAE'.
        mtl_method (Tuple[str], optional): List of methods to use for multi-task learning.
            Assessed possibilities are combinations of "identity", "gradnorm", "graddrop".
            Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        individual_model (bool, optional): Flag to indicate if the current model
            is applied individually or as part of e.g. a modular HIVAE. Defaults to True.
    """

    def __init__(
        self,
        variable_types: VarTypes,
        input_dim: int,
        dim_s: int,
        dim_z: int,
        dim_y: int,
        n_layers: int,
        num_timepoints: int,
        module_name: str | None = "HIVAE",
        mtl_method: Tuple[str] = ("identity",),
        use_imputation_layer: bool = False,
        individual_model: bool = True,
    ) -> None:
        super().__init__(
            variable_types,
            input_dim,
            dim_s,
            dim_z,
            dim_y,
            module_name,
            mtl_method,
            use_imputation_layer,
            individual_model,
        )
        self.n_layers = n_layers
        self.num_timepoints = num_timepoints

        self.encoder = LstmEncoder(
            input_dimension=input_dim,
            dim_s=dim_s,
            dim_z=dim_z,
            n_layers=n_layers,
            hidden_size=input_dim,
        )
        self.decoder = LstmDecoder(
            mtl_method=mtl_method,
            n_layers=n_layers,
            num_timepoints=num_timepoints,
            s_dim=dim_s,
            variable_types=variable_types,
            y_dim=dim_y,
            z_dim=dim_z,
            decoder_shared=nn.Identity(),
        )
        self.imputation_layer = nn.ModuleList(
            [ImputationLayer(input_dim) for _ in range(num_timepoints)]
        )

        # set normalization params for each timepoint
        self.register_buffer(
            "_mean_data",
            torch.zeros(
                num_timepoints, len(self.variable_types), requires_grad=False
            ),
        )
        self.register_buffer(
            "_std_data",
            torch.ones(
                num_timepoints, len(self.variable_types), requires_grad=False
            ),
        )

    def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
        """
        Forward pass through the LSTM HIVAE model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            HivaeOutput: The output of the LSTM HIVAE model.
        """

        data, mask, encoder_output = self.encoder_part(data, mask)
        decoder_output = self.decoder_part(data, mask, encoder_output)
        return decoder_output

    def decoder_part(
        self,
        data: torch.Tensor,
        mask: torch.Tensor,
        encoder_output: EncoderOutput,
    ) -> HivaeOutput:
        """
        Pass through the decoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.
            encoder_output (EncoderOutput): Output from the encoder.

        Returns:
            HivaeOutput: The output of the decoder.
        """

        decoder_output = self.decoder(
            data=data,
            mask=mask,
            encoder_output=encoder_output,
            normalization_parameters=self.normalization_parameters,
        )

        return decoder_output

    def encoder_part(
        self, data: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
        """
        Pass through the encoder part of the model.

        Args:
            data (torch.Tensor): Input data tensor.
            mask (torch.Tensor): Mask tensor indicating missing values.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.
        """

        time_point_x = []
        time_point_mask = []
        normalization_parameters = [None] * self.num_timepoints
        for i in range(self.num_timepoints):
            (
                new_x,
                new_mask,
                normalization_parameters[i],
            ) = Normalization.normalize_data(
                data[:, i], mask[:, i], variable_types=self.variable_types, prior_parameters=self.normalization_parameters[i]
            )
            time_point_x.append(new_x)
            time_point_mask.append(new_mask)

        self.normalization_parameters = reduce(
            lambda x, y: x + y, normalization_parameters
        )
        new_x = torch.stack(time_point_x, dim=1)
        new_mask = torch.stack(time_point_mask, dim=1)

        if self.use_imp_layer and self.imputation_layer is not None:
            for i in range(self.num_timepoints):
                new_x[:, i] = self.imputation_layer[i](
                    new_x[:, i], new_mask[:, i]
                )

        encoder_output = self.encoder(new_x)
        return data, mask, encoder_output

    def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (HivaeEncoding): The encoding to decode.

        Returns:
            torch.Tensor: The reconstructed data tensor.
        """

        return self.decoder.decode(
            encoding_s=encoding.s,
            encoding_z=encoding.decoder_representation,
            normalization_params=self.normalization_parameters,
        )

decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding HivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tensor

torch.Tensor: The reconstructed data tensor.

Source code in vambn/modelling/models/hivae/hivae.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
def decode(self, encoding: HivaeEncoding) -> torch.Tensor:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (HivaeEncoding): The encoding to decode.

    Returns:
        torch.Tensor: The reconstructed data tensor.
    """

    return self.decoder.decode(
        encoding_s=encoding.s,
        encoding_z=encoding.decoder_representation,
        normalization_params=self.normalization_parameters,
    )

decoder_part(data, mask, encoder_output)

Pass through the decoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required
encoder_output EncoderOutput

Output from the encoder.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the decoder.

Source code in vambn/modelling/models/hivae/hivae.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
def decoder_part(
    self,
    data: torch.Tensor,
    mask: torch.Tensor,
    encoder_output: EncoderOutput,
) -> HivaeOutput:
    """
    Pass through the decoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.
        encoder_output (EncoderOutput): Output from the encoder.

    Returns:
        HivaeOutput: The output of the decoder.
    """

    decoder_output = self.decoder(
        data=data,
        mask=mask,
        encoder_output=encoder_output,
        normalization_parameters=self.normalization_parameters,
    )

    return decoder_output

encoder_part(data, mask)

Pass through the encoder part of the model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Type Description
Tuple[Tensor, Tensor, EncoderOutput]

Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.

Source code in vambn/modelling/models/hivae/hivae.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
def encoder_part(
    self, data: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, EncoderOutput]:
    """
    Pass through the encoder part of the model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, EncoderOutput]: Processed data, mask, and encoder output.
    """

    time_point_x = []
    time_point_mask = []
    normalization_parameters = [None] * self.num_timepoints
    for i in range(self.num_timepoints):
        (
            new_x,
            new_mask,
            normalization_parameters[i],
        ) = Normalization.normalize_data(
            data[:, i], mask[:, i], variable_types=self.variable_types, prior_parameters=self.normalization_parameters[i]
        )
        time_point_x.append(new_x)
        time_point_mask.append(new_mask)

    self.normalization_parameters = reduce(
        lambda x, y: x + y, normalization_parameters
    )
    new_x = torch.stack(time_point_x, dim=1)
    new_mask = torch.stack(time_point_mask, dim=1)

    if self.use_imp_layer and self.imputation_layer is not None:
        for i in range(self.num_timepoints):
            new_x[:, i] = self.imputation_layer[i](
                new_x[:, i], new_mask[:, i]
            )

    encoder_output = self.encoder(new_x)
    return data, mask, encoder_output

forward(data, mask)

Forward pass through the LSTM HIVAE model.

Parameters:

Name Type Description Default
data Tensor

Input data tensor.

required
mask Tensor

Mask tensor indicating missing values.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The output of the LSTM HIVAE model.

Source code in vambn/modelling/models/hivae/hivae.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
def forward(self, data: torch.Tensor, mask: torch.Tensor) -> HivaeOutput:
    """
    Forward pass through the LSTM HIVAE model.

    Args:
        data (torch.Tensor): Input data tensor.
        mask (torch.Tensor): Mask tensor indicating missing values.

    Returns:
        HivaeOutput: The output of the LSTM HIVAE model.
    """

    data, mask, encoder_output = self.encoder_part(data, mask)
    decoder_output = self.decoder_part(data, mask, encoder_output)
    return decoder_output

modular

ModularHivae

Bases: AbstractModularModel[Tuple[Tensor, ...], Tuple[Tensor, ...], ModularHivaeOutput, ModularHivaeEncoding]

Modular HIVAE model containing multiple data modules.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for each data module. See DataModuleConfig for details.

required
dim_s int | Dict[str, int]

Number of mixture components for each module individually (dict) or a single value for all modules (int).

required
dim_z int

Dimension of the latent space. Equal for all modules.

required
dim_ys int

Dimension of the latent space ys. Equal for all modules.

required
dim_y int | Dict[str, int]

Dimension of the latent variable y for each module or a single value for all modules.

required
shared_element_type str

Type of shared element. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
Source code in vambn/modelling/models/hivae/modular.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class ModularHivae(
    AbstractModularModel[
        Tuple[Tensor, ...],
        Tuple[Tensor, ...],
        ModularHivaeOutput,
        ModularHivaeEncoding,
    ]
):
    """
    Modular HIVAE model containing multiple data modules.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for each data module. See DataModuleConfig for details.
        dim_s (int | Dict[str, int]): Number of mixture components for each module individually (dict) or a single value for all modules (int).
        dim_z (int): Dimension of the latent space. Equal for all modules.
        dim_ys (int): Dimension of the latent space ys. Equal for all modules.
        dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
        shared_element_type (str, optional): Type of shared element. Possible values are "none", "sharedLinear", "concatMtl", "concatIndiv", "avgMtl", "maxMtl", "encoder", "encoderMtl". Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for multi-task learning. Tested possibilities are combinations of "identity", "gradnorm", "graddrop". Further implementations and details can be found in the mtl.py file. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
    """

    def __init__(
        self,
        module_config: Tuple[DataModuleConfig],
        dim_s: int | Dict[str, int],
        dim_z: int,
        dim_ys: int,
        dim_y: int | Dict[str, int],
        shared_element_type: str = "none",
        mtl_method: Tuple[str, ...] = ("identity",),
        use_imputation_layer: bool = False,
    ):
        """
        Initialize the modular HIVAE model.

        Args:
            module_config (Tuple[DataModuleConfig]): Configuration for each data module.
            dim_s (int | Dict[str, int]): Number of the mixture components for each module or a single value for all modules.
            dim_z (int): Dimension of the latent space. Equal for all modules.
            dim_ys (int): Dimension of the latent space ys. Equal for all modules.
            dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
            shared_element_type (str, optional): Type of shared element. Defaults to "none".
            mtl_method (Tuple[str, ...], optional): Methods for MTL. Defaults to ("identity",).
            use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
        """

        super().__init__()
        self.module_configs = module_config
        self.mtl_method = mtl_method
        self.use_imputation_layer = use_imputation_layer
        self.dim_s = dim_s
        self.dim_z = dim_z
        self.dim_ys = dim_ys
        self.dim_y = dim_y

        if not isinstance(dim_s, int) and len(dim_s) != len(module_config):
            raise ValueError(
                "If dim_s is a tuple, it must have the same length as module_config"
            )

        # Initialize the shared element
        # The shared element is a module that takes the z samples from the encoder outputs
        # and generates a shared representation ys, finally a representation y for each module
        shared_element_class = SHARED_MODULES[shared_element_type]
        if shared_element_class is None:
            raise ValueError(
                f"Shared element {shared_element_type} is not available"
            )
        self.shared_element = shared_element_class(
            z_dim=self.dim_z,
            n_modules=len(module_config),
            ys_dim=self.dim_ys,
            y_dim=self.dim_y
            if isinstance(dim_y, int)
            else tuple([dim_y[module.name] for module in self.module_configs]),
            mtl_method=mtl_method if mtl_method else ("graddrop",),
            module_names=[module.name for module in module_config],
        )

        module_models = {}
        for module in module_config:
            module_name = module.name
            if module.is_longitudinal:
                module_models[module_name] = LstmHivae(
                    dim_s=dim_s
                    if isinstance(dim_s, int)
                    else dim_s[module_name],
                    dim_y=dim_y
                    if isinstance(dim_y, int)
                    else dim_y[module_name],
                    dim_z=dim_z,
                    individual_model=False,
                    input_dim=module.input_dim,
                    module_name=module_name,
                    mtl_method=mtl_method,
                    n_layers=module.n_layers,
                    num_timepoints=module.num_timepoints,
                    use_imputation_layer=use_imputation_layer,
                    variable_types=module.variable_types,
                )
            else:
                module_models[module_name] = Hivae(
                    dim_s=dim_s
                    if isinstance(dim_s, int)
                    else dim_s[module_name],
                    dim_y=dim_y
                    if isinstance(dim_y, int)
                    else dim_y[module_name],
                    dim_z=dim_z,
                    individual_model=False,
                    input_dim=module.input_dim,
                    module_name=module_name,
                    mtl_method=mtl_method,
                    use_imputation_layer=use_imputation_layer,
                    variable_types=module.variable_types,
                )

        self.module_models = nn.ModuleDict(module_models)
        self._tau = 1.0

    @property
    def decoding(self) -> bool:
        """
        Decoding flag indicating if the encoder and decoder are in decoding mode.

        Returns:
            bool: Decoding flag.
        """
        assert all([module.decoding for module in self.module_models.values()])
        return self.module_models[self.module_configs[0].name].decoding


    @decoding.setter
    def decoding(self, value: bool) -> None:
        """
        Sets the decoding flag for all modules.

        Args:
            value (bool): The decoding flag to set.
        """
        for module in self.module_models.values():
            module.decoding = value


    def colnames(self, module_name: str) -> Tuple[str, ...]:
        """
        Get column names for a specific module.

        Args:
            module_name (str): Name of the module.

        Returns:
            Tuple[str, ...]: Column names for the specified module.
        """
        return self.module_models[module_name].colnames


    def is_longitudinal(self, module_name: str) -> bool:
        """
        Check if a specific module is longitudinal.

        Args:
            module_name (str): Name of the module.

        Returns:
            bool: True if the module is longitudinal, False otherwise.
        """
        return self.module_models[module_name].is_longitudinal


    def forward(
        self, data: Tuple[Tensor, ...], mask: Tuple[Tensor, ...]
    ) -> ModularHivaeOutput:
        """
        Forward pass through the modular HIVAE model.

        Args:
            data (Tuple[Tensor, ...]): Input data tensors for each module.
            mask (Tuple[Tensor, ...]): Mask tensors indicating missing values for each module.

        Returns:
            ModularHivaeOutput: The output of the modular HIVAE model.
        """

        # Generate the encoder outputs and data for each module
        # This step does not differ from the normal HIVAE model
        encoder_outputs: List[EncoderOutput] = []
        modified_data = []
        modified_mask = []
        for i, module in enumerate(self.module_configs):
            mdata, mmask, output = self.module_models[module.name].encoder_part(
                data[i], mask[i]
            )
            encoder_outputs.append(output)
            modified_data.append(mdata)
            modified_mask.append(mmask)

        # Then retrieve the z samples from the encoder outputs
        # These z samples are then passed through the shared element to
        # achieve the modularity
        # The shared representation is a tuple of tensors, one for each module
        z_samples = tuple([x.samples_z for x in encoder_outputs])
        shared_representations = self.shared_element(z=z_samples)

        # This shared representation is then passed back to the encoder outputs
        for i, enc in enumerate(encoder_outputs):
            enc.decoder_representation = shared_representations[i]

        # Then we continue with the decoder part as usual
        decoder_outputs = []
        for i, module in enumerate(self.module_configs):
            output = self.module_models[module.name].decoder_part(
                data=modified_data[i],
                mask=modified_mask[i],
                encoder_output=encoder_outputs[i],
            )
            decoder_outputs.append(output)

        gathered_output = ModularHivaeOutput(outputs=tuple(decoder_outputs))
        return gathered_output

    @property
    def tau(self):
        """
        Get the temperature parameter for the model.

        Returns:
            float: The temperature parameter.
        """

        return self._tau

    @tau.setter
    def tau(self, value: float):
        """
        Set the temperature parameter for the model.

        Args:
            value (float): The temperature parameter to set.
        """

        self._tau = value
        for module in self.module_models.values():
            module.tau = value

    def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
        """
        Decode the given encoding to reconstruct the input data.

        Args:
            encoding (ModularHivaeEncoding): The encoding to decode.

        Returns:
            Tuple[Tensor, ...]: The reconstructed data tensors for each module.
        """

        self.eval()
        self.tau = 1e-3
        z_samples = tuple([x.z for x in encoding])

        modified_output = self.shared_element(z=z_samples)
        output_mapping = {
            x: i for i, x in enumerate(self.shared_element.module_names)
        }

        for enc in encoding.encodings:
            out_pos = output_mapping[enc.module]
            enc.decoder_representation = modified_output[out_pos]

        outputs = []
        for i, module in enumerate(self.module_configs):
            module_enc = encoding.get(module.name)
            output = self.module_models[module.name].decode(module_enc)
            outputs.append(output)
        assert len(outputs) == len(self.module_configs)
        return tuple(outputs)

    def _training_step(
        self,
        data: Tuple[Tensor],
        mask: Tuple[Tensor],
        optimizer: Tuple[optim.Optimizer],
    ) -> float:
        """
        Perform a training step to update the model parameters.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.
            optimizer (Tuple[optim.Optimizer]): Optimizers for updating model parameters.

        Returns:
            float: The training loss.
        """

        # set all gradients to zero
        for opt in optimizer:
            if opt is None:
                continue
            opt.zero_grad()
        output = self.forward(data, mask)
        self.fabric.backward(output.loss)
        for opt in optimizer:
            if opt is None:
                continue
            opt.step()
        return output.loss.item()

    def _validation_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> float:
        """
        Perform a validation step to evaluate the model.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            float: The validation loss.
        """

        output = self.forward(data, mask)
        return output.loss.item()

    def _test_step(self, data: Tuple[Tensor], mask: Tuple[Tensor]) -> float:
        """
        Perform a test step to evaluate the model on test data.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            float: The test loss.
        """

        output = self.forward(data, mask)
        return output.loss.item()

    def _predict_step(
        self, data: Tuple[Tensor], mask: Tuple[Tensor]
    ) -> ModularHivaeOutput:
        """
        Perform a prediction step without gradient calculation.

        Args:
            data (Tuple[Tensor]): Input data tensors for each module.
            mask (Tuple[Tensor]): Mask tensors indicating missing values for each module.

        Returns:
            ModularHivaeOutput: The output of the modular HIVAE model.
        """

        return self.forward(data, mask)

decoding: bool property writable

Decoding flag indicating if the encoder and decoder are in decoding mode.

Returns:

Name Type Description
bool bool

Decoding flag.

tau property writable

Get the temperature parameter for the model.

Returns:

Name Type Description
float

The temperature parameter.

__init__(module_config, dim_s, dim_z, dim_ys, dim_y, shared_element_type='none', mtl_method=('identity'), use_imputation_layer=False)

Initialize the modular HIVAE model.

Parameters:

Name Type Description Default
module_config Tuple[DataModuleConfig]

Configuration for each data module.

required
dim_s int | Dict[str, int]

Number of the mixture components for each module or a single value for all modules.

required
dim_z int

Dimension of the latent space. Equal for all modules.

required
dim_ys int

Dimension of the latent space ys. Equal for all modules.

required
dim_y int | Dict[str, int]

Dimension of the latent variable y for each module or a single value for all modules.

required
shared_element_type str

Type of shared element. Defaults to "none".

'none'
mtl_method Tuple[str, ...]

Methods for MTL. Defaults to ("identity",).

('identity')
use_imputation_layer bool

Flag to indicate if imputation layer should be used. Defaults to False.

False
Source code in vambn/modelling/models/hivae/modular.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    module_config: Tuple[DataModuleConfig],
    dim_s: int | Dict[str, int],
    dim_z: int,
    dim_ys: int,
    dim_y: int | Dict[str, int],
    shared_element_type: str = "none",
    mtl_method: Tuple[str, ...] = ("identity",),
    use_imputation_layer: bool = False,
):
    """
    Initialize the modular HIVAE model.

    Args:
        module_config (Tuple[DataModuleConfig]): Configuration for each data module.
        dim_s (int | Dict[str, int]): Number of the mixture components for each module or a single value for all modules.
        dim_z (int): Dimension of the latent space. Equal for all modules.
        dim_ys (int): Dimension of the latent space ys. Equal for all modules.
        dim_y (int | Dict[str, int]): Dimension of the latent variable y for each module or a single value for all modules.
        shared_element_type (str, optional): Type of shared element. Defaults to "none".
        mtl_method (Tuple[str, ...], optional): Methods for MTL. Defaults to ("identity",).
        use_imputation_layer (bool, optional): Flag to indicate if imputation layer should be used. Defaults to False.
    """

    super().__init__()
    self.module_configs = module_config
    self.mtl_method = mtl_method
    self.use_imputation_layer = use_imputation_layer
    self.dim_s = dim_s
    self.dim_z = dim_z
    self.dim_ys = dim_ys
    self.dim_y = dim_y

    if not isinstance(dim_s, int) and len(dim_s) != len(module_config):
        raise ValueError(
            "If dim_s is a tuple, it must have the same length as module_config"
        )

    # Initialize the shared element
    # The shared element is a module that takes the z samples from the encoder outputs
    # and generates a shared representation ys, finally a representation y for each module
    shared_element_class = SHARED_MODULES[shared_element_type]
    if shared_element_class is None:
        raise ValueError(
            f"Shared element {shared_element_type} is not available"
        )
    self.shared_element = shared_element_class(
        z_dim=self.dim_z,
        n_modules=len(module_config),
        ys_dim=self.dim_ys,
        y_dim=self.dim_y
        if isinstance(dim_y, int)
        else tuple([dim_y[module.name] for module in self.module_configs]),
        mtl_method=mtl_method if mtl_method else ("graddrop",),
        module_names=[module.name for module in module_config],
    )

    module_models = {}
    for module in module_config:
        module_name = module.name
        if module.is_longitudinal:
            module_models[module_name] = LstmHivae(
                dim_s=dim_s
                if isinstance(dim_s, int)
                else dim_s[module_name],
                dim_y=dim_y
                if isinstance(dim_y, int)
                else dim_y[module_name],
                dim_z=dim_z,
                individual_model=False,
                input_dim=module.input_dim,
                module_name=module_name,
                mtl_method=mtl_method,
                n_layers=module.n_layers,
                num_timepoints=module.num_timepoints,
                use_imputation_layer=use_imputation_layer,
                variable_types=module.variable_types,
            )
        else:
            module_models[module_name] = Hivae(
                dim_s=dim_s
                if isinstance(dim_s, int)
                else dim_s[module_name],
                dim_y=dim_y
                if isinstance(dim_y, int)
                else dim_y[module_name],
                dim_z=dim_z,
                individual_model=False,
                input_dim=module.input_dim,
                module_name=module_name,
                mtl_method=mtl_method,
                use_imputation_layer=use_imputation_layer,
                variable_types=module.variable_types,
            )

    self.module_models = nn.ModuleDict(module_models)
    self._tau = 1.0

colnames(module_name)

Get column names for a specific module.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Type Description
Tuple[str, ...]

Tuple[str, ...]: Column names for the specified module.

Source code in vambn/modelling/models/hivae/modular.py
168
169
170
171
172
173
174
175
176
177
178
def colnames(self, module_name: str) -> Tuple[str, ...]:
    """
    Get column names for a specific module.

    Args:
        module_name (str): Name of the module.

    Returns:
        Tuple[str, ...]: Column names for the specified module.
    """
    return self.module_models[module_name].colnames

decode(encoding)

Decode the given encoding to reconstruct the input data.

Parameters:

Name Type Description Default
encoding ModularHivaeEncoding

The encoding to decode.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[Tensor, ...]: The reconstructed data tensors for each module.

Source code in vambn/modelling/models/hivae/modular.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def decode(self, encoding: ModularHivaeEncoding) -> Tuple[Tensor, ...]:
    """
    Decode the given encoding to reconstruct the input data.

    Args:
        encoding (ModularHivaeEncoding): The encoding to decode.

    Returns:
        Tuple[Tensor, ...]: The reconstructed data tensors for each module.
    """

    self.eval()
    self.tau = 1e-3
    z_samples = tuple([x.z for x in encoding])

    modified_output = self.shared_element(z=z_samples)
    output_mapping = {
        x: i for i, x in enumerate(self.shared_element.module_names)
    }

    for enc in encoding.encodings:
        out_pos = output_mapping[enc.module]
        enc.decoder_representation = modified_output[out_pos]

    outputs = []
    for i, module in enumerate(self.module_configs):
        module_enc = encoding.get(module.name)
        output = self.module_models[module.name].decode(module_enc)
        outputs.append(output)
    assert len(outputs) == len(self.module_configs)
    return tuple(outputs)

forward(data, mask)

Forward pass through the modular HIVAE model.

Parameters:

Name Type Description Default
data Tuple[Tensor, ...]

Input data tensors for each module.

required
mask Tuple[Tensor, ...]

Mask tensors indicating missing values for each module.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The output of the modular HIVAE model.

Source code in vambn/modelling/models/hivae/modular.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def forward(
    self, data: Tuple[Tensor, ...], mask: Tuple[Tensor, ...]
) -> ModularHivaeOutput:
    """
    Forward pass through the modular HIVAE model.

    Args:
        data (Tuple[Tensor, ...]): Input data tensors for each module.
        mask (Tuple[Tensor, ...]): Mask tensors indicating missing values for each module.

    Returns:
        ModularHivaeOutput: The output of the modular HIVAE model.
    """

    # Generate the encoder outputs and data for each module
    # This step does not differ from the normal HIVAE model
    encoder_outputs: List[EncoderOutput] = []
    modified_data = []
    modified_mask = []
    for i, module in enumerate(self.module_configs):
        mdata, mmask, output = self.module_models[module.name].encoder_part(
            data[i], mask[i]
        )
        encoder_outputs.append(output)
        modified_data.append(mdata)
        modified_mask.append(mmask)

    # Then retrieve the z samples from the encoder outputs
    # These z samples are then passed through the shared element to
    # achieve the modularity
    # The shared representation is a tuple of tensors, one for each module
    z_samples = tuple([x.samples_z for x in encoder_outputs])
    shared_representations = self.shared_element(z=z_samples)

    # This shared representation is then passed back to the encoder outputs
    for i, enc in enumerate(encoder_outputs):
        enc.decoder_representation = shared_representations[i]

    # Then we continue with the decoder part as usual
    decoder_outputs = []
    for i, module in enumerate(self.module_configs):
        output = self.module_models[module.name].decoder_part(
            data=modified_data[i],
            mask=modified_mask[i],
            encoder_output=encoder_outputs[i],
        )
        decoder_outputs.append(output)

    gathered_output = ModularHivaeOutput(outputs=tuple(decoder_outputs))
    return gathered_output

is_longitudinal(module_name)

Check if a specific module is longitudinal.

Parameters:

Name Type Description Default
module_name str

Name of the module.

required

Returns:

Name Type Description
bool bool

True if the module is longitudinal, False otherwise.

Source code in vambn/modelling/models/hivae/modular.py
181
182
183
184
185
186
187
188
189
190
191
def is_longitudinal(self, module_name: str) -> bool:
    """
    Check if a specific module is longitudinal.

    Args:
        module_name (str): Name of the module.

    Returns:
        bool: True if the module is longitudinal, False otherwise.
    """
    return self.module_models[module_name].is_longitudinal

normalization

Normalization

Class for normalization utilities, including broadcasting masks and normalizing/denormalizing data.

Source code in vambn/modelling/models/hivae/normalization.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class Normalization:
    """
    Class for normalization utilities, including broadcasting masks and normalizing/denormalizing data.
    """
    @staticmethod
    def _broadcast_mask(
        mask: torch.Tensor, variable_types: VarTypes
    ) -> torch.Tensor:
        """
        Broadcast the mask tensor to match the shape required by variable types.

        Args:
            mask (torch.Tensor): The input mask tensor.
            variable_types (VarTypes): The variable types information.

        Returns:
            torch.Tensor: The broadcasted mask tensor.
        """
        # if all(d.n_parameters == 1 for d in variable_types):
        #     return mask

        new_mask = []
        for i, vtype in enumerate(variable_types):
            if vtype.data_type == "cat":
                new_mask.append(
                    mask[:, i].unsqueeze(1).expand(-1, vtype.n_parameters)
                )
            else:
                new_mask.append(mask[:, i].unsqueeze(1))
        return torch.cat(new_mask, dim=1)

    @staticmethod
    def normalize_data(
        x: torch.Tensor,
        mask: torch.Tensor,
        variable_types: VarTypes,
        prior_parameters: NormalizationParameters,
        eps: float = 1e-6,
    ) -> Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]:
        """
        Normalize the input data based on variable types and prior parameters.

        Args:
            x (torch.Tensor): The input data tensor.
            mask (torch.Tensor): The mask tensor indicating missing values.
            variable_types (VarTypes): The variable types information.
            prior_parameters (NormalizationParameters): The prior normalization parameters.
            eps (float, optional): A small value to prevent division by zero. Defaults to 1e-6.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.
        """
        if x.ndim == 3 and x.shape[1] == 1:
            x = x.squeeze(1)
            mask = mask.squeeze(1)

        assert len(variable_types) == x.shape[-1]
        mean_data = prior_parameters.mean
        std_data = prior_parameters.std
        new_x = []
        for i, vtype in enumerate(variable_types):
            x_i = torch.masked_select(x[..., i], mask[..., i].bool())
            new_x_i = torch.unsqueeze(x[..., i], -1)

            if vtype.data_type == "real" or vtype.data_type == "truncate_norm":
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (new_x_i - mean_data[i]) / std_data[i]
            elif vtype.data_type == "pos":
                x_i = torch.log1p(x_i)
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
            elif vtype.data_type == "gamma":
                x_i = torch.log1p(x_i)
                if x_i.shape[0] >= 4:
                    mean_data[i] = x_i.mean()
                    std_data[i] = x_i.std().clamp(min=eps, max=1e20)

                new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
            elif vtype.data_type == "count":
                new_x_i = torch.log1p(new_x_i)
            elif vtype.data_type == "cat":
                # convert to one hot
                new_x_i = torch.nn.functional.one_hot(
                    new_x_i.long().squeeze(1), vtype.n_parameters
                )

            if torch.isnan(new_x_i).any():
                raise ValueError(
                    f"NaN values found in normalized data for {vtype}"
                )
            if torch.isnan(mean_data[i]) or torch.isnan(std_data[i]):
                raise ValueError(
                    f"NaN values found in normalization parameters for {vtype}"
                )
            new_x.append(new_x_i)

        new_x = torch.cat(new_x, dim=-1)
        mask = Normalization._broadcast_mask(mask, variable_types)
        new_x = new_x * mask
        return (
            new_x,
            mask,
            NormalizationParameters.from_tensors(mean_data, std_data),
        )

    @staticmethod
    def denormalize_params(
        params: Tuple[Parameters, ...],
        variable_types: VarTypes,
        normalization_params: NormalizationParameters,
    ) -> Tuple[Parameters, ...]:
        """
        Denormalize parameters based on variable types and normalization parameters.

        Args:
            etas (Tuple[Etas, ...]): The parameters to denormalize.
            variable_types (VarTypes): The variable types information.
            normalization_params (NormalizationParameters): The normalization parameters.

        Returns:
            Tuple[Parameters, ...]: The denormalized parameters.
        """
        out_params = []
        for i, vtype in enumerate(variable_types):
            param_i = params[i]
            if vtype.data_type in ["truncate_norm", "real"] and isinstance(
                param_i, NormalParameters
            ):
                mean_data, std_data = normalization_params[i]
                std_data = std_data

                mean, std = param_i.loc, param_i.scale
                mean = mean * std_data + mean_data
                std = std * std_data

                out_params.append(NormalParameters(mean, std))
            elif vtype.data_type == "pos" and isinstance(
                param_i, LogNormalParameters
            ):
                mean_data, std_data = normalization_params[i]

                mean, std = param_i.loc, param_i.scale
                mean = mean * std_data + mean_data
                std = std * std_data

                out_params.append(LogNormalParameters(mean, std))
            elif vtype.data_type == "count":
                out_params.append(param_i)
            elif vtype.data_type == "cat":
                out_params.append(param_i)
            else:
                raise ValueError(f"Unknown data type {vtype.data_type}")
        return tuple(params)

denormalize_params(params, variable_types, normalization_params) staticmethod

Denormalize parameters based on variable types and normalization parameters.

Parameters:

Name Type Description Default
etas Tuple[Etas, ...]

The parameters to denormalize.

required
variable_types VarTypes

The variable types information.

required
normalization_params NormalizationParameters

The normalization parameters.

required

Returns:

Type Description
Tuple[Parameters, ...]

Tuple[Parameters, ...]: The denormalized parameters.

Source code in vambn/modelling/models/hivae/normalization.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@staticmethod
def denormalize_params(
    params: Tuple[Parameters, ...],
    variable_types: VarTypes,
    normalization_params: NormalizationParameters,
) -> Tuple[Parameters, ...]:
    """
    Denormalize parameters based on variable types and normalization parameters.

    Args:
        etas (Tuple[Etas, ...]): The parameters to denormalize.
        variable_types (VarTypes): The variable types information.
        normalization_params (NormalizationParameters): The normalization parameters.

    Returns:
        Tuple[Parameters, ...]: The denormalized parameters.
    """
    out_params = []
    for i, vtype in enumerate(variable_types):
        param_i = params[i]
        if vtype.data_type in ["truncate_norm", "real"] and isinstance(
            param_i, NormalParameters
        ):
            mean_data, std_data = normalization_params[i]
            std_data = std_data

            mean, std = param_i.loc, param_i.scale
            mean = mean * std_data + mean_data
            std = std * std_data

            out_params.append(NormalParameters(mean, std))
        elif vtype.data_type == "pos" and isinstance(
            param_i, LogNormalParameters
        ):
            mean_data, std_data = normalization_params[i]

            mean, std = param_i.loc, param_i.scale
            mean = mean * std_data + mean_data
            std = std * std_data

            out_params.append(LogNormalParameters(mean, std))
        elif vtype.data_type == "count":
            out_params.append(param_i)
        elif vtype.data_type == "cat":
            out_params.append(param_i)
        else:
            raise ValueError(f"Unknown data type {vtype.data_type}")
    return tuple(params)

normalize_data(x, mask, variable_types, prior_parameters, eps=1e-06) staticmethod

Normalize the input data based on variable types and prior parameters.

Parameters:

Name Type Description Default
x Tensor

The input data tensor.

required
mask Tensor

The mask tensor indicating missing values.

required
variable_types VarTypes

The variable types information.

required
prior_parameters NormalizationParameters

The prior normalization parameters.

required
eps float

A small value to prevent division by zero. Defaults to 1e-6.

1e-06

Returns:

Type Description
Tuple[Tensor, Tensor, NormalizationParameters]

Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.

Source code in vambn/modelling/models/hivae/normalization.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
@staticmethod
def normalize_data(
    x: torch.Tensor,
    mask: torch.Tensor,
    variable_types: VarTypes,
    prior_parameters: NormalizationParameters,
    eps: float = 1e-6,
) -> Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]:
    """
    Normalize the input data based on variable types and prior parameters.

    Args:
        x (torch.Tensor): The input data tensor.
        mask (torch.Tensor): The mask tensor indicating missing values.
        variable_types (VarTypes): The variable types information.
        prior_parameters (NormalizationParameters): The prior normalization parameters.
        eps (float, optional): A small value to prevent division by zero. Defaults to 1e-6.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, NormalizationParameters]: The normalized data, updated mask, and new normalization parameters.
    """
    if x.ndim == 3 and x.shape[1] == 1:
        x = x.squeeze(1)
        mask = mask.squeeze(1)

    assert len(variable_types) == x.shape[-1]
    mean_data = prior_parameters.mean
    std_data = prior_parameters.std
    new_x = []
    for i, vtype in enumerate(variable_types):
        x_i = torch.masked_select(x[..., i], mask[..., i].bool())
        new_x_i = torch.unsqueeze(x[..., i], -1)

        if vtype.data_type == "real" or vtype.data_type == "truncate_norm":
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (new_x_i - mean_data[i]) / std_data[i]
        elif vtype.data_type == "pos":
            x_i = torch.log1p(x_i)
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
        elif vtype.data_type == "gamma":
            x_i = torch.log1p(x_i)
            if x_i.shape[0] >= 4:
                mean_data[i] = x_i.mean()
                std_data[i] = x_i.std().clamp(min=eps, max=1e20)

            new_x_i = (torch.log1p(new_x_i) - mean_data[i]) / std_data[i]
        elif vtype.data_type == "count":
            new_x_i = torch.log1p(new_x_i)
        elif vtype.data_type == "cat":
            # convert to one hot
            new_x_i = torch.nn.functional.one_hot(
                new_x_i.long().squeeze(1), vtype.n_parameters
            )

        if torch.isnan(new_x_i).any():
            raise ValueError(
                f"NaN values found in normalized data for {vtype}"
            )
        if torch.isnan(mean_data[i]) or torch.isnan(std_data[i]):
            raise ValueError(
                f"NaN values found in normalization parameters for {vtype}"
            )
        new_x.append(new_x_i)

    new_x = torch.cat(new_x, dim=-1)
    mask = Normalization._broadcast_mask(mask, variable_types)
    new_x = new_x * mask
    return (
        new_x,
        mask,
        NormalizationParameters.from_tensors(mean_data, std_data),
    )

NormalizationParameters dataclass

Data class for normalization parameters, including mean and standard deviation.

This class is only used for the parameters of real and pos typed variables.

Parameters:

Name Type Description Default
mean Tensor

The mean values for normalization.

required
std Tensor

The standard deviation values for normalization.

required
Source code in vambn/modelling/models/hivae/normalization.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class NormalizationParameters:
    """
    Data class for normalization parameters, including mean and standard deviation.

    This class is only used for the parameters of real and pos typed variables.

    Args:
        mean (torch.Tensor): The mean values for normalization.
        std (torch.Tensor): The standard deviation values for normalization.
    """

    mean: torch.Tensor
    std: torch.Tensor

    @classmethod
    def from_tensors(
        cls, mean: torch.Tensor, std: torch.Tensor
    ) -> "NormalizationParameters":
        """
        Create NormalizationParameters from mean and std tensors.

        Args:
            mean (torch.Tensor): The mean tensor.
            std (torch.Tensor): The standard deviation tensor.

        Returns:
            NormalizationParameters: An instance of NormalizationParameters.
        """
        return NormalizationParameters(mean, std)

    def __getitem__(
        self, idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor] | "NormalizationParameters":  # type: ignore
        """
        Get normalization parameters by index.

        Args:
            idx (int): The index to retrieve parameters for.

        Returns:
            Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.
        """
        if self.mean.ndim == 1:
            return self.mean[idx], self.std[idx]
        elif self.mean.ndim == 2:
            mean = self.mean[idx, :]
            std = self.std[idx, :]
            return NormalizationParameters(mean, std)

    def __setitem__(
        self, idx: int, value: Tuple[torch.Tensor, torch.Tensor]
    ) -> None:
        """
        Set normalization parameters by index.

        Args:
            idx (int): The index to set parameters for.
            value (Tuple[torch.Tensor, torch.Tensor]): The mean and std tensors to set.
        """
        self.mean[idx], self.std[idx] = value

    def __add__(
        self, other: "NormalizationParameters"
    ) -> "NormalizationParameters":
        """
        Add two NormalizationParameters instances.

        Args:
            other (NormalizationParameters): Another instance to add.

        Returns:
            NormalizationParameters: A new instance with combined parameters.
        """
        if self.mean.ndim == 1:
            # create a 2d tensor by stacking the tensors
            mean = torch.stack([self.mean, other.mean])
            std = torch.stack([self.std, other.std])
        elif self.mean.ndim == 2:
            mean = torch.cat([self.mean, other.mean.unsqueeze(0)], dim=0)
            std = torch.cat([self.std, other.std.unsqueeze(0)], dim=0)
        else:
            raise ValueError("Invalid dimension for mean tensor")

        return NormalizationParameters(mean, std)

__add__(other)

Add two NormalizationParameters instances.

Parameters:

Name Type Description Default
other NormalizationParameters

Another instance to add.

required

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

A new instance with combined parameters.

Source code in vambn/modelling/models/hivae/normalization.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def __add__(
    self, other: "NormalizationParameters"
) -> "NormalizationParameters":
    """
    Add two NormalizationParameters instances.

    Args:
        other (NormalizationParameters): Another instance to add.

    Returns:
        NormalizationParameters: A new instance with combined parameters.
    """
    if self.mean.ndim == 1:
        # create a 2d tensor by stacking the tensors
        mean = torch.stack([self.mean, other.mean])
        std = torch.stack([self.std, other.std])
    elif self.mean.ndim == 2:
        mean = torch.cat([self.mean, other.mean.unsqueeze(0)], dim=0)
        std = torch.cat([self.std, other.std.unsqueeze(0)], dim=0)
    else:
        raise ValueError("Invalid dimension for mean tensor")

    return NormalizationParameters(mean, std)

__getitem__(idx)

Get normalization parameters by index.

Parameters:

Name Type Description Default
idx int

The index to retrieve parameters for.

required

Returns:

Type Description
Tuple[Tensor, Tensor] | NormalizationParameters

Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.

Source code in vambn/modelling/models/hivae/normalization.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __getitem__(
    self, idx: int
) -> Tuple[torch.Tensor, torch.Tensor] | "NormalizationParameters":  # type: ignore
    """
    Get normalization parameters by index.

    Args:
        idx (int): The index to retrieve parameters for.

    Returns:
        Tuple[torch.Tensor, torch.Tensor] | NormalizationParameters: The mean and std tensors or a new NormalizationParameters instance.
    """
    if self.mean.ndim == 1:
        return self.mean[idx], self.std[idx]
    elif self.mean.ndim == 2:
        mean = self.mean[idx, :]
        std = self.std[idx, :]
        return NormalizationParameters(mean, std)

__setitem__(idx, value)

Set normalization parameters by index.

Parameters:

Name Type Description Default
idx int

The index to set parameters for.

required
value Tuple[Tensor, Tensor]

The mean and std tensors to set.

required
Source code in vambn/modelling/models/hivae/normalization.py
67
68
69
70
71
72
73
74
75
76
77
def __setitem__(
    self, idx: int, value: Tuple[torch.Tensor, torch.Tensor]
) -> None:
    """
    Set normalization parameters by index.

    Args:
        idx (int): The index to set parameters for.
        value (Tuple[torch.Tensor, torch.Tensor]): The mean and std tensors to set.
    """
    self.mean[idx], self.std[idx] = value

from_tensors(mean, std) classmethod

Create NormalizationParameters from mean and std tensors.

Parameters:

Name Type Description Default
mean Tensor

The mean tensor.

required
std Tensor

The standard deviation tensor.

required

Returns:

Name Type Description
NormalizationParameters NormalizationParameters

An instance of NormalizationParameters.

Source code in vambn/modelling/models/hivae/normalization.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@classmethod
def from_tensors(
    cls, mean: torch.Tensor, std: torch.Tensor
) -> "NormalizationParameters":
    """
    Create NormalizationParameters from mean and std tensors.

    Args:
        mean (torch.Tensor): The mean tensor.
        std (torch.Tensor): The standard deviation tensor.

    Returns:
        NormalizationParameters: An instance of NormalizationParameters.
    """
    return NormalizationParameters(mean, std)

outputs

DecoderOutput dataclass

Dataclass to hold the output from the decoder.

Attributes:

Name Type Description
log_p_x Tensor

Log-likelihood of the data.

kl_z Tensor

KL divergence for the z variable.

kl_s Tensor

KL divergence for the s variable.

corr_loss Tensor

Correlation loss if applicable.

recon Optional[Tensor]

Reconstruction, if any.

samples Optional[Tensor]

Samples generated, if any.

enc_s Optional[Tensor]

Encoded s values, if any.

enc_z Optional[Tensor]

Encoded z values, if any.

output_name Optional[str]

Name of the output, if any.

detached bool

Whether the tensors have been detached.

Source code in vambn/modelling/models/hivae/outputs.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@dataclass
class DecoderOutput:
    """Dataclass to hold the output from the decoder.

    Attributes:
        log_p_x (Tensor): Log-likelihood of the data.
        kl_z (Tensor): KL divergence for the z variable.
        kl_s (Tensor): KL divergence for the s variable.
        corr_loss (Tensor): Correlation loss if applicable.
        recon (Optional[Tensor]): Reconstruction, if any.
        samples (Optional[Tensor]): Samples generated, if any.
        enc_s (Optional[Tensor]): Encoded s values, if any.
        enc_z (Optional[Tensor]): Encoded z values, if any.
        output_name (Optional[str]): Name of the output, if any.
        detached (bool): Whether the tensors have been detached.
    """


    log_p_x: Tensor
    kl_z: Tensor
    kl_s: Tensor
    corr_loss: Tensor
    recon: Optional[Tensor] = None
    samples: Optional[Tensor] = None
    enc_s: Optional[Tensor] = None
    enc_z: Optional[Tensor] = None
    output_name: Optional[str] = None
    detached: bool = False

    def __post_init__(self):
        """
        Validate dimensions of the log-likelihood tensor.

        Raises:
            Exception: If log-likelihood tensor is not of dimension 1.
        """
        if self.log_p_x.ndim != 1:
            raise Exception(
                f"Log-likelihood is not of correct dimension ({self.log_p_x.ndim}, expected 1)"
            )

    def __str__(self) -> str:
        """
        String representation of the DecoderOutput object.

        Returns:
            str: A string describing the DecoderOutput object.
        """

        if self.output_name is not None:
            return f"Decoder output for {self.output_name})"
        else:
            return f"Decoder output (id={id(self)})"

    def detach_and_move(self) -> "DecoderOutput":
        """
        Detach all tensors and move them to CPU.

        Returns:
            DecoderOutput: The detached DecoderOutput object.
        """

        self.detached = True
        self.log_p_x = self.log_p_x.detach().cpu()
        self.kl_z = self.kl_z.detach().cpu()
        self.kl_s = self.kl_s.detach().cpu()
        if self.corr_loss is not None:
            self.corr_loss = self.corr_loss.detach().cpu()
        if self.samples is not None:
            self.samples = self.samples.detach().cpu()
        if self.enc_s is not None:
            self.enc_s = self.enc_s.detach().cpu()
        if self.enc_z is not None:
            self.enc_z = self.enc_z.detach().cpu()
        return self

    @property
    def elbo(self) -> Tensor:
        """
        Calculate the negative Evidence Lower Bound (ELBO).

        Returns:
            Tensor: The negative ELBO.
        """

        return self.log_p_x - self.kl_z - self.kl_s

    @property
    def loss(self) -> Tensor:
        """
        Calculate the loss based on the negative ELBO.

        Returns:
            Tensor: The loss tensor.

        Raises:
            Exception: If tensors have been detached.
        """

        if self.detached:
            logger.error("Cannot calculate loss. Tensors have been detached.")
            raise Exception(
                "Cannot calculate loss. Tensors have been detached."
            )
        loss = -self.elbo.sum()
        return loss

elbo: Tensor property

Calculate the negative Evidence Lower Bound (ELBO).

Returns:

Name Type Description
Tensor Tensor

The negative ELBO.

loss: Tensor property

Calculate the loss based on the negative ELBO.

Returns:

Name Type Description
Tensor Tensor

The loss tensor.

Raises:

Type Description
Exception

If tensors have been detached.

__post_init__()

Validate dimensions of the log-likelihood tensor.

Raises:

Type Description
Exception

If log-likelihood tensor is not of dimension 1.

Source code in vambn/modelling/models/hivae/outputs.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __post_init__(self):
    """
    Validate dimensions of the log-likelihood tensor.

    Raises:
        Exception: If log-likelihood tensor is not of dimension 1.
    """
    if self.log_p_x.ndim != 1:
        raise Exception(
            f"Log-likelihood is not of correct dimension ({self.log_p_x.ndim}, expected 1)"
        )

__str__()

String representation of the DecoderOutput object.

Returns:

Name Type Description
str str

A string describing the DecoderOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
107
108
109
110
111
112
113
114
115
116
117
118
def __str__(self) -> str:
    """
    String representation of the DecoderOutput object.

    Returns:
        str: A string describing the DecoderOutput object.
    """

    if self.output_name is not None:
        return f"Decoder output for {self.output_name})"
    else:
        return f"Decoder output (id={id(self)})"

detach_and_move()

Detach all tensors and move them to CPU.

Returns:

Name Type Description
DecoderOutput DecoderOutput

The detached DecoderOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def detach_and_move(self) -> "DecoderOutput":
    """
    Detach all tensors and move them to CPU.

    Returns:
        DecoderOutput: The detached DecoderOutput object.
    """

    self.detached = True
    self.log_p_x = self.log_p_x.detach().cpu()
    self.kl_z = self.kl_z.detach().cpu()
    self.kl_s = self.kl_s.detach().cpu()
    if self.corr_loss is not None:
        self.corr_loss = self.corr_loss.detach().cpu()
    if self.samples is not None:
        self.samples = self.samples.detach().cpu()
    if self.enc_s is not None:
        self.enc_s = self.enc_s.detach().cpu()
    if self.enc_z is not None:
        self.enc_z = self.enc_z.detach().cpu()
    return self

EncoderOutput dataclass

Dataclass for encoder output.

Attributes:

Name Type Description
samples_s Tensor

Samples from the s distribution.

logits_s Tensor

Logits for the s distribution.

mean_z Tensor

Mean of the z distribution.

scale_z Tensor

Scale of the z distribution.

samples_z Optional[Tensor]

Samples from the z distribution.

h_representation Optional[Tensor]

Hidden representation, if any.

Source code in vambn/modelling/models/hivae/outputs.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class EncoderOutput:
    """Dataclass for encoder output.

    Attributes:
        samples_s (Tensor): Samples from the s distribution.
        logits_s (Tensor): Logits for the s distribution.
        mean_z (Tensor): Mean of the z distribution.
        scale_z (Tensor): Scale of the z distribution.
        samples_z (Optional[Tensor]): Samples from the z distribution.
        h_representation (Optional[Tensor]): Hidden representation, if any.
    """


    samples_s: Tensor
    logits_s: Tensor
    mean_z: Tensor
    scale_z: Tensor
    samples_z: Optional[Tensor]
    h_representation: Optional[Tensor] = None

    @property
    def decoder_representation(self) -> Tensor:
        """
        Get the decoder representation.

        Returns:
            Tensor: The hidden representation if available, otherwise the samples from the z distribution.
        """
        return (
            self.h_representation
            if self.h_representation is not None
            else self.samples_z
        )

    @decoder_representation.setter
    def decoder_representation(self, value: Tensor) -> None:
        """
        Set the decoder representation.

        Args:
            value (Tensor): The value to set as the hidden representation.
        """
        # if value.shape != self.samples_z.shape:
        #     raise ValueError(
        #         f"Shape of value ({value.shape}) does not match shape of samples_z ({self.samples_z.shape})"
        #     )

        self.h_representation = value

decoder_representation: Tensor property writable

Get the decoder representation.

Returns:

Name Type Description
Tensor Tensor

The hidden representation if available, otherwise the samples from the z distribution.

HivaeEncoding dataclass

Dataclass for HIVAE encoding.

Attributes:

Name Type Description
s Tensor

Encoding for s.

z Tensor

Encoding for z.

module str

Module name.

samples Optional[Tensor]

Samples generated, if any.

subjid Optional[List[str | int]]

Subject IDs.

h_representation Optional[Tensor]

Hidden representation, if any.

Source code in vambn/modelling/models/hivae/outputs.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
@dataclass
class HivaeEncoding:
    """Dataclass for HIVAE encoding.

    Attributes:
        s (torch.Tensor): Encoding for s.
        z (torch.Tensor): Encoding for z.
        module (str): Module name.
        samples (Optional[Tensor]): Samples generated, if any.
        subjid (Optional[List[str | int]]): Subject IDs.
        h_representation (Optional[Tensor]): Hidden representation, if any.
    """

    s: torch.Tensor
    z: torch.Tensor
    module: str
    samples: Optional[Tensor] = None
    subjid: Optional[List[str | int]] = None
    h_representation: Optional[Tensor] = None

    def __post_init__(self):
        """
        Initialize the encoding and ensure tensors are on CPU and have the correct dtype.
        """

        if self.s.device != torch.device("cpu"):
            self.s = self.s.cpu()
        if self.z.device != torch.device("cpu"):
            self.z = self.z.cpu()
        if self.samples is not None and self.samples.device != torch.device(
            "cpu"
        ):
            self.samples = self.samples.cpu()

        if self.s.dtype != torch.float32:
            self.s = self.s.float()

        # make sure z encoding is float
        if self.z.dtype != torch.float32:
            self.z = self.z.float()

    @property
    def decoder_representation(self) -> Tensor:
        """
        Get the decoder representation.

        Returns:
            Tensor: The hidden representation if available, otherwise the z encoding.
        """
        return (
            self.h_representation
            if self.h_representation is not None
            else self.z
        )

    @decoder_representation.setter
    def decoder_representation(self, value: Tensor) -> None:
        """
        Set the decoder representation.

        Args:
            value (Tensor): The value to set as the hidden representation.
        """
        self.h_representation = value

    def convert(self) -> Dict[str, List[str | float]]:
        """
        Convert the encoding to a dictionary format.

        Returns:
            Dict[str, List[str | float]]: The converted encoding.
        """
        data = {
            f"{self.module}_s": self.s.argmax(dim=1).tolist(),
            "SUBJID": self.subjid,
        }
        if self.z.ndim == 2 and self.z.shape[1] > 1:
            for i in range(self.z.shape[1]):
                data[f"{self.module}_z{i}"] = self.z[:, i].tolist()
        else:
            data[f"{self.module}_z"] = self.z.view(-1).tolist()
        return data

    def get_samples(self, module: Optional[str] = None) -> Tensor:
        """
        Get the samples for the specified module.

        Args:
            module (Optional[str]): The module name. If None, return all samples.

        Returns:
            Tensor: The samples tensor.
        """
        if module is None:
            return self.samples
        else:
            assert self.module is not None
            assert module == self.module
            return self.samples

    @typeguard.typechecked
    def save_meta_enc(self, path: Path):
        """
        Save the metadata encoding to a CSV file.

        Args:
            path (Path): The file path to save the metadata.
        """
        path.parent.mkdir(parents=True, exist_ok=True)
        converted = self.convert()
        df = pd.DataFrame(converted)
        df.to_csv(path, index=False)

decoder_representation: Tensor property writable

Get the decoder representation.

Returns:

Name Type Description
Tensor Tensor

The hidden representation if available, otherwise the z encoding.

__post_init__()

Initialize the encoding and ensure tensors are on CPU and have the correct dtype.

Source code in vambn/modelling/models/hivae/outputs.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def __post_init__(self):
    """
    Initialize the encoding and ensure tensors are on CPU and have the correct dtype.
    """

    if self.s.device != torch.device("cpu"):
        self.s = self.s.cpu()
    if self.z.device != torch.device("cpu"):
        self.z = self.z.cpu()
    if self.samples is not None and self.samples.device != torch.device(
        "cpu"
    ):
        self.samples = self.samples.cpu()

    if self.s.dtype != torch.float32:
        self.s = self.s.float()

    # make sure z encoding is float
    if self.z.dtype != torch.float32:
        self.z = self.z.float()

convert()

Convert the encoding to a dictionary format.

Returns:

Type Description
Dict[str, List[str | float]]

Dict[str, List[str | float]]: The converted encoding.

Source code in vambn/modelling/models/hivae/outputs.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def convert(self) -> Dict[str, List[str | float]]:
    """
    Convert the encoding to a dictionary format.

    Returns:
        Dict[str, List[str | float]]: The converted encoding.
    """
    data = {
        f"{self.module}_s": self.s.argmax(dim=1).tolist(),
        "SUBJID": self.subjid,
    }
    if self.z.ndim == 2 and self.z.shape[1] > 1:
        for i in range(self.z.shape[1]):
            data[f"{self.module}_z{i}"] = self.z[:, i].tolist()
    else:
        data[f"{self.module}_z"] = self.z.view(-1).tolist()
    return data

get_samples(module=None)

Get the samples for the specified module.

Parameters:

Name Type Description Default
module Optional[str]

The module name. If None, return all samples.

None

Returns:

Name Type Description
Tensor Tensor

The samples tensor.

Source code in vambn/modelling/models/hivae/outputs.py
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
def get_samples(self, module: Optional[str] = None) -> Tensor:
    """
    Get the samples for the specified module.

    Args:
        module (Optional[str]): The module name. If None, return all samples.

    Returns:
        Tensor: The samples tensor.
    """
    if module is None:
        return self.samples
    else:
        assert self.module is not None
        assert module == self.module
        return self.samples

save_meta_enc(path)

Save the metadata encoding to a CSV file.

Parameters:

Name Type Description Default
path Path

The file path to save the metadata.

required
Source code in vambn/modelling/models/hivae/outputs.py
493
494
495
496
497
498
499
500
501
502
503
504
@typeguard.typechecked
def save_meta_enc(self, path: Path):
    """
    Save the metadata encoding to a CSV file.

    Args:
        path (Path): The file path to save the metadata.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    converted = self.convert()
    df = pd.DataFrame(converted)
    df.to_csv(path, index=False)

HivaeOutput dataclass

Dataclass for HIVAE output.

Attributes:

Name Type Description
loss Tensor

Loss tensor.

enc_z Tensor

Encoded z values.

enc_s Tensor

Encoded s values.

samples Optional[Tensor]

Samples generated, if any.

n Optional[int]

Number of samples.

single bool

Whether this is a single output.

Source code in vambn/modelling/models/hivae/outputs.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
@dataclass
class HivaeOutput:
    """Dataclass for HIVAE output.

    Attributes:
        loss (Tensor): Loss tensor.
        enc_z (Tensor): Encoded z values.
        enc_s (Tensor): Encoded s values.
        samples (Optional[Tensor]): Samples generated, if any.
        n (Optional[int]): Number of samples.
        single (bool): Whether this is a single output.
    """

    loss: Tensor
    enc_z: Tensor
    enc_s: Tensor
    samples: Optional[Tensor] = None
    n: Optional[int] = None
    single: bool = True

    def __post_init__(self):
        """
        Initialize the number of samples.
        """
        self.n = self.enc_z.shape[0]

    @property
    def n_loss(self) -> int:
        """
        Get the number of loss values.

        Returns:
            int: Number of loss values.
        """
        return 1 if self.loss.ndim == 0 else self.loss.shape[0]

    def detach(self) -> "HivaeOutput":
        """
        Detach all tensors and move them to CPU.

        Returns:
            HivaeOutput: The detached HivaeOutput object.
        """
        self.loss = self.loss.detach().cpu()
        self.enc_z = self.enc_z.detach().cpu()
        self.enc_s = self.enc_s.detach().cpu()

        if self.samples is not None:
            self.samples = self.samples.detach().cpu()
        return self

    @property
    def avg_loss(self) -> float:
        """
        Calculate the average loss.

        Returns:
            float: The average loss.

        Raises:
            ValueError: If the loss tensor has an invalid dimension.
        """
        if self.loss.ndim == 0:
            return float(self.loss)
        elif self.loss.ndim == 1:
            return float(self.loss.mean())
        else:
            raise ValueError(
                f"Loss is of wrong dimension ({self.loss.ndim}), expected 0 or 1"
            )

    def stack(self, other: "HivaeOutput") -> "HivaeOutput":
        """
        Stack another HivaeOutput object with this one.

        Args:
            other (HivaeOutput): The other HivaeOutput object to stack.

        Returns:
            HivaeOutput: The stacked HivaeOutput object.
        """

        self.single = False
        self.loss = torch.cat([self.loss.view(-1), other.loss.view(-1)])
        self.enc_z = torch.cat([self.enc_z, other.enc_z])
        self.enc_s = torch.cat([self.enc_s, other.enc_s])

        if self.samples is not None:
            self.samples = torch.cat([self.samples, other.samples])
        self.n += other.n
        return self

    def __add__(self, other: "HivaeOutput") -> "HivaeOutput":
        """
        Add another HivaeOutput object to this one.

        Args:
            other (HivaeOutput): The other HivaeOutput object to add.

        Returns:
            HivaeOutput: The resulting HivaeOutput object.
        """

        return self.stack(other)

avg_loss: float property

Calculate the average loss.

Returns:

Name Type Description
float float

The average loss.

Raises:

Type Description
ValueError

If the loss tensor has an invalid dimension.

n_loss: int property

Get the number of loss values.

Returns:

Name Type Description
int int

Number of loss values.

__add__(other)

Add another HivaeOutput object to this one.

Parameters:

Name Type Description Default
other HivaeOutput

The other HivaeOutput object to add.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The resulting HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
282
283
284
285
286
287
288
289
290
291
292
293
def __add__(self, other: "HivaeOutput") -> "HivaeOutput":
    """
    Add another HivaeOutput object to this one.

    Args:
        other (HivaeOutput): The other HivaeOutput object to add.

    Returns:
        HivaeOutput: The resulting HivaeOutput object.
    """

    return self.stack(other)

__post_init__()

Initialize the number of samples.

Source code in vambn/modelling/models/hivae/outputs.py
210
211
212
213
214
def __post_init__(self):
    """
    Initialize the number of samples.
    """
    self.n = self.enc_z.shape[0]

detach()

Detach all tensors and move them to CPU.

Returns:

Name Type Description
HivaeOutput HivaeOutput

The detached HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def detach(self) -> "HivaeOutput":
    """
    Detach all tensors and move them to CPU.

    Returns:
        HivaeOutput: The detached HivaeOutput object.
    """
    self.loss = self.loss.detach().cpu()
    self.enc_z = self.enc_z.detach().cpu()
    self.enc_s = self.enc_s.detach().cpu()

    if self.samples is not None:
        self.samples = self.samples.detach().cpu()
    return self

stack(other)

Stack another HivaeOutput object with this one.

Parameters:

Name Type Description Default
other HivaeOutput

The other HivaeOutput object to stack.

required

Returns:

Name Type Description
HivaeOutput HivaeOutput

The stacked HivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def stack(self, other: "HivaeOutput") -> "HivaeOutput":
    """
    Stack another HivaeOutput object with this one.

    Args:
        other (HivaeOutput): The other HivaeOutput object to stack.

    Returns:
        HivaeOutput: The stacked HivaeOutput object.
    """

    self.single = False
    self.loss = torch.cat([self.loss.view(-1), other.loss.view(-1)])
    self.enc_z = torch.cat([self.enc_z, other.enc_z])
    self.enc_s = torch.cat([self.enc_s, other.enc_s])

    if self.samples is not None:
        self.samples = torch.cat([self.samples, other.samples])
    self.n += other.n
    return self

LogLikelihoodOutput dataclass

Dataclass to hold the output from log-likelihood functions.

Attributes:

Name Type Description
log_p_x Tensor

Log-likelihood for observed data.

log_p_x_missing Tensor

Log-likelihood for missing data.

samples Optional[Tensor]

Samples generated, if any.

Source code in vambn/modelling/models/hivae/outputs.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
@dataclass
class LogLikelihoodOutput:
    """Dataclass to hold the output from log-likelihood functions.

    Attributes:
        log_p_x (Tensor): Log-likelihood for observed data.
        log_p_x_missing (Tensor): Log-likelihood for missing data.
        samples (Optional[Tensor]): Samples generated, if any.
    """


    log_p_x: Tensor
    log_p_x_missing: Tensor
    samples: Optional[Tensor] = None

LstmHivaeOutput dataclass

Bases: HivaeOutput

Dataclass for LSTM HIVAE output.

Source code in vambn/modelling/models/hivae/outputs.py
296
297
298
299
300
@dataclass
class LstmHivaeOutput(HivaeOutput):
    """Dataclass for LSTM HIVAE output."""

    pass

ModularHivaeEncoding dataclass

Dataclass for modular HIVAE encoding.

Attributes:

Name Type Description
encodings Tuple[HivaeEncoding, ...]

Tuple of HIVAE encodings. See HivaeEncoding for details.

modules List[str]

List of module names in the same order as encodings.

Source code in vambn/modelling/models/hivae/outputs.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
@dataclass
class ModularHivaeEncoding:
    """Dataclass for modular HIVAE encoding.

    Attributes:
        encodings (Tuple[HivaeEncoding, ...]): Tuple of HIVAE encodings. See HivaeEncoding for details.
        modules (List[str]): List of module names in the same order as encodings.
    """
    encodings: Tuple[HivaeEncoding, ...]
    modules: List[str]

    def __post_init__(self):
        """
        Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

        Raises:
            Exception: If modules in encodings do not match modules in ModularHivaeEncoding.
        """
        for encoding in self.encodings:
            if encoding.s.device != torch.device("cpu"):
                encoding.s = encoding.s.cpu()
            if encoding.z.device != torch.device("cpu"):
                encoding.z = encoding.z.cpu()
            if (
                encoding.samples is not None
                and encoding.samples.device != torch.device("cpu")
            ):
                encoding.samples = encoding.samples.cpu()

            if any(
                [x.module != y for x, y in zip(self.encodings, self.modules)]
            ):
                raise Exception(
                    "Modules in encodings do not match modules in ModularHivaeEncoding"
                )

    def convert(self) -> Dict[str, List[float | str]]:
        """
        Convert the modular encoding to a dictionary format.

        Returns:
            Dict[str, List[float | str]]: The converted encoding.
        """
        out = {}
        for encoding in self.encodings:
            data = encoding.convert()
            out.update(data)
        return out

    @typeguard.typechecked
    def save_meta_enc(self, path: Path):
        """
        Save the metadata encoding to a CSV file.

        Args:
            path (Path): The file path to save the metadata.
        """
        path.parent.mkdir(parents=True, exist_ok=True)
        converted = self.convert()
        df = pd.DataFrame(converted)
        df.to_csv(path, index=False)

    def get_samples(
        self, module: Optional[str] = None
    ) -> Tensor | Dict[str, Tensor]:
        """
        Get the samples for the specified module.

        Args:
            module (Optional[str]): The module name. If None, return all samples.

        Returns:
            Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.
        """
        if module is None:
            raise {x.module: x.samples for x in self.encodings}
        else:
            selected_encodings = [
                x for x in self.encodings if x.module == module
            ]
            assert len(selected_encodings) == 1
            return selected_encodings[0].samples

    def __getitem__(self, idx: int) -> HivaeEncoding:
        """
        Get an encoding by index.

        Args:
            idx (int): The index of the encoding to retrieve.

        Returns:
            HivaeEncoding: The encoding at the specified index.
        """
        return self.encodings[idx]

    def get(self, module: str) -> HivaeEncoding:
        """
        Get an encoding by module name.

        Args:
            module (str): The module name.

        Returns:
            HivaeEncoding: The encoding for the specified module.

        Raises:
            Exception: If the module is not found in encodings.
        """
        for encoding in self.encodings:
            if encoding.module == module:
                return encoding
        raise Exception(f"Module {module} not found in encodings")

__getitem__(idx)

Get an encoding by index.

Parameters:

Name Type Description Default
idx int

The index of the encoding to retrieve.

required

Returns:

Name Type Description
HivaeEncoding HivaeEncoding

The encoding at the specified index.

Source code in vambn/modelling/models/hivae/outputs.py
590
591
592
593
594
595
596
597
598
599
600
def __getitem__(self, idx: int) -> HivaeEncoding:
    """
    Get an encoding by index.

    Args:
        idx (int): The index of the encoding to retrieve.

    Returns:
        HivaeEncoding: The encoding at the specified index.
    """
    return self.encodings[idx]

__post_init__()

Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

Raises:

Type Description
Exception

If modules in encodings do not match modules in ModularHivaeEncoding.

Source code in vambn/modelling/models/hivae/outputs.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def __post_init__(self):
    """
    Initialize the modular encoding and ensure tensors are on CPU and have the correct dtype.

    Raises:
        Exception: If modules in encodings do not match modules in ModularHivaeEncoding.
    """
    for encoding in self.encodings:
        if encoding.s.device != torch.device("cpu"):
            encoding.s = encoding.s.cpu()
        if encoding.z.device != torch.device("cpu"):
            encoding.z = encoding.z.cpu()
        if (
            encoding.samples is not None
            and encoding.samples.device != torch.device("cpu")
        ):
            encoding.samples = encoding.samples.cpu()

        if any(
            [x.module != y for x, y in zip(self.encodings, self.modules)]
        ):
            raise Exception(
                "Modules in encodings do not match modules in ModularHivaeEncoding"
            )

convert()

Convert the modular encoding to a dictionary format.

Returns:

Type Description
Dict[str, List[float | str]]

Dict[str, List[float | str]]: The converted encoding.

Source code in vambn/modelling/models/hivae/outputs.py
543
544
545
546
547
548
549
550
551
552
553
554
def convert(self) -> Dict[str, List[float | str]]:
    """
    Convert the modular encoding to a dictionary format.

    Returns:
        Dict[str, List[float | str]]: The converted encoding.
    """
    out = {}
    for encoding in self.encodings:
        data = encoding.convert()
        out.update(data)
    return out

get(module)

Get an encoding by module name.

Parameters:

Name Type Description Default
module str

The module name.

required

Returns:

Name Type Description
HivaeEncoding HivaeEncoding

The encoding for the specified module.

Raises:

Type Description
Exception

If the module is not found in encodings.

Source code in vambn/modelling/models/hivae/outputs.py
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
def get(self, module: str) -> HivaeEncoding:
    """
    Get an encoding by module name.

    Args:
        module (str): The module name.

    Returns:
        HivaeEncoding: The encoding for the specified module.

    Raises:
        Exception: If the module is not found in encodings.
    """
    for encoding in self.encodings:
        if encoding.module == module:
            return encoding
    raise Exception(f"Module {module} not found in encodings")

get_samples(module=None)

Get the samples for the specified module.

Parameters:

Name Type Description Default
module Optional[str]

The module name. If None, return all samples.

None

Returns:

Type Description
Tensor | Dict[str, Tensor]

Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.

Source code in vambn/modelling/models/hivae/outputs.py
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
def get_samples(
    self, module: Optional[str] = None
) -> Tensor | Dict[str, Tensor]:
    """
    Get the samples for the specified module.

    Args:
        module (Optional[str]): The module name. If None, return all samples.

    Returns:
        Tensor | Dict[str, Tensor]: The samples tensor or a dictionary of samples.
    """
    if module is None:
        raise {x.module: x.samples for x in self.encodings}
    else:
        selected_encodings = [
            x for x in self.encodings if x.module == module
        ]
        assert len(selected_encodings) == 1
        return selected_encodings[0].samples

save_meta_enc(path)

Save the metadata encoding to a CSV file.

Parameters:

Name Type Description Default
path Path

The file path to save the metadata.

required
Source code in vambn/modelling/models/hivae/outputs.py
556
557
558
559
560
561
562
563
564
565
566
567
@typeguard.typechecked
def save_meta_enc(self, path: Path):
    """
    Save the metadata encoding to a CSV file.

    Args:
        path (Path): The file path to save the metadata.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    converted = self.convert()
    df = pd.DataFrame(converted)
    df.to_csv(path, index=False)

ModularHivaeOutput dataclass

Dataclass for modular HIVAE output.

Attributes:

Name Type Description
outputs Tuple[HivaeOutputs, ...]

Tuple of HIVAE outputs. See HivaeOutputs for details.

Source code in vambn/modelling/models/hivae/outputs.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@dataclass
class ModularHivaeOutput:
    """Dataclass for modular HIVAE output.

    Attributes:
        outputs (Tuple[HivaeOutputs, ...]): Tuple of HIVAE outputs. See HivaeOutputs for details.
    """

    outputs: Tuple[HivaeOutputs, ...]

    def __add__(self, other: "ModularHivaeOutput") -> "ModularHivaeOutput":
        """
        Add another ModularHivaeOutput object to this one.

        Args:
            other (ModularHivaeOutput): The other ModularHivaeOutput object to add.

        Returns:
            ModularHivaeOutput: The resulting ModularHivaeOutput object.
        """
        for old, new in zip(self.outputs, other.outputs):
            old += new
        logger.debug(f"Added {other} to {self}")
        return self

    def detach(self) -> "ModularHivaeOutput":
        """
        Detach all tensors in the outputs and move them to CPU.

        Returns:
            ModularHivaeOutput: The detached ModularHivaeOutput object.
        """
        for output in self.outputs:
            output.detach()
        return self

    @property
    def avg_loss(self) -> float:
        """
        Calculate the average loss across all outputs.

        Returns:
            float: The average loss.
        """
        return sum([x.avg_loss for x in self.outputs])

    @property
    def loss(self) -> Tensor:
        """
        Calculate the total loss across all outputs.

        Returns:
            Tensor: The total loss tensor.
        """
        return torch.stack([x.loss for x in self.outputs]).sum()

    def __iter__(self):
        """
        Iterate over the outputs.

        Returns:
            Iterator: An iterator over the outputs.
        """
        return iter(self.outputs)

    def __len__(self):
        """
        Get the number of outputs.

        Returns:
            int: The number of outputs.
        """
        return len(self.outputs)

    def __item__(self, idx: int) -> HivaeOutputs:
        """
        Get an output by index.

        Args:
            idx (int): The index of the output to retrieve.

        Returns:
            HivaeOutputs: The output at the specified index.
        """
        return self.outputs[idx]

avg_loss: float property

Calculate the average loss across all outputs.

Returns:

Name Type Description
float float

The average loss.

loss: Tensor property

Calculate the total loss across all outputs.

Returns:

Name Type Description
Tensor Tensor

The total loss tensor.

__add__(other)

Add another ModularHivaeOutput object to this one.

Parameters:

Name Type Description Default
other ModularHivaeOutput

The other ModularHivaeOutput object to add.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The resulting ModularHivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def __add__(self, other: "ModularHivaeOutput") -> "ModularHivaeOutput":
    """
    Add another ModularHivaeOutput object to this one.

    Args:
        other (ModularHivaeOutput): The other ModularHivaeOutput object to add.

    Returns:
        ModularHivaeOutput: The resulting ModularHivaeOutput object.
    """
    for old, new in zip(self.outputs, other.outputs):
        old += new
    logger.debug(f"Added {other} to {self}")
    return self

__item__(idx)

Get an output by index.

Parameters:

Name Type Description Default
idx int

The index of the output to retrieve.

required

Returns:

Name Type Description
HivaeOutputs HivaeOutputs

The output at the specified index.

Source code in vambn/modelling/models/hivae/outputs.py
380
381
382
383
384
385
386
387
388
389
390
def __item__(self, idx: int) -> HivaeOutputs:
    """
    Get an output by index.

    Args:
        idx (int): The index of the output to retrieve.

    Returns:
        HivaeOutputs: The output at the specified index.
    """
    return self.outputs[idx]

__iter__()

Iterate over the outputs.

Returns:

Name Type Description
Iterator

An iterator over the outputs.

Source code in vambn/modelling/models/hivae/outputs.py
362
363
364
365
366
367
368
369
def __iter__(self):
    """
    Iterate over the outputs.

    Returns:
        Iterator: An iterator over the outputs.
    """
    return iter(self.outputs)

__len__()

Get the number of outputs.

Returns:

Name Type Description
int

The number of outputs.

Source code in vambn/modelling/models/hivae/outputs.py
371
372
373
374
375
376
377
378
def __len__(self):
    """
    Get the number of outputs.

    Returns:
        int: The number of outputs.
    """
    return len(self.outputs)

detach()

Detach all tensors in the outputs and move them to CPU.

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The detached ModularHivaeOutput object.

Source code in vambn/modelling/models/hivae/outputs.py
331
332
333
334
335
336
337
338
339
340
def detach(self) -> "ModularHivaeOutput":
    """
    Detach all tensors in the outputs and move them to CPU.

    Returns:
        ModularHivaeOutput: The detached ModularHivaeOutput object.
    """
    for output in self.outputs:
        output.detach()
    return self

shared

AvgModuleMtl

Bases: BaseModule

This module averages the z's and passes them through the MOO block. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z.

required
ys_dim int

The dimension of the shared output ys.

required
y_dim int | Tuple[int, ...]

The dimension of the individual outputs y.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method used for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
class AvgModuleMtl(BaseModule):
    """
    This module averages the z's and passes them through the MOO block. The output is then
    passed through individual layers to generate the final outputs.

    Args:
        z_dim (int): The dimension of the input z.
        ys_dim (int): The dimension of the shared output ys.
        y_dim (int | Tuple[int, ...]): The dimension of the individual outputs y.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method used for multi-task learning. Defaults to None.

    Attributes:
        _mtl_module (MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (MultiMOOForLoop): The multi-objective optimization block.
        scaling_layers (nn.ModuleList): The list of scaling layers.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [BaseElement(self.z_dim, self.z_dim) for i in range(self.n_modules)]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs the forward pass of the module. The input z's are averaged and passed through
        the MOO block. The output is then passed through the individual scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """
        # Average the z's and pass through the shared layer
        (h,) = self.moo_block(torch.mean(torch.stack(z), dim=0))
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), h)
            ]
        )

forward(z)

Performs the forward pass of the module. The input z's are averaged and passed through the MOO block. The output is then passed through the individual scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs the forward pass of the module. The input z's are averaged and passed through
    the MOO block. The output is then passed through the individual scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """
    # Average the z's and pass through the shared layer
    (h,) = self.moo_block(torch.mean(torch.stack(z), dim=0))
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), h)
        ]
    )

order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
469
470
471
472
473
474
475
476
477
478
479
480
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

BaseElement

Bases: Module

A base neural network element that consists of a single modified linear layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension of the layer.

required
output_dim int

Output dimension of the layer.

required
Source code in vambn/modelling/models/hivae/shared.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class BaseElement(nn.Module):
    """
    A base neural network element that consists of a single modified linear layer.

    Args:
        input_dim (int): Input dimension of the layer.
        output_dim (int): Output dimension of the layer.
    """

    def __init__(self, input_dim: int, output_dim: int) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layer = ModifiedLinear(self.input_dim, self.output_dim, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Perform the forward pass through the layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after passing through the layer.
        """
        return self.layer(x)

forward(x)

Perform the forward pass through the layer.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Output tensor after passing through the layer.

Source code in vambn/modelling/models/hivae/shared.py
27
28
29
30
31
32
33
34
35
36
37
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Perform the forward pass through the layer.

    Args:
        x (torch.Tensor): Input tensor.

    Returns:
        torch.Tensor: Output tensor after passing through the layer.
    """
    return self.layer(x)

BaseModule

Bases: Module, ABC

Base class for all modules in the HIVAE model.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Raises:

Type Description
ValueError

If the length of y_dim is not equal to the number of modules.

Source code in vambn/modelling/models/hivae/shared.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class BaseModule(nn.Module, ABC):
    """
    Base class for all modules in the HIVAE model.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Raises:
        ValueError: If the length of y_dim is not equal to the number of modules.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        """
        Base class for all modules in the HIVAE model.

        Args:
            z_dim (int): Input dimension of the latent samples z.
            ys_dim (int): Intermediate dimension of the latent samples ys.
            y_dim (int | Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
            n_modules (int): Number of modules.
            module_names (Tuple[str, ...]): Names of the modules.
            mtl_method (Optional[Tuple[str, ...]], optional): MTL methods used to avoid conflicting gradients. Defaults to None.

        Raises:
            ValueError: If the length of y_dim is not equal to the number of modules.
        """
        super().__init__()
        self.z_dim = z_dim
        self.ys_dim = ys_dim
        self.y_dim = y_dim
        self.n_modules = n_modules
        self.mtl_method = mtl_method
        self.module_names = module_names
        self.has_params = True

        if (
            isinstance(self.y_dim, Iterable)
            and len(self.y_dim) != self.n_modules
        ):
            raise ValueError(
                f"Length of y_dim must be equal to the number of modules: {self.n_modules}"
            )


    @abstractmethod
    def order_layers(self, module_names: Tuple[str, ...]) -> None:
        """
        Order the layers of the module according to the module names.

        Args:
            module_names (Tuple[str, ...]): Names of the modules
        """
        pass

    @abstractmethod
    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        pass

    def _y_dim(self, i: int) -> int:
        """
        Retruns the output dimension of the ith module.

        Args:
            i (int): Index of the module.

        Returns:
            int: Output dimension of the ith module.
        """
        return self.y_dim if isinstance(self.y_dim, int) else self.y_dim[i]

__init__(z_dim, ys_dim, y_dim, n_modules, module_names, mtl_method=None)

Base class for all modules in the HIVAE model.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int | Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Raises:

Type Description
ValueError

If the length of y_dim is not equal to the number of modules.

Source code in vambn/modelling/models/hivae/shared.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    z_dim: int,
    ys_dim: int,
    y_dim: int | Tuple[int, ...],
    n_modules: int,
    module_names: Tuple[str, ...],
    mtl_method: Optional[Tuple[str, ...]] = None,
) -> None:
    """
    Base class for all modules in the HIVAE model.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int | Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]], optional): MTL methods used to avoid conflicting gradients. Defaults to None.

    Raises:
        ValueError: If the length of y_dim is not equal to the number of modules.
    """
    super().__init__()
    self.z_dim = z_dim
    self.ys_dim = ys_dim
    self.y_dim = y_dim
    self.n_modules = n_modules
    self.mtl_method = mtl_method
    self.module_names = module_names
    self.has_params = True

    if (
        isinstance(self.y_dim, Iterable)
        and len(self.y_dim) != self.n_modules
    ):
        raise ValueError(
            f"Length of y_dim must be equal to the number of modules: {self.n_modules}"
        )

order_layers(module_names) abstractmethod

Order the layers of the module according to the module names.

Parameters:

Name Type Description Default
module_names Tuple[str, ...]

Names of the modules

required
Source code in vambn/modelling/models/hivae/shared.py
 97
 98
 99
100
101
102
103
104
105
@abstractmethod
def order_layers(self, module_names: Tuple[str, ...]) -> None:
    """
    Order the layers of the module according to the module names.

    Args:
        module_names (Tuple[str, ...]): Names of the modules
    """
    pass

ConcatModule

Bases: BaseModule

A module that concatenates multiple input tensors and applies scaling layers.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z tensors.

required
ys_dim int

The dimension of the output ys tensors.

required
y_dim int or Tuple[int, ...]

The dimension of the output y tensors.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method used for multi-task learning (default: None).

None

Attributes:

Name Type Description
shared_layer BaseElement

The shared layer that concatenates the input z tensors.

scaling_layers ModuleList

The list of scaling layers for each module.

Source code in vambn/modelling/models/hivae/shared.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class ConcatModule(BaseModule):
    """
    A module that concatenates multiple input tensors and applies scaling layers.

    Args:
        z_dim (int): The dimension of the input z tensors.
        ys_dim (int): The dimension of the output ys tensors.
        y_dim (int or Tuple[int, ...]): The dimension of the output y tensors.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method used for multi-task learning (default: None).

    Attributes:
        shared_layer (BaseElement): The shared layer that concatenates the input z tensors.
        scaling_layers (nn.ModuleList): The list of scaling layers for each module.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.shared_layer = BaseElement(
            self.z_dim * n_modules, self.ys_dim * n_modules
        )
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Reorders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The new order of the module names.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the ConcatModule. The input tensors z are concatenated
        and passed through the shared layer to generate a single output tensor. The output
        tensor is then passed through the scaling layers to generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

        """
        # z-type: Tuple[torch.Tensor, ...], with each z having shape (batch_size, z_dim)
        h = self.shared_layer(torch.cat(z, dim=1))
        # size of h: (batch_size, ys_dim * n_modules)
        return tuple(
            [
                self.scaling_layers[i](
                    h[:, i * self.ys_dim : (i + 1) * self.ys_dim]
                )
                for i in range(self.n_modules)
            ]
        )  # size of h[:, i * self.ys_dim : (i + 1) * self.ys_dim]: (batch_size, ys_dim)

forward(z)

Performs forward pass through the ConcatModule. The input tensors z are concatenated and passed through the shared layer to generate a single output tensor. The output tensor is then passed through the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the ConcatModule. The input tensors z are concatenated
    and passed through the shared layer to generate a single output tensor. The output
    tensor is then passed through the scaling layers to generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors after applying scaling layers.

    """
    # z-type: Tuple[torch.Tensor, ...], with each z having shape (batch_size, z_dim)
    h = self.shared_layer(torch.cat(z, dim=1))
    # size of h: (batch_size, ys_dim * n_modules)
    return tuple(
        [
            self.scaling_layers[i](
                h[:, i * self.ys_dim : (i + 1) * self.ys_dim]
            )
            for i in range(self.n_modules)
        ]
    )  # size of h[:, i * self.ys_dim : (i + 1) * self.ys_dim]: (batch_size, ys_dim)

order_layers(module_names)

Reorders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The new order of the module names.

required
Source code in vambn/modelling/models/hivae/shared.py
382
383
384
385
386
387
388
389
390
391
392
393
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Reorders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The new order of the module names.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

ConcatModuleMtl

Bases: BaseModule

This module concatenates the z's and passes them through a shared layer before passing through the MOO block. The output is then passed through individual layers to generate the final outputs. This is the same as the ConcatModule, but with the addition of the MOO block.

Source code in vambn/modelling/models/hivae/shared.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
class ConcatModuleMtl(BaseModule):
    """
    This module concatenates the z's and passes them through a shared layer before
    passing through the MOO block. The output is then passed through individual layers
    to generate the final outputs. This is the same as the ConcatModule, but with the
    addition of the MOO block.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.mtl_method = mtl_method
        self.shared_layer = BaseElement(z_dim * n_modules, ys_dim)
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. The z's are concatenated and passed through
        the shared layer to generate one output which is identical for all modules. The
        output is then passed through the MOO block and individual layers to generate
        the final outputs.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
        """
        # Concatenate the z's and pass through the
        # shared layer to generate one output which is identical for all modules
        h = self.shared_layer(torch.cat(z, dim=1))
        # Pass through the MOO block
        (hs,) = self.moo_block(h)
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), hs)
            ]
        )

forward(z)

Forward pass through the module. The z's are concatenated and passed through the shared layer to generate one output which is identical for all modules. The output is then passed through the MOO block and individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.

Source code in vambn/modelling/models/hivae/shared.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. The z's are concatenated and passed through
    the shared layer to generate one output which is identical for all modules. The
    output is then passed through the MOO block and individual layers to generate
    the final outputs.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
    """
    # Concatenate the z's and pass through the
    # shared layer to generate one output which is identical for all modules
    h = self.shared_layer(torch.cat(z, dim=1))
    # Pass through the MOO block
    (hs,) = self.moo_block(h)
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), hs)
        ]
    )

order_layers(module_names)

Order the layers based on the module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
298
299
300
301
302
303
304
305
306
307
308
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

EncoderModule

Bases: BaseModule

EncoderModule class represents a module for encoding input data.

Parameters:

Name Type Description Default
z_dim int

The dimension of the latent space.

required
ys_dim int

The dimension of the output space.

required
y_dim int or Tuple[int, ...]

The dimension(s) of the input data.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method(s) for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
attention SelfAttention

The self-attention layer.

feed_forward Sequential

The feed-forward neural network.

dropout Dropout

The dropout layer.

layer_norm_1 LayerNorm

The layer normalization layer.

layer_norm_2 LayerNorm

The layer normalization layer.

ys_layer Linear

The linear layer for output.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
class EncoderModule(BaseModule):
    """
    EncoderModule class represents a module for encoding input data.

    Args:
        z_dim (int): The dimension of the latent space.
        ys_dim (int): The dimension of the output space.
        y_dim (int or Tuple[int, ...]): The dimension(s) of the input data.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method(s) for multi-task learning. Defaults to None.

    Attributes:
        attention (SelfAttention): The self-attention layer.
        feed_forward (nn.Sequential): The feed-forward neural network.
        dropout (nn.Dropout): The dropout layer.
        layer_norm_1 (nn.LayerNorm): The layer normalization layer.
        layer_norm_2 (nn.LayerNorm): The layer normalization layer.
        ys_layer (nn.Linear): The linear layer for output.
        scaling_layers (nn.ModuleList): The list of scaling layers.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        cat_dim = z_dim * n_modules
        self.attention = SelfAttention(cat_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(cat_dim, ys_dim),
            nn.GELU(),
            nn.Linear(ys_dim, cat_dim),
            nn.Dropout(0.1),
        )
        self.dropout = nn.Dropout(0.1)
        self.layer_norm_1 = nn.LayerNorm(cat_dim)
        self.layer_norm_2 = nn.LayerNorm(cat_dim)
        self.ys_layer = nn.Linear(cat_dim, ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.

        Returns:
            None

        """

        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the encoder module. The input tensors z are concatenated
        and passed through the self-attention layer. The output is then passed through the feed-forward
        neural network and the output layer. The output is then passed through the scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """

        x = torch.cat(z, dim=1)
        attended = self.attention(torch.cat(z, dim=1))
        x = self.layer_norm_1(x + self.dropout(attended))

        h = self.feed_forward(x)
        h = self.layer_norm_2(x + self.dropout(h))
        yb = self.ys_layer(h)
        y = tuple([self.scaling_layers[i](yb) for i in range(self.n_modules)])
        return y

forward(z)

Performs forward pass through the encoder module. The input tensors z are concatenated and passed through the self-attention layer. The output is then passed through the feed-forward neural network and the output layer. The output is then passed through the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the encoder module. The input tensors z are concatenated
    and passed through the self-attention layer. The output is then passed through the feed-forward
    neural network and the output layer. The output is then passed through the scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """

    x = torch.cat(z, dim=1)
    attended = self.attention(torch.cat(z, dim=1))
    x = self.layer_norm_1(x + self.dropout(attended))

    h = self.feed_forward(x)
    h = self.layer_norm_2(x + self.dropout(h))
    yb = self.ys_layer(h)
    y = tuple([self.scaling_layers[i](yb) for i in range(self.n_modules)])
    return y

order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/shared.py
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.

    Returns:
        None

    """

    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

EncoderModuleMtl

Bases: BaseModule

Encoder module for multi-task learning.

Parameters:

Name Type Description Default
z_dim int

The dimension of the latent space.

required
ys_dim int

The dimension of the shared representation.

required
y_dim int | Tuple[int, ...]

The dimension(s) of the task-specific representations.

required
n_modules int

The number of task-specific modules.

required
module_names Tuple[str, ...]

The names of the task-specific modules.

required
mtl_method Optional[Tuple[str, ...]]

The multi-task learning method(s) to use. Defaults to None.

None

Attributes:

Name Type Description
attention SelfAttention

The self-attention layer.

feed_forward Sequential

The feed-forward neural network.

dropout Dropout

The dropout layer.

layer_norm_1 LayerNorm

The first layer normalization.

layer_norm_2 LayerNorm

The second layer normalization.

ys_layer Linear

The linear layer for shared representation.

scaling_layers ModuleList

The list of scaling layers for task-specific representations.

_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

Raises:

Type Description
Exception

This class should no longer be used.

Source code in vambn/modelling/models/hivae/shared.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
class EncoderModuleMtl(BaseModule):
    """Encoder module for multi-task learning.

    Args:
        z_dim (int): The dimension of the latent space.
        ys_dim (int): The dimension of the shared representation.
        y_dim (int | Tuple[int, ...]): The dimension(s) of the task-specific representations.
        n_modules (int): The number of task-specific modules.
        module_names (Tuple[str, ...]): The names of the task-specific modules.
        mtl_method (Optional[Tuple[str, ...]], optional): The multi-task learning method(s) to use. Defaults to None.

    Attributes:
        attention (SelfAttention): The self-attention layer.
        feed_forward (nn.Sequential): The feed-forward neural network.
        dropout (nn.Dropout): The dropout layer.
        layer_norm_1 (nn.LayerNorm): The first layer normalization.
        layer_norm_2 (nn.LayerNorm): The second layer normalization.
        ys_layer (nn.Linear): The linear layer for shared representation.
        scaling_layers (nn.ModuleList): The list of scaling layers for task-specific representations.
        _mtl_module (moo.MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (moo.MultiMOOForLoop): The multi-objective optimization block.

    Raises:
        Exception: This class should no longer be used.

    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        cat_dim = z_dim * n_modules
        self.attention = SelfAttention(cat_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(cat_dim, ys_dim),
            nn.GELU(),
            nn.Linear(ys_dim, cat_dim),
            nn.Dropout(0.1),
        )
        self.dropout = nn.Dropout(0.1)
        self.layer_norm_1 = nn.LayerNorm(cat_dim)
        self.layer_norm_2 = nn.LayerNorm(cat_dim)
        self.ys_layer = nn.Linear(cat_dim, ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """Order the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the task-specific modules.

        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """Forward pass of the module. The input tensors z are concatenated and passed through
        the self-attention layer. The output is then passed through the feed-forward neural network
        combined with the residual connection and layer normalization. The output is then passed through
        the MTL block and the scaling layers to generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input tensors.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.

        """
        x = torch.cat(z, dim=1)
        attended = self.attention(torch.cat(z, dim=1))
        x = self.layer_norm_1(x + self.dropout(attended))

        h = self.feed_forward(x)
        h = self.layer_norm_2(x + self.dropout(h))
        yb = self.ys_layer(h)
        (yb_moo,) = self.moo_block(yb)
        y = tuple(
            [
                self.scaling_layers[i](yi)
                for i, yi in zip(range(self.n_modules), yb_moo)
            ]
        )
        return y

forward(z)

Forward pass of the module. The input tensors z are concatenated and passed through the self-attention layer. The output is then passed through the feed-forward neural network combined with the residual connection and layer normalization. The output is then passed through the MTL block and the scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input tensors.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """Forward pass of the module. The input tensors z are concatenated and passed through
    the self-attention layer. The output is then passed through the feed-forward neural network
    combined with the residual connection and layer normalization. The output is then passed through
    the MTL block and the scaling layers to generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input tensors.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.

    """
    x = torch.cat(z, dim=1)
    attended = self.attention(torch.cat(z, dim=1))
    x = self.layer_norm_1(x + self.dropout(attended))

    h = self.feed_forward(x)
    h = self.layer_norm_2(x + self.dropout(h))
    yb = self.ys_layer(h)
    (yb_moo,) = self.moo_block(yb)
    y = tuple(
        [
            self.scaling_layers[i](yi)
            for i, yi in zip(range(self.n_modules), yb_moo)
        ]
    )
    return y

order_layers(module_names)

Order the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the task-specific modules.

required
Source code in vambn/modelling/models/hivae/shared.py
804
805
806
807
808
809
810
811
812
813
814
def order_layers(self, module_names: Tuple[str]) -> None:
    """Order the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the task-specific modules.

    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

ImposterModule

Bases: BaseModule

This module is used as a placeholder when no modularity is desired. It simply returns the input z as the output.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Attributes:

Name Type Description
has_params bool

Whether the module has parameters.

Source code in vambn/modelling/models/hivae/shared.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ImposterModule(BaseModule):
    """
    This module is used as a placeholder when no modularity is desired. It simply
    returns the input z as the output.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Attributes:
        has_params (bool): Whether the module has parameters.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.has_params = False

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names. Since this module has no parameters,
        this method does nothing in the case of the ImposterModule.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        pass

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. Since this module has no parameters, it simply
        returns the input z.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor.
        """
        return z

forward(z)

Forward pass through the module. Since this module has no parameters, it simply returns the input z.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor.

Source code in vambn/modelling/models/hivae/shared.py
170
171
172
173
174
175
176
177
178
179
180
181
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. Since this module has no parameters, it simply
    returns the input z.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor.
    """
    return z

order_layers(module_names)

Order the layers based on the module names. Since this module has no parameters, this method does nothing in the case of the ImposterModule.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
160
161
162
163
164
165
166
167
168
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names. Since this module has no parameters,
    this method does nothing in the case of the ImposterModule.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    pass

MaxModuleMtl

Bases: BaseModule

This module takes the maximum of the z's and passes them through the MOO block. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z_dim int

The dimension of the input z.

required
ys_dim int

The dimension of the ys.

required
y_dim int | Tuple[int, ...]

The dimension of the output y.

required
n_modules int

The number of modules.

required
module_names Tuple[str, ...]

The names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

The method for multi-task learning. Defaults to None.

None

Attributes:

Name Type Description
_mtl_module MultiObjectiveOptimization

The multi-objective optimization module.

moo_block MultiMOOForLoop

The multi-objective optimization block.

scaling_layers ModuleList

The list of scaling layers.

Source code in vambn/modelling/models/hivae/shared.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
class MaxModuleMtl(BaseModule):
    """
    This module takes the maximum of the z's and passes them through the MOO block. The
    output is then passed through individual layers to generate the final outputs.

    Args:
        z_dim (int): The dimension of the input z.
        ys_dim (int): The dimension of the ys.
        y_dim (int | Tuple[int, ...]): The dimension of the output y.
        n_modules (int): The number of modules.
        module_names (Tuple[str, ...]): The names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): The method for multi-task learning. Defaults to None.

    Attributes:
        _mtl_module (MultiObjectiveOptimization): The multi-objective optimization module.
        moo_block (MultiMOOForLoop): The multi-objective optimization block.
        scaling_layers (nn.ModuleList): The list of scaling layers.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self._mtl_module = moo.setup_moo(
            [MtlMethodParams(x) for x in mtl_method],
            num_tasks=n_modules,
        )
        self.moo_block = moo.MultiMOOForLoop(
            n_modules, moo_methods=(self._mtl_module,)
        )
        self.scaling_layers = nn.ModuleList(
            [BaseElement(self.z_dim, self.z_dim) for i in range(self.n_modules)]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Orders the scaling layers based on the given module names.

        Args:
            module_names (Tuple[str]): The names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Performs forward pass through the module. The maximum of the z's is passed through
        the MOO block. The output is then passed through individual scaling layers to
        generate the final output tensors.

        Args:
            z (Tuple[torch.Tensor, ...]): The input z.

        Returns:
            Tuple[torch.Tensor, ...]: The output tensors.
        """
        # Average the z's and pass through the shared layer
        (h,) = self.moo_block(torch.max(torch.stack(z), dim=0).values)
        # Pass through the individual layers
        return tuple(
            [
                self.scaling_layers[i](hi)
                for i, hi in zip(range(self.n_modules), h)
            ]
        )

forward(z)

Performs forward pass through the module. The maximum of the z's is passed through the MOO block. The output is then passed through individual scaling layers to generate the final output tensors.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

The input z.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: The output tensors.

Source code in vambn/modelling/models/hivae/shared.py
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Performs forward pass through the module. The maximum of the z's is passed through
    the MOO block. The output is then passed through individual scaling layers to
    generate the final output tensors.

    Args:
        z (Tuple[torch.Tensor, ...]): The input z.

    Returns:
        Tuple[torch.Tensor, ...]: The output tensors.
    """
    # Average the z's and pass through the shared layer
    (h,) = self.moo_block(torch.max(torch.stack(z), dim=0).values)
    # Pass through the individual layers
    return tuple(
        [
            self.scaling_layers[i](hi)
            for i, hi in zip(range(self.n_modules), h)
        ]
    )

order_layers(module_names)

Orders the scaling layers based on the given module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

The names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
553
554
555
556
557
558
559
560
561
562
563
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Orders the scaling layers based on the given module names.

    Args:
        module_names (Tuple[str]): The names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

SelfAttention

Bases: Module

Self-Attention module.

Parameters:

Name Type Description Default
hidden_dim int

The dimension of the input and output tensors.

required

Attributes:

Name Type Description
query Linear

Linear layer for computing the query tensor.

key Linear

Linear layer for computing the key tensor.

value Linear

Linear layer for computing the value tensor.

softmax Softmax

Softmax function for computing attention weights.

Source code in vambn/modelling/models/hivae/shared.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
class SelfAttention(nn.Module):
    """
    Self-Attention module.

    Args:
        hidden_dim (int): The dimension of the input and output tensors.

    Attributes:
        query (nn.Linear): Linear layer for computing the query tensor.
        key (nn.Linear): Linear layer for computing the key tensor.
        value (nn.Linear): Linear layer for computing the value tensor.
        softmax (nn.Softmax): Softmax function for computing attention weights.
    """

    def __init__(self, hidden_dim):
        super().__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, z):
        """
        Forward pass of the SelfAttention module. Computes the query, key, and value tensors
        and uses them to compute the attention weights. The attention weights are then used
        to compute the attended values.

        Args:
            z (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).
        """
        q = self.query(z)
        k = self.key(z)
        v = self.value(z)

        attention_scores = torch.matmul(q, k.transpose(-2, -1))
        attention_weights = self.softmax(attention_scores)

        attended_values = torch.matmul(attention_weights, v)
        return attended_values

forward(z)

Forward pass of the SelfAttention module. Computes the query, key, and value tensors and uses them to compute the attention weights. The attention weights are then used to compute the attended values.

Parameters:

Name Type Description Default
z Tensor

Input tensor of shape (batch_size, seq_len, hidden_dim).

required

Returns:

Type Description

torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).

Source code in vambn/modelling/models/hivae/shared.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def forward(self, z):
    """
    Forward pass of the SelfAttention module. Computes the query, key, and value tensors
    and uses them to compute the attention weights. The attention weights are then used
    to compute the attended values.

    Args:
        z (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).

    Returns:
        torch.Tensor: Output tensor of shape (batch_size, seq_len, hidden_dim).
    """
    q = self.query(z)
    k = self.key(z)
    v = self.value(z)

    attention_scores = torch.matmul(q, k.transpose(-2, -1))
    attention_weights = self.softmax(attention_scores)

    attended_values = torch.matmul(attention_weights, v)
    return attended_values

SharedLinearModule

Bases: BaseModule

This module passes the z's through a single shared dense layer that generates the individual outputs for each module using the same weights. Outputs are generated one by one. The assumption is that z shares the same dimensional space across modules. This is the simplest form of modularity.

Parameters:

Name Type Description Default
z_dim int

Input dimension of the latent samples z.

required
ys_dim int

Intermediate dimension of the latent samples ys.

required
y_dim int or Tuple[int, ...]

Output dimension of the latent samples y, which can be different for each module.

required
n_modules int

Number of modules.

required
module_names Tuple[str, ...]

Names of the modules.

required
mtl_method Optional[Tuple[str, ...]]

MTL methods used to avoid conflicting gradients. Defaults to None.

None

Attributes:

Name Type Description
shared_layer BaseElement

Shared dense layer.

scaling_layers ModuleList

List of individual dense layers for each module.

Source code in vambn/modelling/models/hivae/shared.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
class SharedLinearModule(BaseModule):
    """
    This module passes the z's through a single shared dense layer that generates the
    individual outputs for each module using the same weights. Outputs are generated
    one by one. The assumption is that z shares the same dimensional space across
    modules. This is the simplest form of modularity.

    Args:
        z_dim (int): Input dimension of the latent samples z.
        ys_dim (int): Intermediate dimension of the latent samples ys.
        y_dim (int or Tuple[int, ...]): Output dimension of the latent samples y, which can be different for each module.
        n_modules (int): Number of modules.
        module_names (Tuple[str, ...]): Names of the modules.
        mtl_method (Optional[Tuple[str, ...]]): MTL methods used to avoid conflicting gradients. Defaults to None.

    Attributes:
        shared_layer (BaseElement): Shared dense layer.
        scaling_layers (nn.ModuleList): List of individual dense layers for each module.
    """

    def __init__(
        self,
        z_dim: int,
        ys_dim: int,
        y_dim: int | Tuple[int, ...],
        n_modules: int,
        module_names: Tuple[str, ...],
        mtl_method: Optional[Tuple[str, ...]] = None,
    ) -> None:
        super().__init__(
            z_dim=z_dim,
            ys_dim=ys_dim,
            y_dim=y_dim,
            n_modules=n_modules,
            mtl_method=mtl_method,
            module_names=module_names,
        )
        self.shared_layer = BaseElement(self.z_dim, self.ys_dim)
        self.scaling_layers = nn.ModuleList(
            [
                BaseElement(self.ys_dim, self.z_dim)
                for i in range(self.n_modules)
            ]
        )

    def order_layers(self, module_names: Tuple[str]) -> None:
        """
        Order the layers based on the module names.

        Args:
            module_names (Tuple[str]): Names of the modules.
        """
        prior_map = {name: i for i, name in enumerate(self.module_names)}
        self.scaling_layers = nn.ModuleList(
            [self.scaling_layers[prior_map[name]] for name in module_names]
        )

    def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
        """
        Forward pass through the module. The z's are passed through the shared layer
        to generate one output which is identical for all modules. The output is then
        passed through individual layers to generate the final outputs.

        Args:
            z (Tuple[torch.Tensor, ...]): Input tensor.

        Returns:
            Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
        """
        ys = tuple([self.shared_layer(zi) for zi in z])
        return tuple([self.scaling_layers[i](ysi) for i, ysi in enumerate(ys)])

forward(z)

Forward pass through the module. The z's are passed through the shared layer to generate one output which is identical for all modules. The output is then passed through individual layers to generate the final outputs.

Parameters:

Name Type Description Default
z Tuple[Tensor, ...]

Input tensor.

required

Returns:

Type Description
Tuple[Tensor, ...]

Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.

Source code in vambn/modelling/models/hivae/shared.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def forward(self, z: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """
    Forward pass through the module. The z's are passed through the shared layer
    to generate one output which is identical for all modules. The output is then
    passed through individual layers to generate the final outputs.

    Args:
        z (Tuple[torch.Tensor, ...]): Input tensor.

    Returns:
        Tuple[torch.Tensor, ...]: Output tensor with individual outputs for each module.
    """
    ys = tuple([self.shared_layer(zi) for zi in z])
    return tuple([self.scaling_layers[i](ysi) for i, ysi in enumerate(ys)])

order_layers(module_names)

Order the layers based on the module names.

Parameters:

Name Type Description Default
module_names Tuple[str]

Names of the modules.

required
Source code in vambn/modelling/models/hivae/shared.py
229
230
231
232
233
234
235
236
237
238
239
def order_layers(self, module_names: Tuple[str]) -> None:
    """
    Order the layers based on the module names.

    Args:
        module_names (Tuple[str]): Names of the modules.
    """
    prior_map = {name: i for i, name in enumerate(self.module_names)}
    self.scaling_layers = nn.ModuleList(
        [self.scaling_layers[prior_map[name]] for name in module_names]
    )

trainer

BaseTrainer

Bases: Generic[TConfig, TModel, TPredict, TrainerType, TEncoding], ABC

Base class for trainers in the VAMBN2 framework.

Parameters:

Name Type Description Default
dataset VambnDataset

The dataset to use for training.

required
config PipelineConfig

The configuration object for the pipeline.

required
workers int

The number of workers to use for data loading.

required
checkpoint_path Path

The path to save checkpoints during training.

required
module_name Optional[str]

The name of the module. Defaults to None.

None
experiment_name Optional[str]

The name of the experiment. Defaults to None.

None
force_cpu bool

Whether to force CPU usage. Defaults to False.

False

Attributes:

Name Type Description
dataset VambnDataset

The dataset used for training.

config PipelineConfig

The configuration object for the pipeline.

workers int

The number of workers used for data loading.

checkpoint_path Path

The path to save checkpoints during training.

model Optional[TModel]

The model used for training.

model_config Optional[TConfig]

The configuration object for the model.

module_name Optional[str]

The name of the module.

experiment_name Optional[str]

The name of the experiment.

type str

The type of the trainer.

device device

The device used for training.

use_mtl bool

Whether to use multi-task learning.

use_gan bool

Whether to use generative adversarial networks.

Source code in vambn/modelling/models/hivae/trainer.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
class BaseTrainer(
    Generic[TConfig, TModel, TPredict, TrainerType, TEncoding], ABC
):
    """
    Base class for trainers in the VAMBN2 framework.

    Args:
        dataset (VambnDataset): The dataset to use for training.
        config (PipelineConfig): The configuration object for the pipeline.
        workers (int): The number of workers to use for data loading.
        checkpoint_path (Path): The path to save checkpoints during training.
        module_name (Optional[str], optional): The name of the module. Defaults to None.
        experiment_name (Optional[str], optional): The name of the experiment. Defaults to None.
        force_cpu (bool, optional): Whether to force CPU usage. Defaults to False.

    Attributes:
        dataset (VambnDataset): The dataset used for training.
        config (PipelineConfig): The configuration object for the pipeline.
        workers (int): The number of workers used for data loading.
        checkpoint_path (Path): The path to save checkpoints during training.
        model (Optional[TModel]): The model used for training.
        model_config (Optional[TConfig]): The configuration object for the model.
        module_name (Optional[str]): The name of the module.
        experiment_name (Optional[str]): The name of the experiment.
        type (str): The type of the trainer.
        device (torch.device): The device used for training.
        use_mtl (bool): Whether to use multi-task learning.
        use_gan (bool): Whether to use generative adversarial networks.
    """

    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: Optional[str] = None,
        experiment_name: Optional[str] = None,
        force_cpu: bool = False,
    ):
        self.dataset = dataset
        self.config = config
        self.workers = workers
        self.checkpoint_path = checkpoint_path

        self.model = None
        self.model_config = None
        self.module_name = module_name
        self.experiment_name = experiment_name
        self.type = "base"
        self.device = (
            torch.device("cuda")
            if torch.cuda.is_available() and not force_cpu
            else torch.device("cpu")
        )
        self.use_mtl = config.training.use_mtl
        self.use_gan = config.training.with_gan
        logger.info(f"Use {self.device} for training")

    @cached_property
    def run_base(self) -> str:
        """
        Generates a base name for the training run. Used e.g. for MLflow run names.

        Returns:
            str: The base name for the Training run.
        """
        base = f"{self.type}_{'wmtl' if self.use_mtl else 'womtl'}_{'wgan' if self.use_gan else 'wogan'}"
        if self.module_name is not None:
            base += f"_{self.module_name}"
        return base

    def get_dataloader(
        self, dataset: VambnDataset, batch_size: int, shuffle: bool
    ) -> DataLoader:
        """
        Get a DataLoader object for the given dataset.

        Args:
            dataset (VambnDataset): The dataset to use.
            batch_size (int): The batch size.
            shuffle (bool): Whether to shuffle the data.

        Returns:
            DataLoader: The DataLoader object.

        Raises:
            ValueError: If the dataset is empty.
        """
        # Set the number of workers to 0
        self.workers = 0

        # Create and return the DataLoader object
        return DataLoader(
            dataset.get_iter_dataset(self.module_name)
            if self.module_name is not None
            else dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.workers,
            # collate_fn=self.custom_collate,
            persistent_workers=True if self.workers > 0 else False,
            pin_memory=True if self.device.type == "cuda" else False,
            drop_last=True
            if shuffle and len(dataset) % batch_size <= 3
            else False,
        )

    def _set_device(self):
        """
        Sets the device for the model and dataset.

        This method sets the device for the model and dataset to the device specified in the `device` attribute.

        Returns:
            None
        """
        if self.model is not None:
            self.model.to(self.device)

    def multiple_objective_selection(
        self, study: optuna.Study, corr_weight: float = 0.8
    ) -> optuna.trial.FrozenTrial:
        """
        Selects the best trial from a given Optuna study based on multiple objectives.

        Args:
            study (optuna.Study): The Optuna study object.
            corr_weight (float, optional): The weight for the relative correlation error. Defaults to 0.8.

        Returns:
            optuna.trial.FrozenTrial: The best trial.

        Raises:
            ValueError: If no trials are found in the study.
        """

        # Get the best trials from the study
        best_trials = study.best_trials

        if not best_trials:
            raise ValueError("No trials found in the study.")

        # Calculate the weighted sum of relative correlation error and loss for each trial
        weighted_scores = []
        for trial in best_trials:
            corr_error = trial.values[1]
            loss = trial.values[0]
            weighted_score = corr_weight * corr_error + (1 - corr_weight) * loss
            weighted_scores.append(weighted_score)

        # Find the index of the trial with the minimum weighted score
        best_index = weighted_scores.index(min(weighted_scores))

        # Select the best trial based on the weighted score
        best_trial = best_trials[best_index]
        logger.info(f"Selected trial: {best_trial.number}")

        return best_trial

    def hyperopt(self, study: optuna.Study, num_trials: int) -> Hyperparameters:
        """
        Perform hyperparameter optimization using Optuna.

        Args:
            study (optuna.Study): The Optuna study object.
            num_trials (int): The number of trials to run.

        Returns:
            Hyperparameters: The best hyperparameters found during optimization.

        Raises:
            ValueError: If no trials are found in the study.
        """
        with mlflow.start_run(run_name=f"{self.run_base}_hyperopt"):
            # Optimize the study
            study.optimize(self._objective, n_trials=num_trials)

            # Get the best trial parameters
            if self.config.optimization.use_relative_correlation_error_for_optimization:
                trial = self.multiple_objective_selection(study)
            else:
                trial = study.best_trial

            # Extract the best hyperparameters
            params = trial.params
            best_epoch = trial.user_attrs.get("best_epoch", None)
            params["epochs"] = best_epoch

            # Process the parameters
            params["batch_size"] = 2 ** params.pop("batch_size_n")
            if "hidden_dim_s" in params:
                params["dim_s"] = params.pop("hidden_dim_s")
            else:
                matching_keys = [k for k in params if "hidden_dim_s" in k]
                dim_s = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    dim_s[module_name] = params.pop(key)
                params["dim_s"] = dim_s
            if "hidden_dim_y" in params:
                params["dim_y"] = params.pop("hidden_dim_y")
            else:
                matching_keys = [k for k in params if "hidden_dim_y_" in k]
                dim_y = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    dim_y[module_name] = params.pop(key)
                params["dim_y"] = dim_y
            if "hidden_dim_ys" in params:
                params["dim_ys"] = params.pop("hidden_dim_ys")
            params["dim_z"] = params.pop("hidden_dim_z")
            if "learning_rate" not in params:
                matching_keys = [k for k in params if "learning_rate" in k]
                learning_rate = {}
                for key in matching_keys:
                    module_name = key.split("_")[-1]
                    learning_rate[module_name] = params.pop(key)
                params["learning_rate"] = learning_rate

            # Create Hyperparameters object
            hyperparameters = Hyperparameters(dropout=0.1, **params)

            # Log hyperparameters to MLflow
            mlflow.log_params(hyperparameters.__dict__)

        return hyperparameters

    def cleanup_checkpoints(self):
        """
        Cleans up the checkpoints directory.

        This method deletes the entire checkpoints directory specified by `self.checkpoint_path`.

        Returns:
            None
        """
        delete_directory(self.checkpoint_path)

    def optimize_model(self, model: Optional[TModel] = None) -> TModel:
        """
        Optimizes the model using a specified optimization function.

        Args:
            model (Optional[TModel], optional): The model to optimize. If None, the method optimizes self.model. Defaults to None.

        Returns:
            TModel: The optimized model.

        Notes:
            - The optimization function is specified by the opt_func variable.
            - The opt_func function should take a model as input and return an optimized model.
            - If model is None, the method optimizes self.model.
            - If model is not None, the method optimizes the specified model.

        Raises:
            TypeError: If the model is not of type TModel.

        """
        # Define the optimization function
        opt_func = partial(torch.compile, mode="reduce-overhead")
        # opt_func = lambda model: model

        if model is None:
            # Optimize self.model
            self.model = opt_func(model=self.model)
            return self.model
        else:
            return opt_func(model=model)

    def _add_encs(
        self,
        encodings_s: Dict[str, torch.Tensor],
        encodings_z: Dict[str, torch.Tensor],
        meta_enc: Dict[str, np.ndarray],
        decoder_output: DecoderOutput,
    ) -> None:
        """
        Helper function to prepare the encodings from the decoder output.

        This function takes the decoder output and adds the s and z encodings to the respective dictionaries.
        It also updates the meta encoding dictionary with the s and z encodings.

        Args:
            encodings_s (Dict[str, torch.Tensor]): Dictionary of s encodings per module.
            encodings_z (Dict[str, torch.Tensor]): Dictionary of z encodings per module.
            meta_enc (Dict[str, np.ndarray]): Entire meta encoding.
            decoder_output (DecoderOutput): The decoder output.

        Returns:
            None
        """
        dict_name = decoder_output.output_name

        # Add s and z encodings to the respective dictionaries
        if dict_name in encodings_s:
            encodings_s[dict_name].append(decoder_output.enc_s)
            encodings_z[dict_name].append(decoder_output.enc_z)
        else:
            encodings_s[dict_name] = [decoder_output.enc_s]
            encodings_z[dict_name] = [decoder_output.enc_z]

        # Update meta encoding dictionary with s and z encodings
        if decoder_output.enc_s.shape[1] > 1:
            s_dist = torch.argmax(decoder_output.enc_s, dim=1)
            s_name = f"{decoder_output.output_name}_s"
            if s_name in meta_enc:
                meta_enc[s_name] = np.concatenate(
                    [meta_enc[s_name], s_dist.numpy()]
                )
            else:
                meta_enc[s_name] = s_dist.numpy()
        for i in range(decoder_output.enc_z.shape[1]):
            z_name = f"{decoder_output.output_name}_z{i}"
            if z_name in meta_enc:
                meta_enc[z_name] = np.concatenate(
                    [meta_enc[z_name], decoder_output.enc_z[:, i].numpy()]
                )
            else:
                meta_enc[z_name] = decoder_output.enc_z[:, i].numpy()

    def save_model(self, path: Path) -> None:
        """
        Save the model and its configuration to the specified path.

        Args:
            path (Path): The path to save the model.

        Returns:
            None
        """
        self.save_model_config(path / "config.pkl")
        torch.save(self.model.state_dict(), path / "model.bin")

    def load_model(self, path: Path) -> None:
        """
        Load the model and its configuration from the specified path.

        Args:
            path (Path): The path to load the model from.

        Returns:
            None
        """
        self.read_model_config(path / "config.pkl")
        self.model = self.init_model(self.model_config)
        state_dict = torch.load(path / "model.bin")
        try:
            self.model.load_state_dict(state_dict)
        except RuntimeError:
            state_dict = OrderedDict(
                {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
            )
            self.model.load_state_dict(state_dict)

    @abstractmethod
    def save_model_config(self, path: Path) -> None:
        """
        Save the model configuration to the specified path.

        Args:
            path (Path): The path to save the model configuration.

        Returns:
            None
        """
        pass

    @abstractmethod
    def read_model_config(self, path: Path) -> None:
        """
        Read the model configuration from the specified path.

        Args:
            path (Path): The path to read the model configuration from.

        Returns:
            None
        """
        pass

    def cv_generator(
        self, splits: int, seed: int = 42
    ) -> Generator[Tuple[VambnDataset, VambnDataset], None, None]:
        """
        Generates train and validation datasets for cross-validation.

        Args:
            splits (int): Number of splits.
            seed (int, optional): Seed for split. Defaults to 42.

        Raises:
            ValueError: Number of splits must be greater than 0.

        Yields:
            Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.
        """
        if splits < 1:
            raise ValueError("splits must be greater than 0")

        data_idx = np.arange(len(self.dataset))
        if splits == 1:
            train_idx, val_idx = train_test_split(
                data_idx, test_size=0.2, random_state=seed
            )
            yield (
                self.dataset.subset_by_idx(train_idx),
                self.dataset.subset_by_idx(val_idx),
            )
        else:
            cv = KFold(n_splits=splits, shuffle=True, random_state=seed)
            for train_idx, val_idx in cv.split(data_idx):
                if len(train_idx) == 0 or len(val_idx) == 0:
                    raise Exception("Empty split")
                yield (
                    self.dataset.subset_by_idx(train_idx),
                    self.dataset.subset_by_idx(val_idx),
                )

    @staticmethod
    def __lstm_forward_hook(module, input, output):
        # output is a tuple (output_tensor, (h_n, c_n))
        return output[0]  # We only want the output tensor for the next layer

    @staticmethod
    def setup_y_layer(
        z_dim: int, y_dim: int, dropout: float, n_layers: Optional[int] = None
    ) -> torch.nn.Module:
        if n_layers is None:
            return torch.nn.Sequential(
                # torch.nn.LayerNorm(z_dim),
                ModifiedLinear(z_dim, y_dim),
                torch.nn.Dropout(dropout),
            )
        else:
            layers = [
                # torch.nn.LayerNorm(z_dim),
                torch.nn.LSTM(
                    input_size=z_dim,
                    hidden_size=y_dim,
                    num_layers=n_layers,
                    dropout=dropout if n_layers > 1 else 0.0,
                    batch_first=True,
                ),
            ]

            layers[-1].register_forward_hook(BaseTrainer.__lstm_forward_hook)

            layers.append(torch.nn.Dropout(dropout))

            return torch.nn.Sequential(*layers)

    @abstractmethod
    def init_model(self, config: TConfig) -> TModel:
        pass

    @abstractmethod
    def train(self, best_parameters: Hyperparameters) -> TModel:
        pass

    def predict(self, dl: DataLoader | None) -> TPredict:
        if dl is None:
            data = self.dataset.get_iter_dataset(self.module_name)
            dl = self.get_dataloader(data, batch_size=128, shuffle=False)

        with torch.no_grad():
            return self.model.predict(dl)

    def decode(
        self,
        encoding: HivaeEncoding | ModularHivaeEncoding,
        use_mode: bool = True,
    ) -> Dict[str, torch.Tensor]:
        self.model.eval()
        self.model.decoding = use_mode

        return {self.module_name: self.model.decode(encoding)}

    def save_trainer(self, path: Path) -> None:
        path.parent.mkdir(parents=True, exist_ok=True)
        self.save_model(path)
        self.model = None
        torch.save(self, path / "trainer.pkl", pickle_module=dill)

        with (path / "pipeline-config.yml").open("w") as f:
            f.write(yaml.dump(self.config.__dict__))

    @classmethod
    def load_trainer(cls, path: Path) -> TrainerType:
        trainer = torch.load(path / "trainer.pkl", pickle_module=dill)
        trainer.load_model(path)
        return trainer

    @abstractmethod
    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        pass

    @abstractmethod
    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        pass

    def __handle_continous_data(
        self,
        original: pd.Series,
        decoded: pd.Series,
        output_file: Path,
        dtype: str,
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.boxplot(x="variable", y="value", data=df)
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)
        if orig_tensor.isnan().any() or orig_tensor.isinf().any():
            raise Exception("NaN values in original tensor")

        if dec_tensor.isnan().any() or dec_tensor.isinf().any():
            logger.warning("NaN values in decoded tensor")
            logger.warning(
                f"Found nan values: {dec_tensor.isnan().any()}, {dec_tensor.isnan().sum()}"
            )
            logger.warning(
                f"Found inf values: {dec_tensor.isinf().any()}, {dec_tensor.isinf().sum()}"
            )
            dec_tensor = dec_tensor.nan_to_num()
            dec_tensor[dec_tensor == float("inf")] = dec_tensor[
                dec_tensor != float("inf")
            ].max()
            logger.info(
                f"After replacing: {dec_tensor.isnan().any()}, {dec_tensor.isnan().sum()}"
            )
            logger.info(
                f"After replacing: {dec_tensor.isinf().any()}, {dec_tensor.isinf().sum()}"
            )

        error = float(
            nrmse(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        )
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, dtype)
        try:
            statistic, pval = pearsonr(
                decoded.values.tolist(), original.values.tolist()
            )
        except ValueError:
            statistic, pval = 0.0, 1.0
            logger.warning("ValueError in pearsonr")
        return error, jsd, statistic, pval

    def handle_pos_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        return self.__handle_continous_data(
            original, decoded, output_file, "pos"
        )

    def handle_real_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        return self.__handle_continous_data(
            original, decoded, output_file, "real"
        )

    def handle_count_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.histplot(
            data=df, x="value", hue="variable", stat="probability", bins=30
        )
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)

        nrmse_val = nrmse(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, "count")
        statistic, pval = pearsonr(
            decoded.astype(int).values, original.astype(int).values
        )
        return nrmse_val, jsd, statistic, pval

    def handle_categorical_data(
        self, original: pd.Series, decoded: pd.Series, output_file: Path
    ) -> Tuple[float, float, float, float]:
        df = pd.DataFrame({"original": original, "decoded": decoded}).melt()
        sns.histplot(data=df, x="value", hue="variable", stat="probability")
        plt.savefig(output_file, dpi=300)
        plt.close()

        dec_tensor = torch.tensor(decoded.values)
        orig_tensor = torch.tensor(original.values)

        error = accuracy(dec_tensor, orig_tensor, torch.ones_like(dec_tensor))
        jsd = jensen_shannon_distance(dec_tensor, orig_tensor, "cat")
        statistic, pval = spearmanr(
            decoded.values.tolist(), original.values.tolist()
        )
        return error, jsd, statistic, pval

    @abstractmethod
    def process_encodings(self, predictions: TPredict) -> TEncoding:
        raise NotImplementedError

    @staticmethod
    def reverse_scale(x: Tensor, variable_types: VarTypes) -> Tensor:
        copied_x = x.clone()
        for i, var_type in enumerate(variable_types):
            if var_type.data_type in ["real", "pos", "truncate_norm", "gamma"]:
                copied_x[:, i] = var_type.reverse_scale(copied_x[:, i])
        return copied_x

    def evaluate(self, dl: DataLoader, output_path: Path):
        output_path.mkdir(parents=True, exist_ok=True)

        predictions = self.predict(dl)
        encoding = self.process_encodings(predictions)

        data_output_path = output_path / "data_outputs"
        data_output_path.mkdir(exist_ok=True, parents=True)
        encoding.save_meta_enc(data_output_path / "meta_enc.csv")

        modules = (
            (self.module_name,)
            if self.module_name is not None
            else self.dataset.module_names
        )
        overall_metrics = [None] * len(modules)
        for k, module_name in enumerate(modules):
            submodules = self.dataset.get_modules(module_name)
            data, mask = self.dataset.get_longitudinal_data(module_name)
            sampled_data = encoding.get_samples(module_name)
            if sampled_data.ndim == 2 and data.ndim == 3:
                sampled_data = sampled_data.unsqueeze(1)

            if data.shape != sampled_data.shape:
                raise Exception("Data and sampled data have different shapes")

            if data.shape != mask.shape:
                raise Exception("Data and mask have different shapes")

            if data.ndim == 2:
                data = data.unsqueeze(1)
                sampled_data = sampled_data.unsqueeze(1)
                mask = mask.unsqueeze(1)

            num_timepoints = data.shape[1]
            decoded_data = [None] * num_timepoints
            original_data = [None] * num_timepoints
            mask_data = [None] * num_timepoints
            colnames = [
                re.sub("_VIS[0-9]+", "", c) for c in submodules[0].columns
            ]

            for time_point in range(num_timepoints):
                data_df = pd.DataFrame(
                    self.reverse_scale(
                        data[:, time_point, :], submodules[0].variable_types
                    ),
                    columns=colnames,
                )
                data_df["SUBJID"] = self.dataset.subj
                data_df["VISIT"] = time_point + 1
                data_df.set_index(["SUBJID", "VISIT"], inplace=True)
                original_data[time_point] = data_df

                sampled_data_df = pd.DataFrame(
                    sampled_data[:, time_point, :],
                    columns=colnames,
                )
                sampled_data_df["SUBJID"] = self.dataset.subj
                sampled_data_df["VISIT"] = time_point + 1
                sampled_data_df.set_index(["SUBJID", "VISIT"], inplace=True)
                decoded_data[time_point] = sampled_data_df

                mask_df = pd.DataFrame(mask[:, time_point, :], columns=colnames)
                mask_df["SUBJID"] = self.dataset.subj
                mask_df["VISIT"] = time_point + 1
                mask_df.set_index(["SUBJID", "VISIT"], inplace=True)
                mask_data[time_point] = mask_df

            original_data = pd.concat(original_data)
            decoded_data = pd.concat(decoded_data)

            # assert that we have the same visit ids in the original and decoded data
            assert original_data.index.equals(decoded_data.index), (
                f"Original and decoded data have different indices: "
                f"{original_data.index} != {decoded_data.index}"
            )
            if decoded_data.isna().any().any():
                raise Exception("NaN in decoded data")
            mask_data = pd.concat(mask_data)

            decoded_data.to_csv(data_output_path / f"{module_name}_decoded.csv")
            mask_data.to_csv(data_output_path / f"{module_name}_mask.csv")
            original_data.to_csv(
                data_output_path / f"{module_name}_original.csv"
            )

            # Calculate metrics per column
            error = [None] * len(colnames)
            jsd = [None] * len(colnames)
            correlation_stat = [None] * len(colnames)
            correlation_pval = [None] * len(colnames)

            distribution_path = output_path / "distributions"
            distribution_path.mkdir(parents=True, exist_ok=True)

            for i, column in enumerate(colnames):
                orig_col = original_data[column]
                decoded_col = decoded_data[column]
                mask_col = mask_data[column]

                orig_avail = orig_col[mask_col == 1]
                decoded_avail = decoded_col[mask_col == 1]

                col_type = submodules[0].variable_types[i].data_type
                if col_type in ["real", "pos", "truncate_norm", "gamma"]:
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_real_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                elif col_type == "count":
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_count_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                elif col_type == "cat":
                    (
                        error[i],
                        jsd[i],
                        correlation_stat[i],
                        correlation_pval[i],
                    ) = self.handle_categorical_data(
                        orig_avail,
                        decoded_avail,
                        distribution_path / f"{module_name}_{column}.png",
                    )
                else:
                    raise ValueError(f"Unknown data type {col_type}")

            overall_metrics[k] = pd.DataFrame(
                {
                    "module_name": [module_name] * len(colnames),
                    "column": colnames,
                    "error": [float(x) for x in error],
                    "jsd": jsd,
                    "correlation_stat": correlation_stat,
                    "correlation_pval": correlation_pval,
                }
            )

        overall_metrics = pd.concat(overall_metrics)
        overall_metrics.to_csv(output_path / "overall_metrics.csv", index=False)

run_base: str cached property

Generates a base name for the training run. Used e.g. for MLflow run names.

Returns:

Name Type Description
str str

The base name for the Training run.

cleanup_checkpoints()

Cleans up the checkpoints directory.

This method deletes the entire checkpoints directory specified by self.checkpoint_path.

Returns:

Type Description

None

Source code in vambn/modelling/models/hivae/trainer.py
411
412
413
414
415
416
417
418
419
420
def cleanup_checkpoints(self):
    """
    Cleans up the checkpoints directory.

    This method deletes the entire checkpoints directory specified by `self.checkpoint_path`.

    Returns:
        None
    """
    delete_directory(self.checkpoint_path)

cv_generator(splits, seed=42)

Generates train and validation datasets for cross-validation.

Parameters:

Name Type Description Default
splits int

Number of splits.

required
seed int

Seed for split. Defaults to 42.

42

Raises:

Type Description
ValueError

Number of splits must be greater than 0.

Yields:

Type Description
VambnDataset

Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.

Source code in vambn/modelling/models/hivae/trainer.py
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def cv_generator(
    self, splits: int, seed: int = 42
) -> Generator[Tuple[VambnDataset, VambnDataset], None, None]:
    """
    Generates train and validation datasets for cross-validation.

    Args:
        splits (int): Number of splits.
        seed (int, optional): Seed for split. Defaults to 42.

    Raises:
        ValueError: Number of splits must be greater than 0.

    Yields:
        Generator[Tuple[VambnDataset, VambnDataset], None, None]: A generator that yields tuples of train and validation datasets.
    """
    if splits < 1:
        raise ValueError("splits must be greater than 0")

    data_idx = np.arange(len(self.dataset))
    if splits == 1:
        train_idx, val_idx = train_test_split(
            data_idx, test_size=0.2, random_state=seed
        )
        yield (
            self.dataset.subset_by_idx(train_idx),
            self.dataset.subset_by_idx(val_idx),
        )
    else:
        cv = KFold(n_splits=splits, shuffle=True, random_state=seed)
        for train_idx, val_idx in cv.split(data_idx):
            if len(train_idx) == 0 or len(val_idx) == 0:
                raise Exception("Empty split")
            yield (
                self.dataset.subset_by_idx(train_idx),
                self.dataset.subset_by_idx(val_idx),
            )

get_dataloader(dataset, batch_size, shuffle)

Get a DataLoader object for the given dataset.

Parameters:

Name Type Description Default
dataset VambnDataset

The dataset to use.

required
batch_size int

The batch size.

required
shuffle bool

Whether to shuffle the data.

required

Returns:

Name Type Description
DataLoader DataLoader

The DataLoader object.

Raises:

Type Description
ValueError

If the dataset is empty.

Source code in vambn/modelling/models/hivae/trainer.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def get_dataloader(
    self, dataset: VambnDataset, batch_size: int, shuffle: bool
) -> DataLoader:
    """
    Get a DataLoader object for the given dataset.

    Args:
        dataset (VambnDataset): The dataset to use.
        batch_size (int): The batch size.
        shuffle (bool): Whether to shuffle the data.

    Returns:
        DataLoader: The DataLoader object.

    Raises:
        ValueError: If the dataset is empty.
    """
    # Set the number of workers to 0
    self.workers = 0

    # Create and return the DataLoader object
    return DataLoader(
        dataset.get_iter_dataset(self.module_name)
        if self.module_name is not None
        else dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=self.workers,
        # collate_fn=self.custom_collate,
        persistent_workers=True if self.workers > 0 else False,
        pin_memory=True if self.device.type == "cuda" else False,
        drop_last=True
        if shuffle and len(dataset) % batch_size <= 3
        else False,
    )

hyperopt(study, num_trials)

Perform hyperparameter optimization using Optuna.

Parameters:

Name Type Description Default
study Study

The Optuna study object.

required
num_trials int

The number of trials to run.

required

Returns:

Name Type Description
Hyperparameters Hyperparameters

The best hyperparameters found during optimization.

Raises:

Type Description
ValueError

If no trials are found in the study.

Source code in vambn/modelling/models/hivae/trainer.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def hyperopt(self, study: optuna.Study, num_trials: int) -> Hyperparameters:
    """
    Perform hyperparameter optimization using Optuna.

    Args:
        study (optuna.Study): The Optuna study object.
        num_trials (int): The number of trials to run.

    Returns:
        Hyperparameters: The best hyperparameters found during optimization.

    Raises:
        ValueError: If no trials are found in the study.
    """
    with mlflow.start_run(run_name=f"{self.run_base}_hyperopt"):
        # Optimize the study
        study.optimize(self._objective, n_trials=num_trials)

        # Get the best trial parameters
        if self.config.optimization.use_relative_correlation_error_for_optimization:
            trial = self.multiple_objective_selection(study)
        else:
            trial = study.best_trial

        # Extract the best hyperparameters
        params = trial.params
        best_epoch = trial.user_attrs.get("best_epoch", None)
        params["epochs"] = best_epoch

        # Process the parameters
        params["batch_size"] = 2 ** params.pop("batch_size_n")
        if "hidden_dim_s" in params:
            params["dim_s"] = params.pop("hidden_dim_s")
        else:
            matching_keys = [k for k in params if "hidden_dim_s" in k]
            dim_s = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                dim_s[module_name] = params.pop(key)
            params["dim_s"] = dim_s
        if "hidden_dim_y" in params:
            params["dim_y"] = params.pop("hidden_dim_y")
        else:
            matching_keys = [k for k in params if "hidden_dim_y_" in k]
            dim_y = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                dim_y[module_name] = params.pop(key)
            params["dim_y"] = dim_y
        if "hidden_dim_ys" in params:
            params["dim_ys"] = params.pop("hidden_dim_ys")
        params["dim_z"] = params.pop("hidden_dim_z")
        if "learning_rate" not in params:
            matching_keys = [k for k in params if "learning_rate" in k]
            learning_rate = {}
            for key in matching_keys:
                module_name = key.split("_")[-1]
                learning_rate[module_name] = params.pop(key)
            params["learning_rate"] = learning_rate

        # Create Hyperparameters object
        hyperparameters = Hyperparameters(dropout=0.1, **params)

        # Log hyperparameters to MLflow
        mlflow.log_params(hyperparameters.__dict__)

    return hyperparameters

load_model(path)

Load the model and its configuration from the specified path.

Parameters:

Name Type Description Default
path Path

The path to load the model from.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
def load_model(self, path: Path) -> None:
    """
    Load the model and its configuration from the specified path.

    Args:
        path (Path): The path to load the model from.

    Returns:
        None
    """
    self.read_model_config(path / "config.pkl")
    self.model = self.init_model(self.model_config)
    state_dict = torch.load(path / "model.bin")
    try:
        self.model.load_state_dict(state_dict)
    except RuntimeError:
        state_dict = OrderedDict(
            {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
        )
        self.model.load_state_dict(state_dict)

multiple_objective_selection(study, corr_weight=0.8)

Selects the best trial from a given Optuna study based on multiple objectives.

Parameters:

Name Type Description Default
study Study

The Optuna study object.

required
corr_weight float

The weight for the relative correlation error. Defaults to 0.8.

0.8

Returns:

Type Description
FrozenTrial

optuna.trial.FrozenTrial: The best trial.

Raises:

Type Description
ValueError

If no trials are found in the study.

Source code in vambn/modelling/models/hivae/trainer.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def multiple_objective_selection(
    self, study: optuna.Study, corr_weight: float = 0.8
) -> optuna.trial.FrozenTrial:
    """
    Selects the best trial from a given Optuna study based on multiple objectives.

    Args:
        study (optuna.Study): The Optuna study object.
        corr_weight (float, optional): The weight for the relative correlation error. Defaults to 0.8.

    Returns:
        optuna.trial.FrozenTrial: The best trial.

    Raises:
        ValueError: If no trials are found in the study.
    """

    # Get the best trials from the study
    best_trials = study.best_trials

    if not best_trials:
        raise ValueError("No trials found in the study.")

    # Calculate the weighted sum of relative correlation error and loss for each trial
    weighted_scores = []
    for trial in best_trials:
        corr_error = trial.values[1]
        loss = trial.values[0]
        weighted_score = corr_weight * corr_error + (1 - corr_weight) * loss
        weighted_scores.append(weighted_score)

    # Find the index of the trial with the minimum weighted score
    best_index = weighted_scores.index(min(weighted_scores))

    # Select the best trial based on the weighted score
    best_trial = best_trials[best_index]
    logger.info(f"Selected trial: {best_trial.number}")

    return best_trial

optimize_model(model=None)

Optimizes the model using a specified optimization function.

Parameters:

Name Type Description Default
model Optional[TModel]

The model to optimize. If None, the method optimizes self.model. Defaults to None.

None

Returns:

Name Type Description
TModel TModel

The optimized model.

Notes
  • The optimization function is specified by the opt_func variable.
  • The opt_func function should take a model as input and return an optimized model.
  • If model is None, the method optimizes self.model.
  • If model is not None, the method optimizes the specified model.

Raises:

Type Description
TypeError

If the model is not of type TModel.

Source code in vambn/modelling/models/hivae/trainer.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def optimize_model(self, model: Optional[TModel] = None) -> TModel:
    """
    Optimizes the model using a specified optimization function.

    Args:
        model (Optional[TModel], optional): The model to optimize. If None, the method optimizes self.model. Defaults to None.

    Returns:
        TModel: The optimized model.

    Notes:
        - The optimization function is specified by the opt_func variable.
        - The opt_func function should take a model as input and return an optimized model.
        - If model is None, the method optimizes self.model.
        - If model is not None, the method optimizes the specified model.

    Raises:
        TypeError: If the model is not of type TModel.

    """
    # Define the optimization function
    opt_func = partial(torch.compile, mode="reduce-overhead")
    # opt_func = lambda model: model

    if model is None:
        # Optimize self.model
        self.model = opt_func(model=self.model)
        return self.model
    else:
        return opt_func(model=model)

read_model_config(path) abstractmethod

Read the model configuration from the specified path.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
551
552
553
554
555
556
557
558
559
560
561
562
@abstractmethod
def read_model_config(self, path: Path) -> None:
    """
    Read the model configuration from the specified path.

    Args:
        path (Path): The path to read the model configuration from.

    Returns:
        None
    """
    pass

save_model(path)

Save the model and its configuration to the specified path.

Parameters:

Name Type Description Default
path Path

The path to save the model.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
504
505
506
507
508
509
510
511
512
513
514
515
def save_model(self, path: Path) -> None:
    """
    Save the model and its configuration to the specified path.

    Args:
        path (Path): The path to save the model.

    Returns:
        None
    """
    self.save_model_config(path / "config.pkl")
    torch.save(self.model.state_dict(), path / "model.bin")

save_model_config(path) abstractmethod

Save the model configuration to the specified path.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required

Returns:

Type Description
None

None

Source code in vambn/modelling/models/hivae/trainer.py
538
539
540
541
542
543
544
545
546
547
548
549
@abstractmethod
def save_model_config(self, path: Path) -> None:
    """
    Save the model configuration to the specified path.

    Args:
        path (Path): The path to save the model configuration.

    Returns:
        None
    """
    pass

Hyperparameters dataclass

Class representing the hyperparameters for the model trainer.

Parameters:

Name Type Description Default
dim_s int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the input sequence(s).

required
dim_y int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the output sequence(s).

required
dim_z int

Dimension of the latent space.

required
dropout float

Dropout rate.

required
batch_size int

Batch size.

required
learning_rate float | Dict[str, float] | Tuple[float, ...]

Learning rate(s).

required
epochs int

Number of training epochs.

required
mtl_methods Tuple[str, ...]

Multi-task learning methods. Defaults to ("identity",).

('identity')
lstm_layers int

Number of LSTM layers. Defaults to 1.

1
dim_ys Optional[int]

Dimension of the output sequence. Defaults to None.

None

Attributes:

Name Type Description
dim_s int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the input sequence(s).

dim_y int | Dict[str, int] | Tuple[int, ...]

Dimension(s) of the output sequence(s).

dim_z int

Dimension of the latent space.

dropout float

Dropout rate.

batch_size int

Batch size.

learning_rate float | Dict[str, float] | Tuple[float, ...]

Learning rate(s).

epochs int

Number of training epochs.

mtl_methods Tuple[str, ...]

Multi-task learning methods.

lstm_layers int

Number of LSTM layers.

dim_ys Optional[int]

Dimension of the output sequence.

Methods:

Name Description
__post_init__

Post-initialization method.

write_to_json

Path): Write the hyperparameters to a JSON file.

read_from_json

Path): Read the hyperparameters from a JSON file.

Source code in vambn/modelling/models/hivae/trainer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@dataclass
class Hyperparameters:
    """
    Class representing the hyperparameters for the model trainer.

    Args:
        dim_s (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the input sequence(s).
        dim_y (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the output sequence(s).
        dim_z (int): Dimension of the latent space.
        dropout (float): Dropout rate.
        batch_size (int): Batch size.
        learning_rate (float | Dict[str, float] | Tuple[float, ...]): Learning rate(s).
        epochs (int): Number of training epochs.
        mtl_methods (Tuple[str, ...], optional): Multi-task learning methods. Defaults to ("identity",).
        lstm_layers (int, optional): Number of LSTM layers. Defaults to 1.
        dim_ys (Optional[int], optional): Dimension of the output sequence. Defaults to None.

    Attributes:
        dim_s (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the input sequence(s).
        dim_y (int | Dict[str, int] | Tuple[int, ...]): Dimension(s) of the output sequence(s).
        dim_z (int): Dimension of the latent space.
        dropout (float): Dropout rate.
        batch_size (int): Batch size.
        learning_rate (float | Dict[str, float] | Tuple[float, ...]): Learning rate(s).
        epochs (int): Number of training epochs.
        mtl_methods (Tuple[str, ...]): Multi-task learning methods.
        lstm_layers (int): Number of LSTM layers.
        dim_ys (Optional[int]): Dimension of the output sequence.

    Methods:
        __post_init__(self): Post-initialization method.
        write_to_json(self, path: Path): Write the hyperparameters to a JSON file.
        read_from_json(cls, path: Path): Read the hyperparameters from a JSON file.

    """

    dim_s: int | Dict[str, int] | Tuple[int, ...]
    dim_y: int | Dict[str, int] | Tuple[int, ...]
    dim_z: int
    dropout: float
    batch_size: int
    learning_rate: float | Dict[str, float] | Tuple[float, ...]
    epochs: int
    mtl_methods: Tuple[str, ...] = ("identity",)
    lstm_layers: int = 1
    dim_ys: Optional[int] = None

    def __post_init__(self):
        if isinstance(self.mtl_methods, List):
            self.mtl_methods = tuple(self.mtl_methods)
        elif isinstance(self.mtl_methods, str):
            self.mtl_methods = (self.mtl_methods,)

    def write_to_json(self, path: Path):
        """
        Write the hyperparameters to a JSON file.

        Args:
            path (Path): Path to the JSON file.

        """
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("w") as f:
            json.dump(self.__dict__, f, indent=4)

    @classmethod
    def read_from_json(cls, path: Path):
        """
        Read the hyperparameters from a JSON file.

        Args:
            path (Path): Path to the JSON file.

        Returns:
            Hyperparameters: An instance of the Hyperparameters class.

        """
        with path.open("r") as f:
            data = json.load(f)

        tmp = cls(**data)
        # FIXME: fix model script to avoid this step
        # check if , is in mtl_methods
        logger.info(tmp.mtl_methods)
        if len(tmp.mtl_methods) == 1 and "," in tmp.mtl_methods[0]:
            method_str = tmp.mtl_methods[0]
            tmp.mtl_methods = tuple(method_str.split(","))
        logger.info(tmp.mtl_methods)
        logger.debug(tmp)
        return tmp

read_from_json(path) classmethod

Read the hyperparameters from a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file.

required

Returns:

Name Type Description
Hyperparameters

An instance of the Hyperparameters class.

Source code in vambn/modelling/models/hivae/trainer.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@classmethod
def read_from_json(cls, path: Path):
    """
    Read the hyperparameters from a JSON file.

    Args:
        path (Path): Path to the JSON file.

    Returns:
        Hyperparameters: An instance of the Hyperparameters class.

    """
    with path.open("r") as f:
        data = json.load(f)

    tmp = cls(**data)
    # FIXME: fix model script to avoid this step
    # check if , is in mtl_methods
    logger.info(tmp.mtl_methods)
    if len(tmp.mtl_methods) == 1 and "," in tmp.mtl_methods[0]:
        method_str = tmp.mtl_methods[0]
        tmp.mtl_methods = tuple(method_str.split(","))
    logger.info(tmp.mtl_methods)
    logger.debug(tmp)
    return tmp

write_to_json(path)

Write the hyperparameters to a JSON file.

Parameters:

Name Type Description Default
path Path

Path to the JSON file.

required
Source code in vambn/modelling/models/hivae/trainer.py
144
145
146
147
148
149
150
151
152
153
154
def write_to_json(self, path: Path):
    """
    Write the hyperparameters to a JSON file.

    Args:
        path (Path): Path to the JSON file.

    """
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w") as f:
        json.dump(self.__dict__, f, indent=4)

ModularTrainer

Bases: BaseTrainer[GenericMHivaeConfig, GenericMHivaeModel, ModularHivaeOutput, 'ModularTrainer', ModularHivaeEncoding]

Source code in vambn/modelling/models/hivae/trainer.py
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
class ModularTrainer(
    BaseTrainer[
        GenericMHivaeConfig,
        GenericMHivaeModel,
        ModularHivaeOutput,
        "ModularTrainer",
        ModularHivaeEncoding,
    ]
):
    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: str | None = None,
        experiment_name: str | None = None,
        force_cpu: bool = False,
        shared_element: str = "none",
    ):
        """
        Initialize the ModularTrainer class.

        Args:
            dataset: The VambnDataset object.
            config: The PipelineConfig object.
            workers: The number of workers for data loading.
            checkpoint_path: The path to save checkpoints.
            module_name: The name of the module.
            experiment_name: The name of the experiment.
            force_cpu: Whether to force CPU usage.
            shared_element: The type of shared element.
        """
        super().__init__(
            dataset=dataset,
            config=config,
            workers=workers,
            checkpoint_path=checkpoint_path,
            module_name=module_name,
            experiment_name=experiment_name,
            force_cpu=force_cpu,
        )
        self.type = "modular"
        self.shared_element = shared_element

    def process_encodings(
        self, predictions: ModularHivaeOutput
    ) -> ModularHivaeEncoding:
        """
        Process the model predictions and return the encodings.

        Args:
            predictions: The model predictions.

        Returns:
            The ModularHivaeEncoding object.
        """
        return ModularHivaeEncoding(
            encodings=tuple(
                HivaeEncoding(
                    z=pred.enc_z,
                    s=pred.enc_s,
                    module=module_name,
                    samples=pred.samples,
                    subjid=self.dataset.subj,
                )
                for module_name, pred in zip(
                    self.dataset.module_names, predictions
                )
            ),
            modules=self.dataset.module_names,
        )

    def init_model(self, config: ModularHivaeConfig) -> ModularHivae:
        """
        Initialize the ModularHivae model.

        Args:
            config: The ModularHivaeConfig object.

        Returns:
            The initialized ModularHivae model.
        """
        return ModularHivae(
            dim_s=config.dim_s,
            dim_ys=config.dim_ys,
            dim_y=config.dim_y,
            dim_z=config.dim_z,
            module_config=config.module_config,
            mtl_method=config.mtl_method,
            shared_element_type=config.shared_element,
            use_imputation_layer=config.use_imputation_layer,
        )

    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        """
        Generate hyperparameters for the trial.

        Args:
            trial: The optuna.Trial object.

        Returns:
            The generated Hyperparameters object.
        """
        opt = self.config.optimization
        if opt.fixed_s_dim:
            # Use fixed dimension for s
            fixed_dim_s = trial.suggest_int(
                "hidden_dim_s",
                opt.s_dim_lower,
                opt.s_dim_upper,
                opt.s_dim_step,
            )
            dim_s = {
                module_name: fixed_dim_s
                for module_name in self.dataset.module_names
            }
        else:
            # Use different dimensions for s
            dim_s = {
                module_name: trial.suggest_int(
                    f"hidden_dim_s_{module_name}",
                    opt.s_dim_lower,
                    opt.s_dim_upper,
                    opt.s_dim_step,
                )
                for module_name in self.dataset.module_names
            }
        if opt.fixed_y_dim:
            # Use fixed dimension for y
            fixed_y = trial.suggest_int(
                "hidden_dim_y",
                opt.y_dim_lower,
                opt.y_dim_upper,
                opt.y_dim_step,
            )
            dim_y = {module_name: fixed_y for module_name in dim_s}
        else:
            # Use different dimensions for y
            dim_y = {
                module_name: trial.suggest_int(
                    f"hidden_dim_y_{module_name}",
                    opt.y_dim_lower,
                    opt.y_dim_upper,
                    opt.y_dim_step,
                )
                for module_name in self.dataset.module_names
            }
        dim_z = trial.suggest_int(
            "hidden_dim_z",
            opt.latent_dim_lower,
            opt.latent_dim_upper,
            opt.latent_dim_step,
        )
        dim_ys = trial.suggest_int(
            "hidden_dim_ys",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        if opt.fixed_learning_rate:
            lr = trial.suggest_float(
                "learning_rate",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )
            learning_rate = {
                module_name: lr for module_name in self.dataset.module_names
            }
            learning_rate["learning_rate_shared"] = lr
        else:
            learning_rate = {
                module_name: trial.suggest_float(
                    f"learning_rate_{module_name}",
                    opt.learning_rate_lower,
                    opt.learning_rate_upper,
                    log=True,
                )
                for module_name in self.dataset.module_names
            }
            learning_rate["learning_rate_shared"] = trial.suggest_float(
                "learning_rate_shared",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )

        batch_size_n = trial.suggest_int(
            "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
        )
        batch_size = 2**batch_size_n

        if self.dataset.is_longitudinal:
            lstm_layers = trial.suggest_int(
                "lstm_layers",
                opt.lstm_layers_lower,
                opt.lstm_layers_upper,
                opt.lstm_layers_step,
            )
        else:
            lstm_layers = 1
        if self.use_mtl:
            mtl_string = trial.suggest_categorical(
                "mtl_methods",
                [
                    "gradnorm",
                    "graddrop",
                    "gradnorm,graddrop",
                ],
            )
            mtl_methods = (
                tuple(mtl_string.split(","))
                if "," in mtl_string
                else tuple([mtl_string])
            )
        else:
            mtl_methods = ("identity",)
        return Hyperparameters(
            dim_s=dim_s,
            dim_y=dim_y,
            dropout=0.1,
            batch_size=batch_size,
            learning_rate=learning_rate,
            epochs=opt.max_epochs,
            lstm_layers=lstm_layers,
            mtl_methods=mtl_methods,
            dim_ys=dim_ys,
            dim_z=dim_z,
        )

    def get_module_config(
        self, hyperparameters: Hyperparameters
    ) -> Tuple[DataModuleConfig, ...]:
        """
        Get the module configuration based on the hyperparameters.

        Args:
            hyperparameters: The Hyperparameters object.

        Returns:
            A tuple of DataModuleConfig objects.
        """
        module_configs = []
        for module_name in self.dataset.module_names:
            module = self.dataset.get_modules(module_name)[0]
            module_config = DataModuleConfig(
                name=module_name,
                variable_types=module.variable_types,
                n_layers=hyperparameters.lstm_layers,
                num_timepoints=self.dataset.visits_per_module[module_name],
            )
            module_configs.append(module_config)
        return tuple(module_configs)

    def train(self, best_parameters: Hyperparameters) -> GenericMHivaeModel:
        whole_dataloader = self.get_dataloader(
            self.dataset, best_parameters.batch_size, shuffle=True
        )
        module_config = self.get_module_config(best_parameters)
        self.model_config = ModularHivaeConfig(
            module_config=module_config,
            dim_z=best_parameters.dim_z,
            dim_s=best_parameters.dim_s,
            dim_y=best_parameters.dim_y,
            dropout=best_parameters.dropout,
            mtl_method=best_parameters.mtl_methods,
            n_layers=best_parameters.lstm_layers,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dim_ys=best_parameters.dim_ys,
        )
        self.model = self.init_model(self.model_config)
        if isinstance(best_parameters.learning_rate, Dict):
            order = self.dataset.module_names + ["shared"]
            learning_rate = tuple(
                best_parameters.learning_rate[module_name]
                for module_name in order
            )
        else:
            learning_rate = best_parameters.learning_rate
        with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
            mlflow.log_params(best_parameters.__dict__)
            self.model.fit(
                train_dataloader=whole_dataloader,
                num_epochs=best_parameters.epochs,
                learning_rate=learning_rate,
            )
        return self.model

    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        """
        Objective function for the optimization process.

        Args:
            trial (optuna.Trial): The trial object for hyperparameter optimization.

        Returns:
            float | Tuple[float, float]: The loss value or a tuple of loss values.
        """

        trial_params = self.hyperparameters(trial)
        logger.info(f"Trial parameters: {trial_params}")
        fold_loss = []
        n_epochs = []
        rel_corr_loss = []
        auc_metric = []
        module_config = self.get_module_config(trial_params)
        model_config = ModularHivaeConfig(
            module_config=module_config,
            dim_z=trial_params.dim_z,
            dim_y=trial_params.dim_y,
            dropout=trial_params.dropout,
            mtl_method=trial_params.mtl_methods,
            n_layers=trial_params.lstm_layers,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dim_s=trial_params.dim_s,
            shared_element=self.shared_element,
            dim_ys=trial_params.dim_ys,
        )

        if isinstance(trial_params.learning_rate, Dict):
            order = self.dataset.module_names + ["learning_rate_shared"]
            learning_rate = tuple(
                trial_params.learning_rate[module_name] for module_name in order
            )
        else:
            learning_rate = trial_params.learning_rate
        for i, (train_data, val_data) in enumerate(
            self.cv_generator(self.config.optimization.folds)
        ):
            train_dataloader = self.get_dataloader(
                train_data, trial_params.batch_size, shuffle=True
            )
            val_dataloader = self.get_dataloader(
                val_data, trial_params.batch_size, shuffle=False
            )
            model = self.init_model(model_config)
            with mlflow.start_run(
                run_name=f"{self.run_base}_T{trial._trial_id}_F{i}",
                nested=True,
            ):
                mlflow.log_params(trial_params.__dict__)
                start = time.time()
                try:
                    fit_loss, fit_epoch = model.fit(
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        num_epochs=trial_params.epochs,
                        learning_rate=learning_rate,
                    )
                except ValueError:
                    raise optuna.TrialPruned(
                        "Trial pruned due to error during training. Likely due to unsuitable hyperparameters"
                    )

                if ((time.time() - start) / 3600) > 2:
                    logger.warning(
                        f"Trial pruned due to very long execution ({(time.time() - start) / 3600}h at fold {i})"
                    )
                    raise optuna.TrialPruned()

                res = model.predict(val_dataloader)
                mlflow.log_metric("Final loss", res.avg_loss)
                logger.info(f"Fold {i} loss: {res.avg_loss}")

                fold_loss.append(res.avg_loss)
                n_epochs.append(fit_epoch)

                if (
                    self.config.optimization.use_relative_correlation_error_for_optimization
                    or self.config.optimization.use_auc_for_optimization
                ):
                    # Calculate the relative correlation loss
                    assert len(res.outputs) == len(model.module_configs)

                    orig_data = val_data.to_pandas()
                    module_dfs = []
                    for module_output, conf in zip(
                        res.outputs, model.module_configs
                    ):
                        module_columns = self.dataset.get_modules(conf.name)[
                            0
                        ].columns
                        cleaned_columns = [
                            re.sub(r"_VIS[0-9]+", "", c) for c in module_columns
                        ]
                        module_samples = module_output.samples
                        if module_samples.ndim == 2:
                            module_samples = module_samples.unsqueeze(1)

                        module_data = pd.concat(
                            [
                                pd.DataFrame(
                                    module_samples[:, i, :].numpy(),
                                    columns=cleaned_columns,
                                )
                                for i in range(module_samples.shape[1])
                            ]
                        )
                        module_data["SUBJID"] = (
                            val_data.subj * conf.num_timepoints
                        )
                        module_data["VISIT"] = np.repeat(
                            np.arange(1, module_samples.shape[1] + 1),
                            len(val_data),
                        )
                        module_dfs.append(module_data)

                    decoded_data = reduce(
                        lambda x, y: pd.merge(
                            x, y, on=["SUBJID", "VISIT"], how="outer"
                        ),
                        module_dfs,
                    )
                    decoded_data_filtered = decoded_data.loc[
                        :, orig_data.columns
                    ].drop(columns=["SUBJID", "VISIT"])
                    orig_data_filtered = orig_data.drop(
                        columns=["SUBJID", "VISIT"]
                    )

                    fold_rel_corr_loss, m1, m2 = RelativeCorrelation.error(
                        real=orig_data_filtered,
                        synthetic=decoded_data_filtered,
                    )
                    mlflow.log_metric(
                        "Relative correlation loss", fold_rel_corr_loss
                    )
                    rel_corr_loss.append(fold_rel_corr_loss)

                    auc = get_auc(orig_data_filtered, decoded_data_filtered)
                    if auc < 0.5:
                        auc = 1 - auc
                    auc_quality = max(math.floor((1 - auc) * 200), 1)
                    mlflow.log_metric("AUC", auc)
                    mlflow.log_metric("AUC quality", auc_quality)
                    auc_metric.append(auc_quality)

        loss = np.mean(fold_loss)
        best_epoch = n_epochs[np.argmin(fold_loss)]
        mlflow.log_metric("Best epoch", best_epoch)
        trial.set_user_attr("best_epoch", best_epoch)
        trial.set_user_attr("n_epochs", n_epochs)

        if self.config.optimization.use_relative_correlation_error_for_optimization:
            rel_corr_loss = np.mean(rel_corr_loss)
            logger.info(f"Trial loss: {loss}; rel_corr_loss: {rel_corr_loss}")
            return loss, rel_corr_loss
        elif self.config.optimization.use_auc_for_optimization:
            return np.mean(auc_metric)
        else:
            logger.info(f"Trial loss: {loss}")
            return loss

    def predict(self, dl: DataLoader) -> ModularHivaeOutput:
        """
        Predict the output of the model.

        Args:
            dl (DataLoader): The DataLoader object.

        Returns:
            ModularHivaeOutput: The output of the model.
        """
        with torch.no_grad():
            return self.model.predict(dl)

    def save_model_config(self, path: Path):
        """
        Save the model configuration to a file.

        Args:
            path (Path): The path to save the model configuration.

        Raises:
            Exception: If the model is None.
            Exception: If the model config is None.
        """
        if self.model is None:
            raise Exception("Model should not be none")

        if self.model_config is None:
            raise Exception("Model config should not be none")

        with path.open("wb") as f:
            dill.dump(self.model_config, f)

    def read_model_config(self, path: Path):
        """
        Read the model configuration from a file.

        Args:
            path (Path): The path to read the model configuration from.
        """
        with path.open("rb") as f:
            self.model_config = dill.load(f)

    def decode(
        self,
        encoding: Union[HivaeEncoding, ModularHivaeEncoding],
        use_mode: bool = True,
    ) -> Dict[str, Tensor]:
        """
        Decode the given encoding to obtain the sampled data.

        Args:
            encoding (Union[HivaeEncoding, ModularHivaeEncoding]): The encoding to decode.
            use_mode (bool, optional): Whether to use the mode for decoding. Defaults to True.

        Returns:
            Dict[str, Tensor]: The decoded sampled data, with module names as keys.
        """
        self.model.eval()
        self.model.decoding = use_mode
        with torch.no_grad():
            sampled_data = self.model.decode(encoding)

        sample_dict = {
            module_name: sampled_data[i]
            for i, module_name in enumerate(self.dataset.module_names)
        }
        return sample_dict

__init__(dataset, config, workers, checkpoint_path, module_name=None, experiment_name=None, force_cpu=False, shared_element='none')

Initialize the ModularTrainer class.

Parameters:

Name Type Description Default
dataset VambnDataset

The VambnDataset object.

required
config PipelineConfig

The PipelineConfig object.

required
workers int

The number of workers for data loading.

required
checkpoint_path Path

The path to save checkpoints.

required
module_name str | None

The name of the module.

None
experiment_name str | None

The name of the experiment.

None
force_cpu bool

Whether to force CPU usage.

False
shared_element str

The type of shared element.

'none'
Source code in vambn/modelling/models/hivae/trainer.py
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
def __init__(
    self,
    dataset: VambnDataset,
    config: PipelineConfig,
    workers: int,
    checkpoint_path: Path,
    module_name: str | None = None,
    experiment_name: str | None = None,
    force_cpu: bool = False,
    shared_element: str = "none",
):
    """
    Initialize the ModularTrainer class.

    Args:
        dataset: The VambnDataset object.
        config: The PipelineConfig object.
        workers: The number of workers for data loading.
        checkpoint_path: The path to save checkpoints.
        module_name: The name of the module.
        experiment_name: The name of the experiment.
        force_cpu: Whether to force CPU usage.
        shared_element: The type of shared element.
    """
    super().__init__(
        dataset=dataset,
        config=config,
        workers=workers,
        checkpoint_path=checkpoint_path,
        module_name=module_name,
        experiment_name=experiment_name,
        force_cpu=force_cpu,
    )
    self.type = "modular"
    self.shared_element = shared_element

decode(encoding, use_mode=True)

Decode the given encoding to obtain the sampled data.

Parameters:

Name Type Description Default
encoding Union[HivaeEncoding, ModularHivaeEncoding]

The encoding to decode.

required
use_mode bool

Whether to use the mode for decoding. Defaults to True.

True

Returns:

Type Description
Dict[str, Tensor]

Dict[str, Tensor]: The decoded sampled data, with module names as keys.

Source code in vambn/modelling/models/hivae/trainer.py
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
def decode(
    self,
    encoding: Union[HivaeEncoding, ModularHivaeEncoding],
    use_mode: bool = True,
) -> Dict[str, Tensor]:
    """
    Decode the given encoding to obtain the sampled data.

    Args:
        encoding (Union[HivaeEncoding, ModularHivaeEncoding]): The encoding to decode.
        use_mode (bool, optional): Whether to use the mode for decoding. Defaults to True.

    Returns:
        Dict[str, Tensor]: The decoded sampled data, with module names as keys.
    """
    self.model.eval()
    self.model.decoding = use_mode
    with torch.no_grad():
        sampled_data = self.model.decode(encoding)

    sample_dict = {
        module_name: sampled_data[i]
        for i, module_name in enumerate(self.dataset.module_names)
    }
    return sample_dict

get_module_config(hyperparameters)

Get the module configuration based on the hyperparameters.

Parameters:

Name Type Description Default
hyperparameters Hyperparameters

The Hyperparameters object.

required

Returns:

Type Description
Tuple[DataModuleConfig, ...]

A tuple of DataModuleConfig objects.

Source code in vambn/modelling/models/hivae/trainer.py
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
def get_module_config(
    self, hyperparameters: Hyperparameters
) -> Tuple[DataModuleConfig, ...]:
    """
    Get the module configuration based on the hyperparameters.

    Args:
        hyperparameters: The Hyperparameters object.

    Returns:
        A tuple of DataModuleConfig objects.
    """
    module_configs = []
    for module_name in self.dataset.module_names:
        module = self.dataset.get_modules(module_name)[0]
        module_config = DataModuleConfig(
            name=module_name,
            variable_types=module.variable_types,
            n_layers=hyperparameters.lstm_layers,
            num_timepoints=self.dataset.visits_per_module[module_name],
        )
        module_configs.append(module_config)
    return tuple(module_configs)

hyperparameters(trial)

Generate hyperparameters for the trial.

Parameters:

Name Type Description Default
trial Trial

The optuna.Trial object.

required

Returns:

Type Description
Hyperparameters

The generated Hyperparameters object.

Source code in vambn/modelling/models/hivae/trainer.py
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
    """
    Generate hyperparameters for the trial.

    Args:
        trial: The optuna.Trial object.

    Returns:
        The generated Hyperparameters object.
    """
    opt = self.config.optimization
    if opt.fixed_s_dim:
        # Use fixed dimension for s
        fixed_dim_s = trial.suggest_int(
            "hidden_dim_s",
            opt.s_dim_lower,
            opt.s_dim_upper,
            opt.s_dim_step,
        )
        dim_s = {
            module_name: fixed_dim_s
            for module_name in self.dataset.module_names
        }
    else:
        # Use different dimensions for s
        dim_s = {
            module_name: trial.suggest_int(
                f"hidden_dim_s_{module_name}",
                opt.s_dim_lower,
                opt.s_dim_upper,
                opt.s_dim_step,
            )
            for module_name in self.dataset.module_names
        }
    if opt.fixed_y_dim:
        # Use fixed dimension for y
        fixed_y = trial.suggest_int(
            "hidden_dim_y",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        dim_y = {module_name: fixed_y for module_name in dim_s}
    else:
        # Use different dimensions for y
        dim_y = {
            module_name: trial.suggest_int(
                f"hidden_dim_y_{module_name}",
                opt.y_dim_lower,
                opt.y_dim_upper,
                opt.y_dim_step,
            )
            for module_name in self.dataset.module_names
        }
    dim_z = trial.suggest_int(
        "hidden_dim_z",
        opt.latent_dim_lower,
        opt.latent_dim_upper,
        opt.latent_dim_step,
    )
    dim_ys = trial.suggest_int(
        "hidden_dim_ys",
        opt.y_dim_lower,
        opt.y_dim_upper,
        opt.y_dim_step,
    )
    if opt.fixed_learning_rate:
        lr = trial.suggest_float(
            "learning_rate",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )
        learning_rate = {
            module_name: lr for module_name in self.dataset.module_names
        }
        learning_rate["learning_rate_shared"] = lr
    else:
        learning_rate = {
            module_name: trial.suggest_float(
                f"learning_rate_{module_name}",
                opt.learning_rate_lower,
                opt.learning_rate_upper,
                log=True,
            )
            for module_name in self.dataset.module_names
        }
        learning_rate["learning_rate_shared"] = trial.suggest_float(
            "learning_rate_shared",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )

    batch_size_n = trial.suggest_int(
        "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
    )
    batch_size = 2**batch_size_n

    if self.dataset.is_longitudinal:
        lstm_layers = trial.suggest_int(
            "lstm_layers",
            opt.lstm_layers_lower,
            opt.lstm_layers_upper,
            opt.lstm_layers_step,
        )
    else:
        lstm_layers = 1
    if self.use_mtl:
        mtl_string = trial.suggest_categorical(
            "mtl_methods",
            [
                "gradnorm",
                "graddrop",
                "gradnorm,graddrop",
            ],
        )
        mtl_methods = (
            tuple(mtl_string.split(","))
            if "," in mtl_string
            else tuple([mtl_string])
        )
    else:
        mtl_methods = ("identity",)
    return Hyperparameters(
        dim_s=dim_s,
        dim_y=dim_y,
        dropout=0.1,
        batch_size=batch_size,
        learning_rate=learning_rate,
        epochs=opt.max_epochs,
        lstm_layers=lstm_layers,
        mtl_methods=mtl_methods,
        dim_ys=dim_ys,
        dim_z=dim_z,
    )

init_model(config)

Initialize the ModularHivae model.

Parameters:

Name Type Description Default
config ModularHivaeConfig

The ModularHivaeConfig object.

required

Returns:

Type Description
ModularHivae

The initialized ModularHivae model.

Source code in vambn/modelling/models/hivae/trainer.py
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
def init_model(self, config: ModularHivaeConfig) -> ModularHivae:
    """
    Initialize the ModularHivae model.

    Args:
        config: The ModularHivaeConfig object.

    Returns:
        The initialized ModularHivae model.
    """
    return ModularHivae(
        dim_s=config.dim_s,
        dim_ys=config.dim_ys,
        dim_y=config.dim_y,
        dim_z=config.dim_z,
        module_config=config.module_config,
        mtl_method=config.mtl_method,
        shared_element_type=config.shared_element,
        use_imputation_layer=config.use_imputation_layer,
    )

predict(dl)

Predict the output of the model.

Parameters:

Name Type Description Default
dl DataLoader

The DataLoader object.

required

Returns:

Name Type Description
ModularHivaeOutput ModularHivaeOutput

The output of the model.

Source code in vambn/modelling/models/hivae/trainer.py
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
def predict(self, dl: DataLoader) -> ModularHivaeOutput:
    """
    Predict the output of the model.

    Args:
        dl (DataLoader): The DataLoader object.

    Returns:
        ModularHivaeOutput: The output of the model.
    """
    with torch.no_grad():
        return self.model.predict(dl)

process_encodings(predictions)

Process the model predictions and return the encodings.

Parameters:

Name Type Description Default
predictions ModularHivaeOutput

The model predictions.

required

Returns:

Type Description
ModularHivaeEncoding

The ModularHivaeEncoding object.

Source code in vambn/modelling/models/hivae/trainer.py
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
def process_encodings(
    self, predictions: ModularHivaeOutput
) -> ModularHivaeEncoding:
    """
    Process the model predictions and return the encodings.

    Args:
        predictions: The model predictions.

    Returns:
        The ModularHivaeEncoding object.
    """
    return ModularHivaeEncoding(
        encodings=tuple(
            HivaeEncoding(
                z=pred.enc_z,
                s=pred.enc_s,
                module=module_name,
                samples=pred.samples,
                subjid=self.dataset.subj,
            )
            for module_name, pred in zip(
                self.dataset.module_names, predictions
            )
        ),
        modules=self.dataset.module_names,
    )

read_model_config(path)

Read the model configuration from a file.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required
Source code in vambn/modelling/models/hivae/trainer.py
1857
1858
1859
1860
1861
1862
1863
1864
1865
def read_model_config(self, path: Path):
    """
    Read the model configuration from a file.

    Args:
        path (Path): The path to read the model configuration from.
    """
    with path.open("rb") as f:
        self.model_config = dill.load(f)

save_model_config(path)

Save the model configuration to a file.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required

Raises:

Type Description
Exception

If the model is None.

Exception

If the model config is None.

Source code in vambn/modelling/models/hivae/trainer.py
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
def save_model_config(self, path: Path):
    """
    Save the model configuration to a file.

    Args:
        path (Path): The path to save the model configuration.

    Raises:
        Exception: If the model is None.
        Exception: If the model config is None.
    """
    if self.model is None:
        raise Exception("Model should not be none")

    if self.model_config is None:
        raise Exception("Model config should not be none")

    with path.open("wb") as f:
        dill.dump(self.model_config, f)

TraditionalGanTrainer

Bases: TraditionalTrainer[HivaeConfig, GanHivae]

Source code in vambn/modelling/models/hivae/trainer.py
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
class TraditionalGanTrainer(TraditionalTrainer[HivaeConfig, GanHivae]):
    def init_model(self, config: HivaeConfig) -> GanHivae:
        """
        Initializes the GAN-HIVAE model.

        Args:
            config (HivaeConfig): The configuration for the GAN-HIVAE model.

        Returns:
            GanHivae: The initialized GAN-HIVAE model.
        """
        return GanHivae(
            variable_types=config.variable_types,
            input_dim=config.input_dim,
            dim_s=config.dim_s,
            dim_y=config.dim_y,
            dim_z=config.dim_z,
            module_name=config.name,
            mtl_method=config.mtl_methods,
            use_imputation_layer=config.use_imputation_layer,
            individual_model=True,
            n_layers=config.n_layers,
            num_timepoints=config.num_timepoints,
            noise_size=10,
        )

init_model(config)

Initializes the GAN-HIVAE model.

Parameters:

Name Type Description Default
config HivaeConfig

The configuration for the GAN-HIVAE model.

required

Returns:

Name Type Description
GanHivae GanHivae

The initialized GAN-HIVAE model.

Source code in vambn/modelling/models/hivae/trainer.py
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
def init_model(self, config: HivaeConfig) -> GanHivae:
    """
    Initializes the GAN-HIVAE model.

    Args:
        config (HivaeConfig): The configuration for the GAN-HIVAE model.

    Returns:
        GanHivae: The initialized GAN-HIVAE model.
    """
    return GanHivae(
        variable_types=config.variable_types,
        input_dim=config.input_dim,
        dim_s=config.dim_s,
        dim_y=config.dim_y,
        dim_z=config.dim_z,
        module_name=config.name,
        mtl_method=config.mtl_methods,
        use_imputation_layer=config.use_imputation_layer,
        individual_model=True,
        n_layers=config.n_layers,
        num_timepoints=config.num_timepoints,
        noise_size=10,
    )

TraditionalTrainer

Bases: BaseTrainer[GenericHivaeConfig, GenericHivaeModel, HivaeOutput, 'TraditionalTrainer', HivaeEncoding]

Source code in vambn/modelling/models/hivae/trainer.py
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
class TraditionalTrainer(
    BaseTrainer[
        GenericHivaeConfig,
        GenericHivaeModel,
        HivaeOutput,
        "TraditionalTrainer",
        HivaeEncoding,
    ]
):
    def __init__(
        self,
        dataset: VambnDataset,
        config: PipelineConfig,
        workers: int,
        checkpoint_path: Path,
        module_name: str | None = None,
        experiment_name: str | None = None,
        force_cpu: bool = False,
    ):
        super().__init__(
            dataset=dataset,
            config=config,
            workers=workers,
            checkpoint_path=checkpoint_path,
            module_name=module_name,
            experiment_name=experiment_name,
            force_cpu=force_cpu,
        )
        self.type = "traditional"
        if module_name is None:
            raise ValueError("Module name must be specified")
        elif module_name not in self.dataset.module_names:
            raise ValueError(
                f"Module name {module_name} not in dataset, available modules: {self.dataset.module_names}"
            )

    def process_encodings(self, predictions: HivaeOutput) -> HivaeEncoding:
        return HivaeEncoding(
            z=predictions.enc_z,
            s=predictions.enc_s,
            module=self.module_name,
            samples=predictions.samples,
            subjid=self.dataset.subj,
        )

    def init_model(self, config: GenericHivaeConfig) -> GenericHivaeModel:
        if config.is_longitudinal:
            return LstmHivae(
                n_layers=config.n_layers,
                num_timepoints=config.num_timepoints,
                dim_s=config.dim_s,
                dim_y=config.dim_y,
                dim_z=config.dim_z,
                input_dim=config.input_dim,
                module_name=self.module_name,
                mtl_method=config.mtl_methods,
                use_imputation_layer=config.use_imputation_layer,
                variable_types=config.variable_types,
                individual_model=True,
            )
        else:
            return Hivae(
                variable_types=config.variable_types,
                input_dim=config.input_dim,
                dim_s=config.dim_s,
                dim_y=config.dim_y,
                dim_z=config.dim_z,
                module_name=self.module_name,
                mtl_method=config.mtl_methods,
                use_imputation_layer=config.use_imputation_layer,
                individual_model=True,
            )

    def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
        """
        Function to suggest hyperparameters for the model

        Args:
            trial (optuna.Trial): Trial instance

        Returns:
            Dict[str, Any]: Suggested hyperparameters
        """
        opt = self.config.optimization
        dim_s = trial.suggest_int(
            "hidden_dim_s",
            opt.s_dim_lower,
            opt.s_dim_upper,
            opt.s_dim_step,
        )
        dim_y = trial.suggest_int(
            "hidden_dim_y",
            opt.y_dim_lower,
            opt.y_dim_upper,
            opt.y_dim_step,
        )
        dim_z = trial.suggest_int(
            "hidden_dim_z",
            opt.latent_dim_lower,
            opt.latent_dim_upper,
            opt.latent_dim_step,
        )
        learning_rate = trial.suggest_float(
            "learning_rate",
            opt.learning_rate_lower,
            opt.learning_rate_upper,
            log=True,
        )
        batch_size_n = trial.suggest_int(
            "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
        )
        batch_size = 2**batch_size_n
        # epochs = trial.suggest_int(
        #     "epochs",
        #     low=opt.epoch_lower,
        #     high=opt.epoch_upper,
        #     step=opt.epoch_step,
        # )

        if self.dataset.is_longitudinal:
            lstm_layers = trial.suggest_int(
                "lstm_layers",
                opt.lstm_layers_lower,
                opt.lstm_layers_upper,
                opt.lstm_layers_step,
            )
        else:
            lstm_layers = 1
        if self.use_mtl:
            mtl_string = trial.suggest_categorical(
                "mtl_methods",
                [
                    "gradnorm",
                    "graddrop",
                    "gradnorm,graddrop",
                ],
            )
            mtl_methods = (
                tuple(mtl_string.split(","))
                if "," in mtl_string
                else tuple([mtl_string])
            )
        else:
            mtl_methods = ("identity",)
        return Hyperparameters(
            dim_s=dim_s,
            dim_y=dim_y,
            dim_z=dim_z,
            dropout=0.1,
            batch_size=batch_size,
            learning_rate=learning_rate,
            epochs=opt.max_epochs,
            lstm_layers=lstm_layers,
            mtl_methods=mtl_methods,
        )

    def train(self, best_parameters: Hyperparameters) -> GenericHivaeModel:
        """
        Trains the model using the best hyperparameters.

        Args:
            best_parameters (Hyperparameters): The best hyperparameters obtained from optimization.

        Returns:
            GenericHivaeModel: The trained model.
        """
        ref_module = self.dataset.get_modules(self.module_name)[0]
        whole_dataloader = self.get_dataloader(
            self.dataset, best_parameters.batch_size, shuffle=True
        )
        self.model_config = HivaeConfig(
            name=self.module_name,
            variable_types=ref_module.variable_types,
            dim_s=best_parameters.dim_s,
            dim_y=best_parameters.dim_y,
            dim_z=best_parameters.dim_z,
            mtl_methods=best_parameters.mtl_methods,
            use_imputation_layer=self.config.training.use_imputation_layer,
            dropout=best_parameters.dropout,
            n_layers=best_parameters.lstm_layers,
            num_timepoints=self.dataset.visits_per_module[self.module_name],
        )
        self.model = self.init_model(self.model_config).to(self.device)
        self.model.device = self.device
        self.optimize_model()
        with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
            mlflow.log_params(best_parameters.__dict__)
            self.model.fit(
                train_dataloader=whole_dataloader,
                num_epochs=best_parameters.epochs,
                learning_rate=best_parameters.learning_rate,
            )
        return self.model

    def _objective(self, trial: optuna.Trial) -> float | Tuple[float, float]:
        """
        Objective function for the optimization process.

        Args:
            trial (optuna.Trial): The trial object for hyperparameter optimization.

        Returns:
            float or Tuple[float, float]: The loss value or a tuple of loss values.
        """
        trial_params = self.hyperparameters(trial)
        logger.info(f"Trial parameters: {trial_params}")
        fold_loss = []
        rel_corr_loss = []
        auc_metric = []
        ref_module = self.dataset.get_modules(self.module_name)[0]
        n_epochs = []

        for i, (train_data, val_data) in enumerate(
            self.cv_generator(self.config.optimization.folds)
        ):
            train_dataloader = self.get_dataloader(
                train_data, trial_params.batch_size, shuffle=True
            )
            val_dataloader = self.get_dataloader(
                val_data, trial_params.batch_size, shuffle=False
            )

            model_config = HivaeConfig(
                name=self.module_name,
                variable_types=ref_module.variable_types,
                dim_s=trial_params.dim_s,
                dim_y=trial_params.dim_y,
                dim_z=trial_params.dim_z,
                mtl_methods=trial_params.mtl_methods,
                use_imputation_layer=self.config.training.use_imputation_layer,
                dropout=trial_params.dropout,
                n_layers=trial_params.lstm_layers,
                num_timepoints=self.dataset.visits_per_module[self.module_name],
            )
            raw_model = self.init_model(model_config)
            raw_model.device = self.device
            model = self.optimize_model(raw_model)
            model.to(self.device)
            with mlflow.start_run(
                run_name=f"{self.run_base}_T{trial._trial_id}_F{i}",
                nested=True,
            ):
                mlflow.log_params(trial_params.__dict__)
                start = time.time()
                try:
                    fit_loss, fit_epoch = model.fit(
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        num_epochs=trial_params.epochs,
                        learning_rate=trial_params.learning_rate,
                    )
                except ValueError:
                    raise optuna.TrialPruned(
                        "Trial pruned due to error during training. Likely due to unsuitable hyperparameters"
                    )

                if ((time.time() - start) / 3600) > 2:
                    logger.warning(
                        f"Trial pruned due to very long execution ({(time.time() - start) / 3600}h at fold {i})"
                    )
                    raise optuna.TrialPruned()
                output = model.predict(val_dataloader)

                mlflow.log_metric("Final loss", output.avg_loss)
                logger.info(f"Fold {i} loss: {output.avg_loss}")
                fold_loss.append(output.avg_loss)
                n_epochs.append(fit_epoch)

                if (
                    self.config.optimization.use_relative_correlation_error_for_optimization
                    or self.config.optimization.use_auc_for_optimization
                ):
                    # Calculate the relative correlation loss
                    orig_data = val_data.to_pandas(module_name=self.module_name)
                    # convert decoded samples into a pandas dataframe
                    decoded_samples = output.samples
                    assert decoded_samples.shape[-1] == (
                        orig_data.shape[-1] - 2
                    )
                    column_names = [
                        re.sub(r"_VIS[0-9]+", "", c)
                        for c in self.dataset.get_modules(self.module_name)[
                            0
                        ].columns
                    ]
                    if decoded_samples.ndim == 2:
                        synthetic_data = pd.DataFrame(
                            decoded_samples.numpy(),
                            columns=column_names,
                        )
                        synthetic_data["SUBJID"] = val_data.subj
                        synthetic_data["VISIT"] = 1
                    else:
                        synthetic_data = pd.concat(
                            [
                                pd.DataFrame(
                                    decoded_samples[:, i, :].numpy(),
                                    columns=column_names,
                                )
                                for i in range(decoded_samples.shape[1])
                            ]
                        )
                        synthetic_data["SUBJID"] = (
                            val_data.subj
                            * self.dataset.visits_per_module[self.module_name]
                        )
                        synthetic_data["VISIT"] = np.repeat(
                            np.arange(1, decoded_samples.shape[1] + 1),
                            len(val_data),
                        )

                    synthetic_data_filtered = synthetic_data.loc[
                        :, orig_data.columns
                    ].drop(columns=["SUBJID", "VISIT"])
                    orig_data_filtered = orig_data.drop(
                        columns=["SUBJID", "VISIT"]
                    )

                    fold_rel_corr_loss, m1, m2 = RelativeCorrelation.error(
                        real=orig_data_filtered,
                        synthetic=synthetic_data_filtered,
                    )

                    auc = get_auc(orig_data_filtered, synthetic_data_filtered)
                    if auc < 0.5:
                        auc = 1 - auc
                    auc_quality = max(math.floor((1 - auc) * 200), 1)
                    mlflow.log_metric("AUC", auc)
                    mlflow.log_metric("AUC quality", auc_quality)
                    auc_metric.append(auc_quality)

                    mlflow.log_metric(
                        "Relative correlation loss", fold_rel_corr_loss
                    )
                    rel_corr_loss.append(fold_rel_corr_loss)

        loss = np.mean(fold_loss)
        avg_auc_metric = np.mean(auc_metric)

        # Get the n_epochs that correspond to the best loss
        best_epoch = n_epochs[np.argmin(fold_loss)]
        mlflow.log_metric("Best epoch", best_epoch)
        # Log the value in the trial
        trial.set_user_attr("best_epoch", best_epoch)
        trial.set_user_attr("n_epochs", n_epochs)

        if self.config.optimization.use_relative_correlation_error_for_optimization:
            rel_corr_loss = np.mean(rel_corr_loss)
            logger.info(f"Trial loss: {loss}; rel_corr_loss: {rel_corr_loss}")
            return loss, rel_corr_loss
        elif self.config.optimization.use_auc_for_optimization:
            return avg_auc_metric
        else:
            logger.info(f"Trial loss: {loss}")
            return loss

    def save_model_config(self, path: Path):
        """
        Saves the model configuration to a file.

        Args:
            path (Path): The path to save the model configuration.
        """
        if self.model is None:
            raise Exception("Model should not be none")

        if self.model_config is None:
            raise Exception("Model config should not be none")

        with path.open("wb") as f:
            dill.dump(self.model_config, f)

    def read_model_config(self, path: Path):
        """
        Reads the model configuration from a file.

        Args:
            path (Path): The path to read the model configuration from.
        """
        with path.open("rb") as f:
            self.model_config = dill.load(f)

hyperparameters(trial)

Function to suggest hyperparameters for the model

Parameters:

Name Type Description Default
trial Trial

Trial instance

required

Returns:

Type Description
Hyperparameters

Dict[str, Any]: Suggested hyperparameters

Source code in vambn/modelling/models/hivae/trainer.py
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
def hyperparameters(self, trial: optuna.Trial) -> Hyperparameters:
    """
    Function to suggest hyperparameters for the model

    Args:
        trial (optuna.Trial): Trial instance

    Returns:
        Dict[str, Any]: Suggested hyperparameters
    """
    opt = self.config.optimization
    dim_s = trial.suggest_int(
        "hidden_dim_s",
        opt.s_dim_lower,
        opt.s_dim_upper,
        opt.s_dim_step,
    )
    dim_y = trial.suggest_int(
        "hidden_dim_y",
        opt.y_dim_lower,
        opt.y_dim_upper,
        opt.y_dim_step,
    )
    dim_z = trial.suggest_int(
        "hidden_dim_z",
        opt.latent_dim_lower,
        opt.latent_dim_upper,
        opt.latent_dim_step,
    )
    learning_rate = trial.suggest_float(
        "learning_rate",
        opt.learning_rate_lower,
        opt.learning_rate_upper,
        log=True,
    )
    batch_size_n = trial.suggest_int(
        "batch_size_n", opt.batch_size_lower_n, opt.batch_size_upper_n
    )
    batch_size = 2**batch_size_n
    # epochs = trial.suggest_int(
    #     "epochs",
    #     low=opt.epoch_lower,
    #     high=opt.epoch_upper,
    #     step=opt.epoch_step,
    # )

    if self.dataset.is_longitudinal:
        lstm_layers = trial.suggest_int(
            "lstm_layers",
            opt.lstm_layers_lower,
            opt.lstm_layers_upper,
            opt.lstm_layers_step,
        )
    else:
        lstm_layers = 1
    if self.use_mtl:
        mtl_string = trial.suggest_categorical(
            "mtl_methods",
            [
                "gradnorm",
                "graddrop",
                "gradnorm,graddrop",
            ],
        )
        mtl_methods = (
            tuple(mtl_string.split(","))
            if "," in mtl_string
            else tuple([mtl_string])
        )
    else:
        mtl_methods = ("identity",)
    return Hyperparameters(
        dim_s=dim_s,
        dim_y=dim_y,
        dim_z=dim_z,
        dropout=0.1,
        batch_size=batch_size,
        learning_rate=learning_rate,
        epochs=opt.max_epochs,
        lstm_layers=lstm_layers,
        mtl_methods=mtl_methods,
    )

read_model_config(path)

Reads the model configuration from a file.

Parameters:

Name Type Description Default
path Path

The path to read the model configuration from.

required
Source code in vambn/modelling/models/hivae/trainer.py
1332
1333
1334
1335
1336
1337
1338
1339
1340
def read_model_config(self, path: Path):
    """
    Reads the model configuration from a file.

    Args:
        path (Path): The path to read the model configuration from.
    """
    with path.open("rb") as f:
        self.model_config = dill.load(f)

save_model_config(path)

Saves the model configuration to a file.

Parameters:

Name Type Description Default
path Path

The path to save the model configuration.

required
Source code in vambn/modelling/models/hivae/trainer.py
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
def save_model_config(self, path: Path):
    """
    Saves the model configuration to a file.

    Args:
        path (Path): The path to save the model configuration.
    """
    if self.model is None:
        raise Exception("Model should not be none")

    if self.model_config is None:
        raise Exception("Model config should not be none")

    with path.open("wb") as f:
        dill.dump(self.model_config, f)

train(best_parameters)

Trains the model using the best hyperparameters.

Parameters:

Name Type Description Default
best_parameters Hyperparameters

The best hyperparameters obtained from optimization.

required

Returns:

Name Type Description
GenericHivaeModel GenericHivaeModel

The trained model.

Source code in vambn/modelling/models/hivae/trainer.py
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def train(self, best_parameters: Hyperparameters) -> GenericHivaeModel:
    """
    Trains the model using the best hyperparameters.

    Args:
        best_parameters (Hyperparameters): The best hyperparameters obtained from optimization.

    Returns:
        GenericHivaeModel: The trained model.
    """
    ref_module = self.dataset.get_modules(self.module_name)[0]
    whole_dataloader = self.get_dataloader(
        self.dataset, best_parameters.batch_size, shuffle=True
    )
    self.model_config = HivaeConfig(
        name=self.module_name,
        variable_types=ref_module.variable_types,
        dim_s=best_parameters.dim_s,
        dim_y=best_parameters.dim_y,
        dim_z=best_parameters.dim_z,
        mtl_methods=best_parameters.mtl_methods,
        use_imputation_layer=self.config.training.use_imputation_layer,
        dropout=best_parameters.dropout,
        n_layers=best_parameters.lstm_layers,
        num_timepoints=self.dataset.visits_per_module[self.module_name],
    )
    self.model = self.init_model(self.model_config).to(self.device)
    self.model.device = self.device
    self.optimize_model()
    with mlflow.start_run(run_name=f"{self.run_base}_final_fit"):
        mlflow.log_params(best_parameters.__dict__)
        self.model.fit(
            train_dataloader=whole_dataloader,
            num_epochs=best_parameters.epochs,
            learning_rate=best_parameters.learning_rate,
        )
    return self.model

timed(fn)

Decorator to time a function for benchmarking purposes.

Parameters:

Name Type Description Default
fn Callable

Function to be timed.

required

Returns:

Type Description
Tuple[Any, float]

Tuple[Any, float]: Result of the function and the time taken to execute it.

Source code in vambn/modelling/models/hivae/trainer.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def timed(fn: Callable) -> Tuple[Any, float]:
    """
    Decorator to time a function for benchmarking purposes.

    Args:
        fn (Callable): Function to be timed.

    Returns:
        Tuple[Any, float]: Result of the function and the time taken to execute it.
    """

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000