Skip to content

Data

dataclasses

VariableType dataclass

Dataclass for storing type information of the input variables

Source code in vambn/data/dataclasses.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@dataclass
class VariableType:
    """Dataclass for storing type information of the input variables"""

    name: str = "default"
    data_type: str = "default"
    n_parameters: int = -1
    input_dim: int = 1
    scaler: Optional[StandardScaler] = None

    def __eq__(self, __value: "VariableType") -> bool:
        """
        Test if two VariableTypes are equal

        Args:
            __value (VariableType): Second VariableType object

        Returns:
            bool: True or false
        """
        if not isinstance(__value, VariableType):
            return False
        return (
            self.name == __value.name
            and self.data_type == __value.data_type
            and self.n_parameters == __value.n_parameters
            and self.input_dim == __value.input_dim
        )

    def __post_init__(self):
        if isinstance(self.input_dim, float):
            raise TypeError("input_dim must be int, not float")
        if isinstance(self.n_parameters, float):
            raise TypeError("n_parameters must be int, not float")

    @typeguard.typechecked
    def reverse_scale(self, x: torch.Tensor) -> torch.Tensor:
        """
        Use the variable's scaler to invert the transformation so that the original input format is restored.

        Args:
            x (torch.tensor): The transformed input tensor.

        Returns:
            torch.Tensor: Inverse transformed output
        """
        if self.scaler is not None:
            # x can be of shape (time, batch) or (batch)
            orig_shape = x.shape
            dev = x.device
            if x.ndim == 2:
                # time x batch
                transformed = torch.from_numpy(
                    self.scaler.inverse_transform(x.cpu().numpy())
                ).to(dev)
                if orig_shape != transformed.shape:
                    raise RuntimeError(
                        f"Shape of transformed data is {transformed.shape}, expected {orig_shape}"
                    )
                return transformed
            elif x.ndim == 1:
                # batch
                x = x.view(-1, 1)
                transformed = (
                    torch.from_numpy(
                        self.scaler.inverse_transform(x.cpu().numpy())
                    )
                    .to(dev)
                    .view(-1)
                )
                if orig_shape != transformed.shape:
                    raise RuntimeError(
                        f"Shape of transformed data is {transformed.shape}, expected {orig_shape}"
                    )
                return transformed
            else:
                raise RuntimeError(
                    f"Input data must have 1 or 2 dimensions, got {x.ndim}"
                )
        else:
            # raise RuntimeError("Scaler is not defined")
            return x

__eq__(__value)

Test if two VariableTypes are equal

Parameters:

Name Type Description Default
__value VariableType

Second VariableType object

required

Returns:

Name Type Description
bool bool

True or false

Source code in vambn/data/dataclasses.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __eq__(self, __value: "VariableType") -> bool:
    """
    Test if two VariableTypes are equal

    Args:
        __value (VariableType): Second VariableType object

    Returns:
        bool: True or false
    """
    if not isinstance(__value, VariableType):
        return False
    return (
        self.name == __value.name
        and self.data_type == __value.data_type
        and self.n_parameters == __value.n_parameters
        and self.input_dim == __value.input_dim
    )

reverse_scale(x)

Use the variable's scaler to invert the transformation so that the original input format is restored.

Parameters:

Name Type Description Default
x tensor

The transformed input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: Inverse transformed output

Source code in vambn/data/dataclasses.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@typeguard.typechecked
def reverse_scale(self, x: torch.Tensor) -> torch.Tensor:
    """
    Use the variable's scaler to invert the transformation so that the original input format is restored.

    Args:
        x (torch.tensor): The transformed input tensor.

    Returns:
        torch.Tensor: Inverse transformed output
    """
    if self.scaler is not None:
        # x can be of shape (time, batch) or (batch)
        orig_shape = x.shape
        dev = x.device
        if x.ndim == 2:
            # time x batch
            transformed = torch.from_numpy(
                self.scaler.inverse_transform(x.cpu().numpy())
            ).to(dev)
            if orig_shape != transformed.shape:
                raise RuntimeError(
                    f"Shape of transformed data is {transformed.shape}, expected {orig_shape}"
                )
            return transformed
        elif x.ndim == 1:
            # batch
            x = x.view(-1, 1)
            transformed = (
                torch.from_numpy(
                    self.scaler.inverse_transform(x.cpu().numpy())
                )
                .to(dev)
                .view(-1)
            )
            if orig_shape != transformed.shape:
                raise RuntimeError(
                    f"Shape of transformed data is {transformed.shape}, expected {orig_shape}"
                )
            return transformed
        else:
            raise RuntimeError(
                f"Input data must have 1 or 2 dimensions, got {x.ndim}"
            )
    else:
        # raise RuntimeError("Scaler is not defined")
        return x

check_equal_types(a, b)

Check if two lists of variable types are equal

Parameters:

Name Type Description Default
a List[VariableType]

First list of variable types

required
b List[VariableType]

Second list of variable types

required

Returns:

Name Type Description
bool bool

True if equal, False otherwise

Source code in vambn/data/dataclasses.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def check_equal_types(a: List[VariableType], b: List[VariableType]) -> bool:
    """
    Check if two lists of variable types are equal

    Args:
        a (List[VariableType]): First list of variable types
        b (List[VariableType]): Second list of variable types

    Returns:
        bool: True if equal, False otherwise
    """
    if len(a) != len(b):
        return False

    for x, y in zip(a, b):
        if x != y:
            return False

    return True

get_input_dim(types)

Get the input dimension of a list of variable types

Parameters:

Name Type Description Default
types List[VariableType]

List of variable types

required

Returns:

Name Type Description
int int

Sum of input dimensions

Source code in vambn/data/dataclasses.py
 96
 97
 98
 99
100
101
102
103
104
105
106
def get_input_dim(types: List[VariableType]) -> int:
    """
    Get the input dimension of a list of variable types

    Args:
        types (List[VariableType]): List of variable types

    Returns:
        int: Sum of input dimensions
    """
    return int(sum([x.input_dim for x in types]))

datasets

IterDataset

Bases: Dataset

Source code in vambn/data/datasets.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class IterDataset(Dataset):
    def __init__(
        self,
        data: torch.Tensor,
        missing_mask: torch.Tensor,
        types: List[VariableType],
    ) -> None:
        """
        Initialize the IterDataset.

        Args:
            data (torch.Tensor): Tensor containing the data.
            missing_mask (torch.Tensor): Tensor with the corresponding missing mask (0=missing, 1=available).
            types (List[VariableType]): List of VariableType, each specifying dtype, ndim, and nclasses.

        Raises:
            ValueError: If data or missing_mask is not 2-dimensional, or if data contains NaN values.
        """
        if data.ndim != 2 or missing_mask.ndim != 2:
            raise ValueError(
                "Both data and missing_mask tensors must be 2-dimensional"
            )

        # Vector with Samples x Features
        self.data = data
        self.missing_mask = missing_mask
        self.types = types
        self.num_visits = 1

        if torch.isnan(self.data).any():
            raise ValueError("Data contains NaN values, which are not allowed.")

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
        """
        Get the data and the missing mask for a specific index.

        Args:
            idx (int): Index of the sample.

        Returns:
            Tuple[Tensor, Tensor]: Data and missing mask for the sample.
        """
        data = self.data[idx, :]
        mask = self.missing_mask[idx, :]

        return data, mask

    def __len__(self) -> int:
        """
        Get the number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
        return len(self.data)

    def subset(self, idx: List[int]) -> "IterDataset":
        """
        Create a subset of the dataset.

        Args:
            idx (List[int]): Indices of the samples to be selected.

        Returns:
            IterDataset: Subset of the dataset.
        """
        return IterDataset(
            self.data[idx, :], self.missing_mask[idx, :], self.types
        )

    @property
    def ndim(self) -> int:
        """
        Input dimensionality of the dataset.

        Returns:
            int: Dimensionality of the dataset.
        """
        return get_input_dim(self.types)

    def to(self, device: torch.device) -> "IterDataset":
        """
        Move the dataset to a specified device.

        Args:
            device (torch.device): The device to move the dataset to.

        Returns:
            IterDataset: The dataset moved to the specified device.
        """
        self.data = self.data.to(device)
        self.missing_mask = self.missing_mask.to(device)
        return self

ndim: int property

Input dimensionality of the dataset.

Returns:

Name Type Description
int int

Dimensionality of the dataset.

__getitem__(idx)

Get the data and the missing mask for a specific index.

Parameters:

Name Type Description Default
idx int

Index of the sample.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: Data and missing mask for the sample.

Source code in vambn/data/datasets.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
    """
    Get the data and the missing mask for a specific index.

    Args:
        idx (int): Index of the sample.

    Returns:
        Tuple[Tensor, Tensor]: Data and missing mask for the sample.
    """
    data = self.data[idx, :]
    mask = self.missing_mask[idx, :]

    return data, mask

__init__(data, missing_mask, types)

Initialize the IterDataset.

Parameters:

Name Type Description Default
data Tensor

Tensor containing the data.

required
missing_mask Tensor

Tensor with the corresponding missing mask (0=missing, 1=available).

required
types List[VariableType]

List of VariableType, each specifying dtype, ndim, and nclasses.

required

Raises:

Type Description
ValueError

If data or missing_mask is not 2-dimensional, or if data contains NaN values.

Source code in vambn/data/datasets.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    data: torch.Tensor,
    missing_mask: torch.Tensor,
    types: List[VariableType],
) -> None:
    """
    Initialize the IterDataset.

    Args:
        data (torch.Tensor): Tensor containing the data.
        missing_mask (torch.Tensor): Tensor with the corresponding missing mask (0=missing, 1=available).
        types (List[VariableType]): List of VariableType, each specifying dtype, ndim, and nclasses.

    Raises:
        ValueError: If data or missing_mask is not 2-dimensional, or if data contains NaN values.
    """
    if data.ndim != 2 or missing_mask.ndim != 2:
        raise ValueError(
            "Both data and missing_mask tensors must be 2-dimensional"
        )

    # Vector with Samples x Features
    self.data = data
    self.missing_mask = missing_mask
    self.types = types
    self.num_visits = 1

    if torch.isnan(self.data).any():
        raise ValueError("Data contains NaN values, which are not allowed.")

__len__()

Get the number of samples in the dataset.

Returns:

Name Type Description
int int

Number of samples.

Source code in vambn/data/datasets.py
71
72
73
74
75
76
77
78
def __len__(self) -> int:
    """
    Get the number of samples in the dataset.

    Returns:
        int: Number of samples.
    """
    return len(self.data)

subset(idx)

Create a subset of the dataset.

Parameters:

Name Type Description Default
idx List[int]

Indices of the samples to be selected.

required

Returns:

Name Type Description
IterDataset IterDataset

Subset of the dataset.

Source code in vambn/data/datasets.py
80
81
82
83
84
85
86
87
88
89
90
91
92
def subset(self, idx: List[int]) -> "IterDataset":
    """
    Create a subset of the dataset.

    Args:
        idx (List[int]): Indices of the samples to be selected.

    Returns:
        IterDataset: Subset of the dataset.
    """
    return IterDataset(
        self.data[idx, :], self.missing_mask[idx, :], self.types
    )

to(device)

Move the dataset to a specified device.

Parameters:

Name Type Description Default
device device

The device to move the dataset to.

required

Returns:

Name Type Description
IterDataset IterDataset

The dataset moved to the specified device.

Source code in vambn/data/datasets.py
104
105
106
107
108
109
110
111
112
113
114
115
116
def to(self, device: torch.device) -> "IterDataset":
    """
    Move the dataset to a specified device.

    Args:
        device (torch.device): The device to move the dataset to.

    Returns:
        IterDataset: The dataset moved to the specified device.
    """
    self.data = self.data.to(device)
    self.missing_mask = self.missing_mask.to(device)
    return self

LongitudinalDataset

Bases: Dataset

Dataset for longitudinal data, where each sample consists of multiple visits/timepoints.

Attributes:

Name Type Description
data Tensor

Tensor containing the data.

missing_mask Tensor

Tensor indicating missing data (0=missing, 1=available).

types List[VariableType]

List of VariableType objects containing dtype, ndim, and nclasses.

Source code in vambn/data/datasets.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class LongitudinalDataset(Dataset):
    """
    Dataset for longitudinal data, where each sample consists of multiple visits/timepoints.

    Attributes:
        data (torch.Tensor): Tensor containing the data.
        missing_mask (torch.Tensor): Tensor indicating missing data (0=missing, 1=available).
        types (List[VariableType]): List of VariableType objects containing dtype, ndim, and nclasses.
    """

    def __init__(
        self,
        data: torch.Tensor,
        missing_mask: torch.Tensor,
        types: List[VariableType],
    ) -> None:
        """
        Initialize the LongitudinalDataset.

        Args:
            data (torch.Tensor): Tensor containing the data.
            missing_mask (torch.Tensor): Tensor indicating missing data (0=missing, 1=available).
            types (List[VariableType]): List of VariableType objects containing dtype, ndim, and nclasses.

        Raises:
            ValueError: If `data` or `missing_mask` is not a 3-dimensional tensor.
            ValueError: If `data` contains NaN values.
        """
        if data.ndim != 3 or missing_mask.ndim != 3:
            raise ValueError(
                "Both data and missing_mask tensors must be 3-dimensional"
            )

        # Array with Time x Samples x Features
        self.data = data
        self.missing_mask = missing_mask
        self.types = types

        if torch.isnan(self.data).any():
            raise ValueError("Data contains NaN values, which are not allowed.")

    def to(self, device: torch.device) -> "LongitudinalDataset":
        """
        Move the dataset to the specified device.

        Args:
            device (torch.device): The device to which the data and mask should be moved.

        Returns:
            LongitudinalDataset: The dataset on the specified device.
        """
        self.data = self.data.to(device)
        self.missing_mask = self.missing_mask.to(device)
        return self

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
        """
        Get the longitudinal data and the missing mask for a specific index.

        Args:
            idx (int): Index of the sample.

        Returns:
            Tuple[Tensor, Tensor]: 3D tensor with the data and 3D tensor with the missing mask for the sample.
        """
        s_data = self.data[:, idx, :]
        s_mask = self.missing_mask[:, idx, :]

        return s_data, s_mask

    def __len__(self) -> int:
        """
        Number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
        return self.data.shape[1]

    def subset(self, idx: List[int]) -> "LongitudinalDataset":
        """
        Create a subset of the dataset.

        Args:
            idx (List[int]): Indices of the samples to be selected.

        Returns:
            LongitudinalDataset: Subset of the dataset.
        """
        return LongitudinalDataset(
            self.data[:, idx, :], self.missing_mask[:, idx, :], self.types
        )

    @property
    def ndim(self) -> int:
        """
        Input dimensionality of the dataset.

        Returns:
            int: Input dimensionality of the dataset.
        """
        return get_input_dim(self.types)

    @property
    def num_visits(self) -> int:
        """
        Number of visits/timepoints in the dataset.

        Returns:
            int: Number of visits/timepoints.
        """
        return self.data.shape[0]

ndim: int property

Input dimensionality of the dataset.

Returns:

Name Type Description
int int

Input dimensionality of the dataset.

num_visits: int property

Number of visits/timepoints in the dataset.

Returns:

Name Type Description
int int

Number of visits/timepoints.

__getitem__(idx)

Get the longitudinal data and the missing mask for a specific index.

Parameters:

Name Type Description Default
idx int

Index of the sample.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: 3D tensor with the data and 3D tensor with the missing mask for the sample.

Source code in vambn/data/datasets.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
    """
    Get the longitudinal data and the missing mask for a specific index.

    Args:
        idx (int): Index of the sample.

    Returns:
        Tuple[Tensor, Tensor]: 3D tensor with the data and 3D tensor with the missing mask for the sample.
    """
    s_data = self.data[:, idx, :]
    s_mask = self.missing_mask[:, idx, :]

    return s_data, s_mask

__init__(data, missing_mask, types)

Initialize the LongitudinalDataset.

Parameters:

Name Type Description Default
data Tensor

Tensor containing the data.

required
missing_mask Tensor

Tensor indicating missing data (0=missing, 1=available).

required
types List[VariableType]

List of VariableType objects containing dtype, ndim, and nclasses.

required

Raises:

Type Description
ValueError

If data or missing_mask is not a 3-dimensional tensor.

ValueError

If data contains NaN values.

Source code in vambn/data/datasets.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def __init__(
    self,
    data: torch.Tensor,
    missing_mask: torch.Tensor,
    types: List[VariableType],
) -> None:
    """
    Initialize the LongitudinalDataset.

    Args:
        data (torch.Tensor): Tensor containing the data.
        missing_mask (torch.Tensor): Tensor indicating missing data (0=missing, 1=available).
        types (List[VariableType]): List of VariableType objects containing dtype, ndim, and nclasses.

    Raises:
        ValueError: If `data` or `missing_mask` is not a 3-dimensional tensor.
        ValueError: If `data` contains NaN values.
    """
    if data.ndim != 3 or missing_mask.ndim != 3:
        raise ValueError(
            "Both data and missing_mask tensors must be 3-dimensional"
        )

    # Array with Time x Samples x Features
    self.data = data
    self.missing_mask = missing_mask
    self.types = types

    if torch.isnan(self.data).any():
        raise ValueError("Data contains NaN values, which are not allowed.")

__len__()

Number of samples in the dataset.

Returns:

Name Type Description
int int

Number of samples.

Source code in vambn/data/datasets.py
189
190
191
192
193
194
195
196
def __len__(self) -> int:
    """
    Number of samples in the dataset.

    Returns:
        int: Number of samples.
    """
    return self.data.shape[1]

subset(idx)

Create a subset of the dataset.

Parameters:

Name Type Description Default
idx List[int]

Indices of the samples to be selected.

required

Returns:

Name Type Description
LongitudinalDataset LongitudinalDataset

Subset of the dataset.

Source code in vambn/data/datasets.py
198
199
200
201
202
203
204
205
206
207
208
209
210
def subset(self, idx: List[int]) -> "LongitudinalDataset":
    """
    Create a subset of the dataset.

    Args:
        idx (List[int]): Indices of the samples to be selected.

    Returns:
        LongitudinalDataset: Subset of the dataset.
    """
    return LongitudinalDataset(
        self.data[:, idx, :], self.missing_mask[:, idx, :], self.types
    )

to(device)

Move the dataset to the specified device.

Parameters:

Name Type Description Default
device device

The device to which the data and mask should be moved.

required

Returns:

Name Type Description
LongitudinalDataset LongitudinalDataset

The dataset on the specified device.

Source code in vambn/data/datasets.py
160
161
162
163
164
165
166
167
168
169
170
171
172
def to(self, device: torch.device) -> "LongitudinalDataset":
    """
    Move the dataset to the specified device.

    Args:
        device (torch.device): The device to which the data and mask should be moved.

    Returns:
        LongitudinalDataset: The dataset on the specified device.
    """
    self.data = self.data.to(device)
    self.missing_mask = self.missing_mask.to(device)
    return self

ModuleDataset dataclass

A class to represent a module dataset.

Attributes:

Name Type Description
name str

The name of the dataset.

data DataFrame

The data of the dataset.

mask DataFrame

The mask for the dataset.

variable_types List[VariableType]

The variable types for the dataset.

scalers Tuple[Optional[StandardScaler | LogStdScaler]]

The scalers for the dataset.

columns List[str]

The columns of the dataset.

subjects List[str]

The subjects in the dataset.

visit_number int

The visit number. Defaults to 1.

id_name Optional[str]

The ID name for the dataset. Defaults to None.

ndim int

The number of dimensions. Defaults to -1.

device device

The device to use. Defaults to torch.device("cpu").

move_to_device bool

Whether to move the data to the specified device. Defaults to True.

Source code in vambn/data/datasets.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
@dataclass
class ModuleDataset:
    """
    A class to represent a module dataset.

    Attributes:
        name (str): The name of the dataset.
        data (pd.DataFrame): The data of the dataset.
        mask (pd.DataFrame): The mask for the dataset.
        variable_types (List[VariableType]): The variable types for the dataset.
        scalers (Tuple[Optional[StandardScaler | LogStdScaler]]): The scalers for the dataset.
        columns (List[str]): The columns of the dataset.
        subjects (List[str]): The subjects in the dataset.
        visit_number (int): The visit number. Defaults to 1.
        id_name (Optional[str]): The ID name for the dataset. Defaults to None.
        ndim (int): The number of dimensions. Defaults to -1.
        device (torch.device): The device to use. Defaults to torch.device("cpu").
        move_to_device (bool): Whether to move the data to the specified device. Defaults to True.
    """

    name: str
    data: pd.DataFrame
    mask: pd.DataFrame
    variable_types: List[VariableType]
    scalers: Tuple[Optional[StandardScaler | LogStdScaler]]

    columns: List[str]
    subjects: List[str]
    visit_number: int = 1
    id_name: Optional[str] = None
    ndim: int = -1
    device: torch.device = torch.device("cpu")
    move_to_device: bool = True

    def __post_init__(self):
        """
        Validate and initialize the dataset attributes after the object is created.

        Raises:
            Exception: If the length of variable_types does not match the number of columns in data.
            Exception: If the length of columns does not match the number of columns in data or variable_types.
            Exception: If the length of subjects does not match the number of rows in data.
        """
        if len(self.variable_types) != self.data.shape[1]:
            raise Exception("Types do not match to the data")

        if self.columns is None:
            logger.warning("No columns found, using column names from data")
            self.columns = self.data.columns.tolist()

        if self.subjects is None or len(self.subjects) != self.data.shape[0]:
            logger.warning("No subjects found, using index as subjects")
            self.subjects = self.data.index.tolist()

        self.ndim = get_input_dim(self.variable_types)
        if self.id_name is None:
            self.id_name = f"{self.name}_VIS{self.visit_number}"

        if len(self.columns) != self.data.shape[1] or len(self.columns) != len(
            self.variable_types
        ):
            raise Exception("Columns do not match to the data")

        if len(self.subjects) != self.data.shape[0]:
            raise Exception(
                f"Subjects do not match to the data. Found shape {self.data.shape[0]} and {len(self.subjects)} subjects."
            )

    def to(self, device: torch.device) -> "ModuleDataset":
        """
        Move the dataset to a specific device.

        Args:
            device (torch.device): Device to be used.

        Returns:
            ModuleDataset: Dataset on the specified device.
        """
        self.device = device
        return self

    @property
    def input_data(self) -> Tensor:
        """
        Get the input data as a tensor with nan values replaced by 0.

        Returns:
            Tensor: Data tensor.
        """
        x = torch.tensor(self.data.values).float()
        if self.move_to_device:
            x = x.to(self.device)
        return x.nan_to_num()

    @property
    def input_mask(self) -> Tensor:
        """
        Get the input mask as a tensor with nan values replaced by 0.

        Returns:
            Tensor: Mask tensor.
        """
        mask = torch.tensor(self.mask.values).float()
        if self.move_to_device:
            mask = mask.to(self.device)
        return mask.nan_to_num()

    def subset(self, idx: List[int] | np.ndarray) -> "ModuleDataset":
        """
        Subset the data and mask by a list of indices.

        Args:
            idx (List[int] | np.ndarray): Indices of the samples to be selected.

        Returns:
            ModuleDataset: New ModuleDataset object with the subsetted data and mask.
        """
        return ModuleDataset(
            name=self.name,
            data=self.data.iloc[idx, :],
            mask=self.mask.iloc[idx, :],
            variable_types=self.variable_types,
            scalers=self.scalers,
            columns=self.columns,
            subjects=self.subjects,
            visit_number=self.visit_number,
            id_name=self.id_name,
            ndim=self.ndim,
        )

    @property
    def pytorch_dataset(self) -> IterDataset:
        """
        Get a PyTorch compatible dataset based on the given data and mask.

        Returns:
            IterDataset: PyTorch compatible dataset.
        """
        return IterDataset(
            self.input_data, self.input_mask, self.variable_types
        )

    def to_pandas(self) -> pd.DataFrame:
        """
        Convert the data and mask to pandas DataFrame.

        Returns:
            pd.DataFrame: Data and mask as pandas DataFrame.
        """
        out_df = self.data.copy()
        out_df.columns = [re.sub(r"_VIS\d+", "", x) for x in self.columns]
        out_df["SUBJID"] = self.subjects
        out_df["VISIT"] = self.visit_number
        return out_df

    def __str__(self) -> str:
        """
        String representation of the ModuleDataset object.

        Returns:
            str: A string representation of the ModuleDataset.
        """
        return f"ModuleData ({self.name}, {self.visit_number})"

input_data: Tensor property

Get the input data as a tensor with nan values replaced by 0.

Returns:

Name Type Description
Tensor Tensor

Data tensor.

input_mask: Tensor property

Get the input mask as a tensor with nan values replaced by 0.

Returns:

Name Type Description
Tensor Tensor

Mask tensor.

pytorch_dataset: IterDataset property

Get a PyTorch compatible dataset based on the given data and mask.

Returns:

Name Type Description
IterDataset IterDataset

PyTorch compatible dataset.

__post_init__()

Validate and initialize the dataset attributes after the object is created.

Raises:

Type Description
Exception

If the length of variable_types does not match the number of columns in data.

Exception

If the length of columns does not match the number of columns in data or variable_types.

Exception

If the length of subjects does not match the number of rows in data.

Source code in vambn/data/datasets.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def __post_init__(self):
    """
    Validate and initialize the dataset attributes after the object is created.

    Raises:
        Exception: If the length of variable_types does not match the number of columns in data.
        Exception: If the length of columns does not match the number of columns in data or variable_types.
        Exception: If the length of subjects does not match the number of rows in data.
    """
    if len(self.variable_types) != self.data.shape[1]:
        raise Exception("Types do not match to the data")

    if self.columns is None:
        logger.warning("No columns found, using column names from data")
        self.columns = self.data.columns.tolist()

    if self.subjects is None or len(self.subjects) != self.data.shape[0]:
        logger.warning("No subjects found, using index as subjects")
        self.subjects = self.data.index.tolist()

    self.ndim = get_input_dim(self.variable_types)
    if self.id_name is None:
        self.id_name = f"{self.name}_VIS{self.visit_number}"

    if len(self.columns) != self.data.shape[1] or len(self.columns) != len(
        self.variable_types
    ):
        raise Exception("Columns do not match to the data")

    if len(self.subjects) != self.data.shape[0]:
        raise Exception(
            f"Subjects do not match to the data. Found shape {self.data.shape[0]} and {len(self.subjects)} subjects."
        )

__str__()

String representation of the ModuleDataset object.

Returns:

Name Type Description
str str

A string representation of the ModuleDataset.

Source code in vambn/data/datasets.py
388
389
390
391
392
393
394
395
def __str__(self) -> str:
    """
    String representation of the ModuleDataset object.

    Returns:
        str: A string representation of the ModuleDataset.
    """
    return f"ModuleData ({self.name}, {self.visit_number})"

subset(idx)

Subset the data and mask by a list of indices.

Parameters:

Name Type Description Default
idx List[int] | ndarray

Indices of the samples to be selected.

required

Returns:

Name Type Description
ModuleDataset ModuleDataset

New ModuleDataset object with the subsetted data and mask.

Source code in vambn/data/datasets.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def subset(self, idx: List[int] | np.ndarray) -> "ModuleDataset":
    """
    Subset the data and mask by a list of indices.

    Args:
        idx (List[int] | np.ndarray): Indices of the samples to be selected.

    Returns:
        ModuleDataset: New ModuleDataset object with the subsetted data and mask.
    """
    return ModuleDataset(
        name=self.name,
        data=self.data.iloc[idx, :],
        mask=self.mask.iloc[idx, :],
        variable_types=self.variable_types,
        scalers=self.scalers,
        columns=self.columns,
        subjects=self.subjects,
        visit_number=self.visit_number,
        id_name=self.id_name,
        ndim=self.ndim,
    )

to(device)

Move the dataset to a specific device.

Parameters:

Name Type Description Default
device device

Device to be used.

required

Returns:

Name Type Description
ModuleDataset ModuleDataset

Dataset on the specified device.

Source code in vambn/data/datasets.py
301
302
303
304
305
306
307
308
309
310
311
312
def to(self, device: torch.device) -> "ModuleDataset":
    """
    Move the dataset to a specific device.

    Args:
        device (torch.device): Device to be used.

    Returns:
        ModuleDataset: Dataset on the specified device.
    """
    self.device = device
    return self

to_pandas()

Convert the data and mask to pandas DataFrame.

Returns:

Type Description
DataFrame

pd.DataFrame: Data and mask as pandas DataFrame.

Source code in vambn/data/datasets.py
375
376
377
378
379
380
381
382
383
384
385
386
def to_pandas(self) -> pd.DataFrame:
    """
    Convert the data and mask to pandas DataFrame.

    Returns:
        pd.DataFrame: Data and mask as pandas DataFrame.
    """
    out_df = self.data.copy()
    out_df.columns = [re.sub(r"_VIS\d+", "", x) for x in self.columns]
    out_df["SUBJID"] = self.subjects
    out_df["VISIT"] = self.visit_number
    return out_df

VambnDataset

Bases: Dataset

Dataset for the VAMBN model.

Attributes:

Name Type Description
modules List[ModuleDataset]

List of module datasets.

module_names List[str]

List of unique module names.

num_patients int

Number of patients in the dataset.

visits_per_module Optional[dict]

Dictionary of visits per module.

selected_modules List[str]

List of selected modules.

selected_visits List[int]

List of selected visits.

num_timepoints int

Number of timepoints.

subj List[str]

List of subject IDs.

device device

Device to use for tensor operations.

Source code in vambn/data/datasets.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
class VambnDataset(Dataset):
    """
    Dataset for the VAMBN model.

    Attributes:
        modules (List[ModuleDataset]): List of module datasets.
        module_names (List[str]): List of unique module names.
        num_patients (int): Number of patients in the dataset.
        visits_per_module (Optional[dict]): Dictionary of visits per module.
        selected_modules (List[str]): List of selected modules.
        selected_visits (List[int]): List of selected visits.
        num_timepoints (int): Number of timepoints.
        subj (List[str]): List of subject IDs.
        device (torch.device): Device to use for tensor operations.
    """

    def __init__(self, modules: List[ModuleDataset]) -> None:
        """
        Initialize the VambnDataset.

        Args:
            modules (List[ModuleDataset]): Modules to be included in the dataset.

        Raises:
            Exception: If no modules are provided or if the number of rows in the modules do not match.
        """
        super().__init__()

        if len(modules) < 1:
            raise Exception("No modules found")

        self.modules = sorted(modules, key=lambda x: x.name)
        self.module_names = sorted(list(set([x.name for x in modules])))

        unique_nrow = list(set([x.data.shape[0] for x in modules]))
        if len(unique_nrow) > 1:
            raise Exception(f"Number of rows do not match: {unique_nrow}")

        self.num_patients = unique_nrow[0]
        self.visits_per_module = None

        self._prepare_dataset(self.module_names)
        self.selected_modules = self.module_names
        self.selected_visits = list(set(self.module_wise_visits))
        self.num_timepoints = len(list(set(self.module_wise_visits)))
        self.subj = self.modules[0].subjects
        self.device = torch.device("cpu")

        if self.subj is None:
            raise Exception("No subjects found")

    def to(self, device: torch.device) -> "VambnDataset":
        """
        Move the dataset to a specific device.

        Args:
            device (torch.device): Device to be used.

        Returns:
            VambnDataset: Dataset on the specified device.
        """
        self.device = device
        self.modules = [x.to(device) for x in self.modules]
        return self

    def get_modules(self, name: str) -> List[ModuleDataset]:
        """
        Get the modules with the given name.

        Args:
            name (str): Name of the module.

        Returns:
            List[ModuleDataset]: ModuleDataset objects with the given name sorted by visit number.
        """
        selection = [x for x in self.modules if name == x.name]
        selection = sorted(selection, key=lambda x: x.visit_number)
        return selection

    def __getitem__(self, idx: int) -> tuple[list[Tensor], list[Tensor]]:
        """
        Get the data and mask for a specific index.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple[list[Tensor], list[Tensor]]: Data and mask tensors.
        """
        s_data = self.patient_x[idx]
        s_mask = self.patient_mask[idx]
        return s_data, s_mask  # , self.subj[idx]

    def __len__(self) -> int:
        """
        Get the number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
        return self.num_patients

    def __str__(self) -> str:
        """
        String representation of the VambnDataset object.

        Returns:
            str: A string representation of the VambnDataset.
        """
        return f"""
        VambnDataset
            Modules: {self.module_names}
            N: {self.num_patients}
            V: {self.num_timepoints}
            M: {self.num_modules}
        """

    @property
    def num_modules(self) -> int:
        """
        Get the number of modules in the dataset.

        Returns:
            int: Number of modules.
        """
        return len(self.module_names)

    @property
    def is_longitudinal(self) -> bool:
        """
        Check if the dataset is longitudinal.

        Returns:
            bool: True if the dataset is longitudinal, False otherwise.
        """
        return self.num_timepoints > 1

    def _prepare_dataset(
        self,
        selection: Optional[List[str]] = None,
        visits: Optional[List[int]] = None,
    ):
        """
        Prepare the dataset by preparing the internal variables.

        Args:
            selection (Optional[List[str]]): List of module names to select. Defaults to None.
            visits (Optional[List[int]]): List of visits to select. Defaults to None.

        Raises:
            Exception: If types do not match across visits.
        """
        if selection is not None:
            self.selected_modules = selection
            self.module_names = selection
        if visits is not None and visits != []:
            self.selected_visits = visits
            self.modules = [
                x
                for x in self.modules
                if x.name in self.module_names and x.visit_number in visits
            ]
            self.num_timepoints = len(visits)
        else:
            self.modules = [
                x for x in self.modules if x.name in self.module_names
            ]

        self.module_wise_visits = tuple(
            set([x.visit_number for x in self.modules])
        )
        logger.debug(f"Module wise visits: {self.module_wise_visits}")
        self.num_timepoints = len(self.module_wise_visits)
        logger.debug(f"Number of timepoints: {self.num_timepoints}")
        self.module_wise_names = [x.name for x in self.modules]

        module_data = {}
        module_mask = {}
        for sel in self.module_names:
            modules = self.get_module_data(sel)
            if len(modules) < 1:
                raise Exception(
                    f"No modules found for {sel}, available: {self.module_names}, module_names: {tuple(x.name for x in self.modules)})"
                )
            if len(modules) > 1:
                unequal_types = [
                    check_equal_types(
                        modules[0].variable_types, x.variable_types
                    )
                    for x in modules[1:]
                ]
                if any(unequal_types):
                    logger.info(
                        f"Module types: {[x.variable_types for x in modules]}"
                    )
                    raise Exception(f"Types do not match across visits ({sel})")

            logger.info(f"Module {sel} has {len(modules)} visits")
            x = []
            mask = []
            for module in modules:
                x.append(module.input_data)
                mask.append(module.input_mask)

            n_features = x[0].shape[1]
            n_rows = x[0].shape[0]

            module_data[sel] = (
                torch.stack(x, dim=1).float().view(n_rows, -1, n_features)
                if self.is_longitudinal
                else x[0].float()
            )
            module_mask[sel] = (
                torch.stack(mask, dim=1).float().view(n_rows, -1, n_features)
                if self.is_longitudinal
                else mask[0].float()
            )
            logger.info(f"Module {sel} has shape {module_data[sel].shape}")

        self.x = module_data
        self.mask = module_mask

        self.patient_x = []
        self.patient_mask = []
        for i in range(self.num_patients):
            if self.is_longitudinal:
                x = [self.x[sel][i, :, :] for sel in self.module_names]
                mask = [self.mask[sel][i, :, :] for sel in self.module_names]
            else:
                x = [self.x[sel][i, :] for sel in self.module_names]
                mask = [self.mask[sel][i, :] for sel in self.module_names]
            self.patient_x.append(x)
            self.patient_mask.append(mask)

        self.visits_per_module = {}
        for x in self.modules:
            if x.name not in self.visits_per_module:
                self.visits_per_module[x.name] = 0
            self.visits_per_module[x.name] += 1

    def get_module(self, name: str) -> ModuleDataset:
        """
        Get the ModuleDataset for a given name.

        Args:
            name (str): Name of the module with visit number.

        Returns:
            ModuleDataset: ModuleDataset object with the given name.

        Raises:
            Exception: If the module is not found or multiple modules with the same name are found.
        """
        selection = [x for x in self.modules if name == x.id_name]
        if len(selection) != 1:
            logger.warning(
                f"Selection: {selection}, Modules: {[x.id_name for x in self.modules]}, Name: {name}"
            )
            raise Exception(
                f"Selection: {selection}, Modules: {[x.id_name for x in self.modules]}, Name: {name}"
            )
        return selection[0]

    def get_module_data(self, selection: str) -> List[ModuleDataset]:
        """
        Get the ModuleDataset for a given name without considering visit number.

        Args:
            selection (str): Name of the module.

        Returns:
            List[ModuleDataset]: List of all ModuleDataset objects with the given name sorted by visit number.
        """
        modules = sorted(
            [x for x in self.modules if x.name == selection],
            key=lambda x: x.visit_number,
        )
        return modules

    def get_longitudinal_data(self, selection: str) -> Tuple[Tensor, Tensor]:
        """
        Get longitudinal data for a specific module.

        Args:
            selection (str): Module name.

        Returns:
            Tuple[Tensor, Tensor]: Data and mask tensors.
        """
        modules = self.get_modules(selection)
        data = torch.stack(
            [torch.tensor(x.data.values) for x in modules], dim=1
        )
        mask = torch.stack(
            [torch.tensor(x.mask.values) for x in modules], dim=1
        )
        assert data.ndim == 3

        # if the data is longitudinal, then the shape is (n_subjects, n_visits, n_features)
        # otherwise remove the second dimension
        if not self.is_longitudinal:
            data = data.squeeze(1)
            mask = mask.squeeze(1)

        return data, mask

    def select_modules(
        self,
        selection: Optional[List[str]] = None,
        visits: Optional[List[int]] = None,
    ):
        """
        Select certain modules and visits from the existing dataset.

        Args:
            selection (Optional[List[str]]): Module names. Defaults to None.
            visits (Optional[List[int]]): Visit numbers. Defaults to None.
        """
        if selection is None and visits is None:
            return None
        else:
            if selection is not None and isinstance(selection, str):
                selection = [selection]

            if visits is not None and isinstance(visits, int):
                visits = [visits]

            self._prepare_dataset(selection=selection, visits=visits)

    def subset(self, ratio: float) -> "VambnDataset":
        """
        Subset the dataset by a given ratio.

        Args:
            ratio (float): Ratio of the subset to be returned.

        Returns:
            VambnDataset: Subset of the dataset.
        """
        patient_idx = np.arange(self.num_patients)
        selected_idx = np.random.choice(
            patient_idx, size=round(self.num_patients * ratio)
        )
        return self.subset_by_idx(selected_idx)

    def subset_by_idx(
        self, selected_idx: List[int] | np.ndarray[Any, np.dtype]
    ) -> "VambnDataset":
        """
        Subset the dataset by a given list of indices.

        Args:
            selected_idx (List[int]): Indices of the samples to be selected.

        Returns:
            VambnDataset: Subset of the dataset.
        """
        out_modules = [x.subset(selected_idx) for x in self.modules]
        out_ds = VambnDataset(out_modules)
        return out_ds

    def train_test_split(
        self, test_ratio: float
    ) -> Tuple["VambnDataset", "VambnDataset"]:
        """
        Generate a train and test split of the dataset.

        Args:
            test_ratio (float): Ratio of the dataset to be used as the test set.

        Returns:
            Tuple[VambnDataset, VambnDataset]: Train and test split.
        """
        idx = list(range(self.num_patients))
        train_idx, test_idx = train_test_split(
            idx, test_size=test_ratio, random_state=42
        )
        return self.subset_by_idx(train_idx), self.subset_by_idx(test_idx)

    def get_iter_dataset(self, name: str) -> IterDataset | LongitudinalDataset:
        """
        Get a PyTorch compatible dataset for a given module name.

        Args:
            name (str): Module name.

        Returns:
            IterDataset | LongitudinalDataset: Either an IterDataset or a LongitudinalDataset depending on the number of visits.
        """
        modules = self.get_modules(name)
        if self.is_longitudinal and len(modules) > 1:
            data = torch.stack([x.input_data for x in modules], dim=0)
            mask = torch.stack([x.input_mask for x in modules], dim=0)
            return LongitudinalDataset(data, mask, modules[0].variable_types)
        else:
            assert len(modules) == 1
            data = modules[0].input_data
            mask = modules[0].input_mask
            return IterDataset(data, mask, modules[0].variable_types)

    def to_pandas(self, module_name: Optional[str] = None) -> pd.DataFrame:
        """
        Convert the data and mask to pandas DataFrame.

        Args:
            module_name (Optional[str]): Name of the module to convert. Defaults to None.

        Returns:
            pd.DataFrame: Data and mask as pandas DataFrame.
        """
        if module_name is not None:
            selected_modules = [
                x for x in self.modules if x.name == module_name
            ]
        else:
            selected_modules = self.modules
        module_dfs = {}
        for module in selected_modules:
            if module.name not in module_dfs:
                module_dfs[module.name] = module.to_pandas()
            else:
                module_dfs[module.name] = pd.concat(
                    [module_dfs[module.name], module.to_pandas()]
                )
        # merge on subject id and visit
        df = None
        for m in module_dfs.values():
            if df is None:
                df = m
            else:
                df = df.merge(m, on=["SUBJID", "VISIT"], how="outer")

        return df

is_longitudinal: bool property

Check if the dataset is longitudinal.

Returns:

Name Type Description
bool bool

True if the dataset is longitudinal, False otherwise.

num_modules: int property

Get the number of modules in the dataset.

Returns:

Name Type Description
int int

Number of modules.

__getitem__(idx)

Get the data and mask for a specific index.

Parameters:

Name Type Description Default
idx int

Index of the sample.

required

Returns:

Type Description
tuple[list[Tensor], list[Tensor]]

tuple[list[Tensor], list[Tensor]]: Data and mask tensors.

Source code in vambn/data/datasets.py
477
478
479
480
481
482
483
484
485
486
487
488
489
def __getitem__(self, idx: int) -> tuple[list[Tensor], list[Tensor]]:
    """
    Get the data and mask for a specific index.

    Args:
        idx (int): Index of the sample.

    Returns:
        tuple[list[Tensor], list[Tensor]]: Data and mask tensors.
    """
    s_data = self.patient_x[idx]
    s_mask = self.patient_mask[idx]
    return s_data, s_mask  # , self.subj[idx]

__init__(modules)

Initialize the VambnDataset.

Parameters:

Name Type Description Default
modules List[ModuleDataset]

Modules to be included in the dataset.

required

Raises:

Type Description
Exception

If no modules are provided or if the number of rows in the modules do not match.

Source code in vambn/data/datasets.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
def __init__(self, modules: List[ModuleDataset]) -> None:
    """
    Initialize the VambnDataset.

    Args:
        modules (List[ModuleDataset]): Modules to be included in the dataset.

    Raises:
        Exception: If no modules are provided or if the number of rows in the modules do not match.
    """
    super().__init__()

    if len(modules) < 1:
        raise Exception("No modules found")

    self.modules = sorted(modules, key=lambda x: x.name)
    self.module_names = sorted(list(set([x.name for x in modules])))

    unique_nrow = list(set([x.data.shape[0] for x in modules]))
    if len(unique_nrow) > 1:
        raise Exception(f"Number of rows do not match: {unique_nrow}")

    self.num_patients = unique_nrow[0]
    self.visits_per_module = None

    self._prepare_dataset(self.module_names)
    self.selected_modules = self.module_names
    self.selected_visits = list(set(self.module_wise_visits))
    self.num_timepoints = len(list(set(self.module_wise_visits)))
    self.subj = self.modules[0].subjects
    self.device = torch.device("cpu")

    if self.subj is None:
        raise Exception("No subjects found")

__len__()

Get the number of samples in the dataset.

Returns:

Name Type Description
int int

Number of samples.

Source code in vambn/data/datasets.py
491
492
493
494
495
496
497
498
def __len__(self) -> int:
    """
    Get the number of samples in the dataset.

    Returns:
        int: Number of samples.
    """
    return self.num_patients

__str__()

String representation of the VambnDataset object.

Returns:

Name Type Description
str str

A string representation of the VambnDataset.

Source code in vambn/data/datasets.py
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def __str__(self) -> str:
    """
    String representation of the VambnDataset object.

    Returns:
        str: A string representation of the VambnDataset.
    """
    return f"""
    VambnDataset
        Modules: {self.module_names}
        N: {self.num_patients}
        V: {self.num_timepoints}
        M: {self.num_modules}
    """

get_iter_dataset(name)

Get a PyTorch compatible dataset for a given module name.

Parameters:

Name Type Description Default
name str

Module name.

required

Returns:

Type Description
IterDataset | LongitudinalDataset

IterDataset | LongitudinalDataset: Either an IterDataset or a LongitudinalDataset depending on the number of visits.

Source code in vambn/data/datasets.py
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
def get_iter_dataset(self, name: str) -> IterDataset | LongitudinalDataset:
    """
    Get a PyTorch compatible dataset for a given module name.

    Args:
        name (str): Module name.

    Returns:
        IterDataset | LongitudinalDataset: Either an IterDataset or a LongitudinalDataset depending on the number of visits.
    """
    modules = self.get_modules(name)
    if self.is_longitudinal and len(modules) > 1:
        data = torch.stack([x.input_data for x in modules], dim=0)
        mask = torch.stack([x.input_mask for x in modules], dim=0)
        return LongitudinalDataset(data, mask, modules[0].variable_types)
    else:
        assert len(modules) == 1
        data = modules[0].input_data
        mask = modules[0].input_mask
        return IterDataset(data, mask, modules[0].variable_types)

get_longitudinal_data(selection)

Get longitudinal data for a specific module.

Parameters:

Name Type Description Default
selection str

Module name.

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[Tensor, Tensor]: Data and mask tensors.

Source code in vambn/data/datasets.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
def get_longitudinal_data(self, selection: str) -> Tuple[Tensor, Tensor]:
    """
    Get longitudinal data for a specific module.

    Args:
        selection (str): Module name.

    Returns:
        Tuple[Tensor, Tensor]: Data and mask tensors.
    """
    modules = self.get_modules(selection)
    data = torch.stack(
        [torch.tensor(x.data.values) for x in modules], dim=1
    )
    mask = torch.stack(
        [torch.tensor(x.mask.values) for x in modules], dim=1
    )
    assert data.ndim == 3

    # if the data is longitudinal, then the shape is (n_subjects, n_visits, n_features)
    # otherwise remove the second dimension
    if not self.is_longitudinal:
        data = data.squeeze(1)
        mask = mask.squeeze(1)

    return data, mask

get_module(name)

Get the ModuleDataset for a given name.

Parameters:

Name Type Description Default
name str

Name of the module with visit number.

required

Returns:

Name Type Description
ModuleDataset ModuleDataset

ModuleDataset object with the given name.

Raises:

Type Description
Exception

If the module is not found or multiple modules with the same name are found.

Source code in vambn/data/datasets.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
def get_module(self, name: str) -> ModuleDataset:
    """
    Get the ModuleDataset for a given name.

    Args:
        name (str): Name of the module with visit number.

    Returns:
        ModuleDataset: ModuleDataset object with the given name.

    Raises:
        Exception: If the module is not found or multiple modules with the same name are found.
    """
    selection = [x for x in self.modules if name == x.id_name]
    if len(selection) != 1:
        logger.warning(
            f"Selection: {selection}, Modules: {[x.id_name for x in self.modules]}, Name: {name}"
        )
        raise Exception(
            f"Selection: {selection}, Modules: {[x.id_name for x in self.modules]}, Name: {name}"
        )
    return selection[0]

get_module_data(selection)

Get the ModuleDataset for a given name without considering visit number.

Parameters:

Name Type Description Default
selection str

Name of the module.

required

Returns:

Type Description
List[ModuleDataset]

List[ModuleDataset]: List of all ModuleDataset objects with the given name sorted by visit number.

Source code in vambn/data/datasets.py
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
def get_module_data(self, selection: str) -> List[ModuleDataset]:
    """
    Get the ModuleDataset for a given name without considering visit number.

    Args:
        selection (str): Name of the module.

    Returns:
        List[ModuleDataset]: List of all ModuleDataset objects with the given name sorted by visit number.
    """
    modules = sorted(
        [x for x in self.modules if x.name == selection],
        key=lambda x: x.visit_number,
    )
    return modules

get_modules(name)

Get the modules with the given name.

Parameters:

Name Type Description Default
name str

Name of the module.

required

Returns:

Type Description
List[ModuleDataset]

List[ModuleDataset]: ModuleDataset objects with the given name sorted by visit number.

Source code in vambn/data/datasets.py
463
464
465
466
467
468
469
470
471
472
473
474
475
def get_modules(self, name: str) -> List[ModuleDataset]:
    """
    Get the modules with the given name.

    Args:
        name (str): Name of the module.

    Returns:
        List[ModuleDataset]: ModuleDataset objects with the given name sorted by visit number.
    """
    selection = [x for x in self.modules if name == x.name]
    selection = sorted(selection, key=lambda x: x.visit_number)
    return selection

select_modules(selection=None, visits=None)

Select certain modules and visits from the existing dataset.

Parameters:

Name Type Description Default
selection Optional[List[str]]

Module names. Defaults to None.

None
visits Optional[List[int]]

Visit numbers. Defaults to None.

None
Source code in vambn/data/datasets.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
def select_modules(
    self,
    selection: Optional[List[str]] = None,
    visits: Optional[List[int]] = None,
):
    """
    Select certain modules and visits from the existing dataset.

    Args:
        selection (Optional[List[str]]): Module names. Defaults to None.
        visits (Optional[List[int]]): Visit numbers. Defaults to None.
    """
    if selection is None and visits is None:
        return None
    else:
        if selection is not None and isinstance(selection, str):
            selection = [selection]

        if visits is not None and isinstance(visits, int):
            visits = [visits]

        self._prepare_dataset(selection=selection, visits=visits)

subset(ratio)

Subset the dataset by a given ratio.

Parameters:

Name Type Description Default
ratio float

Ratio of the subset to be returned.

required

Returns:

Name Type Description
VambnDataset VambnDataset

Subset of the dataset.

Source code in vambn/data/datasets.py
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
def subset(self, ratio: float) -> "VambnDataset":
    """
    Subset the dataset by a given ratio.

    Args:
        ratio (float): Ratio of the subset to be returned.

    Returns:
        VambnDataset: Subset of the dataset.
    """
    patient_idx = np.arange(self.num_patients)
    selected_idx = np.random.choice(
        patient_idx, size=round(self.num_patients * ratio)
    )
    return self.subset_by_idx(selected_idx)

subset_by_idx(selected_idx)

Subset the dataset by a given list of indices.

Parameters:

Name Type Description Default
selected_idx List[int]

Indices of the samples to be selected.

required

Returns:

Name Type Description
VambnDataset VambnDataset

Subset of the dataset.

Source code in vambn/data/datasets.py
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
def subset_by_idx(
    self, selected_idx: List[int] | np.ndarray[Any, np.dtype]
) -> "VambnDataset":
    """
    Subset the dataset by a given list of indices.

    Args:
        selected_idx (List[int]): Indices of the samples to be selected.

    Returns:
        VambnDataset: Subset of the dataset.
    """
    out_modules = [x.subset(selected_idx) for x in self.modules]
    out_ds = VambnDataset(out_modules)
    return out_ds

to(device)

Move the dataset to a specific device.

Parameters:

Name Type Description Default
device device

Device to be used.

required

Returns:

Name Type Description
VambnDataset VambnDataset

Dataset on the specified device.

Source code in vambn/data/datasets.py
449
450
451
452
453
454
455
456
457
458
459
460
461
def to(self, device: torch.device) -> "VambnDataset":
    """
    Move the dataset to a specific device.

    Args:
        device (torch.device): Device to be used.

    Returns:
        VambnDataset: Dataset on the specified device.
    """
    self.device = device
    self.modules = [x.to(device) for x in self.modules]
    return self

to_pandas(module_name=None)

Convert the data and mask to pandas DataFrame.

Parameters:

Name Type Description Default
module_name Optional[str]

Name of the module to convert. Defaults to None.

None

Returns:

Type Description
DataFrame

pd.DataFrame: Data and mask as pandas DataFrame.

Source code in vambn/data/datasets.py
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def to_pandas(self, module_name: Optional[str] = None) -> pd.DataFrame:
    """
    Convert the data and mask to pandas DataFrame.

    Args:
        module_name (Optional[str]): Name of the module to convert. Defaults to None.

    Returns:
        pd.DataFrame: Data and mask as pandas DataFrame.
    """
    if module_name is not None:
        selected_modules = [
            x for x in self.modules if x.name == module_name
        ]
    else:
        selected_modules = self.modules
    module_dfs = {}
    for module in selected_modules:
        if module.name not in module_dfs:
            module_dfs[module.name] = module.to_pandas()
        else:
            module_dfs[module.name] = pd.concat(
                [module_dfs[module.name], module.to_pandas()]
            )
    # merge on subject id and visit
    df = None
    for m in module_dfs.values():
        if df is None:
            df = m
        else:
            df = df.merge(m, on=["SUBJID", "VISIT"], how="outer")

    return df

train_test_split(test_ratio)

Generate a train and test split of the dataset.

Parameters:

Name Type Description Default
test_ratio float

Ratio of the dataset to be used as the test set.

required

Returns:

Type Description
Tuple[VambnDataset, VambnDataset]

Tuple[VambnDataset, VambnDataset]: Train and test split.

Source code in vambn/data/datasets.py
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
def train_test_split(
    self, test_ratio: float
) -> Tuple["VambnDataset", "VambnDataset"]:
    """
    Generate a train and test split of the dataset.

    Args:
        test_ratio (float): Ratio of the dataset to be used as the test set.

    Returns:
        Tuple[VambnDataset, VambnDataset]: Train and test split.
    """
    idx = list(range(self.num_patients))
    train_idx, test_idx = train_test_split(
        idx, test_size=test_ratio, random_state=42
    )
    return self.subset_by_idx(train_idx), self.subset_by_idx(test_idx)

helpers

filter_data(data, missingness_threshold, selected_columns, variance_threshold=0.1)

Filter data by removing columns with zero variance and too many missing values.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
missingness_threshold float

Threshold for missingness.

required
selected_columns List[str] | None

Columns to keep.

required
variance_threshold float

Minimum variance. Defaults to 0.1.

0.1

Returns:

Type Description
Tuple[DataFrame, Set[str]]

Tuple[pd.DataFrame, Set[str]]: Dataframe with filtered columns and set of selected columns.

Source code in vambn/data/helpers.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def filter_data(
    data: pd.DataFrame,
    missingness_threshold: float,
    selected_columns: List[str] | None,
    variance_threshold: float = 0.1,
) -> Tuple[pd.DataFrame, Set[str]]:
    """
    Filter data by removing columns with zero variance and too many missing values.

    Args:
        data (pd.DataFrame): Input data.
        missingness_threshold (float): Threshold for missingness.
        selected_columns (List[str] | None): Columns to keep.
        variance_threshold (float, optional): Minimum variance. Defaults to 0.1.

    Returns:
        Tuple[pd.DataFrame, Set[str]]: Dataframe with filtered columns and set of selected columns.
    """
    original_data = data.copy()

    if selected_columns is None:
        subj_columns = set([x for x in data.columns if x.startswith("SUBJID")])
        data.drop(columns=list(subj_columns), inplace=True)

        # Filter columns that have only unique values
        selected_unique_columns = set()
        for column in data.columns:
            if data[column].nunique() > 1:
                selected_unique_columns.add(column)

        print(
            f"Ratio of selected unique columns: {len(selected_unique_columns) / data.shape[1]}"
        )

        # Function to compute 'variance' for numeric and 'diversity' for string columns
        def compute_variance_or_diversity(column):
            if column.dtype in [np.int64, np.float64]:
                # return column.var(skipna=True)
                x = column.copy()
                # compute the variance
                return x.var(skipna=True)
            elif column.dtype == "object":  # If column has string data
                unique_count = column.nunique()
                return 10000 if unique_count > 1 else 0
            return 10000 if column.nunique() > 1 else 0

        column_measures = data.apply(compute_variance_or_diversity)
        print(column_measures)
        selected_var_columns = column_measures[
            column_measures > variance_threshold
        ].keys()
        logger.info(
            f"Ratio of selected var columns: {len(selected_var_columns) / data.shape[1]}"
        )

        column_missingness = data.isna().sum(axis=0) / data.shape[0]
        selected_missingness_columns = column_missingness[
            column_missingness < (missingness_threshold / 100)
        ].keys()
        logger.info(
            f"Ratio of selected missingness columns: {len(selected_missingness_columns) / data.shape[1]}"
        )

        # Get the overlap of both column sets
        selected_columns = (
            set(selected_var_columns)
            & set(selected_missingness_columns)
            & set(selected_unique_columns)
        )
        selected_columns |= set(
            data.columns[data.dtypes == "object"]
        )  # Include string columns
        selected_columns |= subj_columns  # Add SUBJID columns back
        logger.info(
            f"Ratio of selected columns: {len(selected_columns) / data.shape[1]}"
        )

    # Sort the selected columns to maintain a consistent order
    selected_columns = sorted(list(selected_columns))

    # Filter the original data to include only the selected columns
    filtered_data = original_data.loc[:, selected_columns].copy()

    return filtered_data, set(selected_columns)

load_vambn_data(data_folder, selected_visits=None, selected_modules=None)

Load the data from the preprocessed folder.

Parameters:

Name Type Description Default
data_folder Path

Folder containing the preprocessed data.

required
selected_visits Optional[List[int]]

List of visits to select. Defaults to None.

None
selected_modules Optional[List[str]]

List of modules to select. Defaults to None.

None

Raises:

Type Description
FileNotFoundError

If the data folder or any required data file is not found.

Returns:

Name Type Description
VambnDataset VambnDataset

Dataset with the loaded data.

Source code in vambn/data/helpers.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def load_vambn_data(
    data_folder: Path,
    selected_visits: Optional[List[int]] = None,
    selected_modules: Optional[List[str]] = None,
) -> VambnDataset:
    """
    Load the data from the preprocessed folder.

    Args:
        data_folder (Path): Folder containing the preprocessed data.
        selected_visits (Optional[List[int]], optional): List of visits to select. Defaults to None.
        selected_modules (Optional[List[str]], optional): List of modules to select. Defaults to None.

    Raises:
        FileNotFoundError: If the data folder or any required data file is not found.

    Returns:
        VambnDataset: Dataset with the loaded data.
    """
    if not data_folder.exists():
        raise FileNotFoundError(f"Data folder {data_folder} not found")
    data_files = [x for x in data_folder.glob("**/*imp.csv")]
    module_names = [str(x).split("/")[-1].split("_imp")[0] for x in data_files]

    modules_map = defaultdict(lambda: defaultdict())

    for i, module_name in enumerate(module_names):
        if "stalone" in module_name:
            continue
        # Read data
        logger.debug(f"Processing module {module_name}")
        file = str(data_folder / f"{module_name}_imp.csv")
        module_data = pd.read_csv(file, index_col=0)
        data = torch.from_numpy(module_data.values)
        # Read column (features) names
        column_names = module_data.columns.values
        # Read row (sample) names
        row_names = module_data.index.values

        scaler_files = list(
            data_folder.glob(f"**/{module_name.split('_')[0]}*_scaler.pkl")
        )
        scalers = {}
        for scaler_file in scaler_files:
            scaler = pickle.loads(scaler_file.open("rb").read())
            column_name = scaler_file.stem.replace(
                f"{module_name.split('_')[0]}_", ""
            ).replace("_scaler", "")
            column_name_vis = f"{column_name}_{module_name.split('_')[1]}"
            scalers[column_name_vis] = scaler

        # Read types of columns
        type_df = pd.read_csv(str(data_folder / f"{module_name}_types.csv"))

        def get_type(type_row: pd.Series) -> VariableType:
            return VariableType(
                name=type_row["column_names"],
                data_type=type_row["type"],
                n_parameters=int(
                    1
                    if type_row["type"] == "count"
                    else 2
                    if type_row["type"]
                    in ["real", "pos", "truncate_norm", "gamma"]
                    else type_row["nclass"]
                    if type_row["type"] == "cat"
                    else 0
                ),
                input_dim=int(
                    type_row["nclass"] if type_row["type"] == "cat" else 1,
                ),
                scaler=scalers.get(type_row["column_names"])
                if type_row["type"] in ["real", "pos", "truncate_norm", "gamma"]
                else None,
            )

        types = [get_type(row) for _, row in type_df.iterrows()]

        # convert scalers into the correct list
        scaler_list = [None] * len(column_names)
        for i, name in enumerate(column_names):
            if name in scalers:
                scaler_list[i] = scalers[name]

        # assert that data matches type
        for i, (var_type, column) in enumerate(zip(types, column_names)):
            data_column = data[:, i].nan_to_num()
            if var_type.data_type == "count":
                assert torch.all(
                    data_column >= 0
                ), f"Negative values in column {column}"

            logger.info(
                f"Range of values for column {column} in module {module_name} ({var_type.data_type}): {data_column.min()} -- {data_column.max()}"
            )

        # Missing data
        # if 1 = observed, 0 = missing
        # TODO: save mask directly instead of having the longtable
        missing_mask = torch.ones(size=data.shape)
        try:
            missing = pd.read_csv(str(data_folder / f"{module_name}_mask.csv"))
            # assert column names "row" and "column"
            assert "row" in missing.columns.values
            assert "column" in missing.columns.values

            for tup in missing.itertuples():
                missing_mask[tup.row - 1, tup.column - 1] = 0
        except pd.errors.EmptyDataError:
            pass
        except FileNotFoundError:
            logger.warning(
                "File with missing observations not found. Assume that all values are present"
            )

        # Verify that torch.isnan(data) == missing_mask
        assert torch.all(
            torch.isnan(data) == (missing_mask == 0)
        ), "Missing mask does not match data"

        # Create mapping of module id to data
        modules_map[module_name]["data"] = data
        modules_map[module_name]["column_names"] = column_names
        modules_map[module_name]["row_names"] = row_names
        modules_map[module_name]["types"] = types
        modules_map[module_name]["missing_mask"] = missing_mask
        modules_map[module_name]["scalers"] = tuple(scaler_list)

    logger.info(f"Loaded {len(modules_map)} modules.")
    logger.info(f"Module names: {sorted(list(modules_map.keys()))}")

    module_data = []
    for module_name, data_dictionary in modules_map.items():
        if "VIS" in module_name:
            name = "_".join(module_name.split("_")[:-1])
            visit = re.sub("[A-Z]+", "", module_name.split("_")[-1])
        else:
            name = module_name
            visit = "1"
        types = data_dictionary["types"]

        if len(types) >= 1:
            module_data.append(
                ModuleDataset(
                    name=name,
                    data=pd.DataFrame(data_dictionary["data"].numpy()),
                    mask=pd.DataFrame(data_dictionary["missing_mask"].numpy()),
                    variable_types=types,
                    visit_number=int(visit),
                    scalers=data_dictionary["scalers"],
                    columns=data_dictionary["column_names"],
                    subjects=data_dictionary["row_names"],
                )
            )
        else:
            logger.warning(f"Module {module_name} has no variable. Skipping.")

    logger.info(f"Loaded {len(module_data)} modules.")
    ds = VambnDataset(module_data)
    if selected_visits is not None or selected_modules is not None:
        if selected_visits is not None and not isinstance(
            selected_visits, list
        ):
            selected_visits = [selected_visits]

        if selected_modules is not None and not isinstance(
            selected_modules, list
        ):
            selected_modules = [selected_modules]

        if selected_visits is not None and len(selected_visits) == 0:
            selected_visits = None

        if selected_modules is not None and len(selected_modules) == 0:
            selected_modules = None

        ds.select_modules(
            visits=selected_visits,
            selection=selected_modules,
        )

    # print selected visits and modules from ds
    logger.info(f"Selected visits: {ds.selected_visits}")
    logger.info(f"Selected modules: {ds.selected_modules}")

    return ds

prepare_data(data, grouping, output_path, missingness_threshold, max_visit_dict, selected_modules, module_wise_features, selected_visits=None, scaling=True, variance_threshold=0.1)

Prepare data for VAMBN and save it in the respective output folder.

The function performs the following steps
  1. Iterate over the modules and timepoints/visits.
  2. Filter out columns with zero variance and too many missing values.
  3. Keep track of missing values and create a mask (1 = missing, 0 = not missing).
  4. Impute missing data for standalone variables.
  5. Save imputed and raw data, as well as types and missing mask for each module.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required
grouping DataFrame

DataFrame containing grouping information.

required
output_path Path

Path to save the processed data.

required
missingness_threshold float

Threshold for missingness.

required
max_visit_dict Dict[str, int]

Dictionary with maximum visit number for each module.

required
selected_modules List[str] | None

List of modules to select.

required
module_wise_features Dict[str, Optional[Set[str]]] | None

Features for each module.

required
selected_visits List[int] | None

List of visits to select. Defaults to None.

None
scaling bool

Whether to apply scaling. Defaults to True.

True
variance_threshold float

Minimum variance threshold. Defaults to 0.1.

0.1

Returns:

Type Description
DataFrame

pd.DataFrame: Prepared data.

Source code in vambn/data/helpers.py
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
def prepare_data(
    data: pd.DataFrame,
    grouping: pd.DataFrame,
    output_path: Path,
    missingness_threshold: float,
    max_visit_dict: Dict[str, int],
    selected_modules: List[str] | None,
    module_wise_features: Dict[str, Optional[Set[str]]] | None,
    selected_visits: List[int] | None = None,
    scaling: bool = True,
    variance_threshold: float = 0.1,
) -> pd.DataFrame:
    """
    Prepare data for VAMBN and save it in the respective output folder.

    The function performs the following steps:
        1. Iterate over the modules and timepoints/visits.
        2. Filter out columns with zero variance and too many missing values.
        3. Keep track of missing values and create a mask (1 = missing, 0 = not missing).
        4. Impute missing data for standalone variables.
        5. Save imputed and raw data, as well as types and missing mask for each module.

    Args:
        data (pd.DataFrame): Input data.
        grouping (pd.DataFrame): DataFrame containing grouping information.
        output_path (Path): Path to save the processed data.
        missingness_threshold (float): Threshold for missingness.
        max_visit_dict (Dict[str, int]): Dictionary with maximum visit number for each module.
        selected_modules (List[str] | None): List of modules to select.
        module_wise_features (Dict[str, Optional[Set[str]]] | None): Features for each module.
        selected_visits (List[int] | None, optional): List of visits to select. Defaults to None.
        scaling (bool, optional): Whether to apply scaling. Defaults to True.
        variance_threshold (float, optional): Minimum variance threshold. Defaults to 0.1.

    Returns:
        pd.DataFrame: Prepared data.
    """
    if selected_visits is not None:
        data = data.copy().loc[data["VISIT"].isin(selected_visits), :]

    if data.shape[0] == 0:
        raise Exception(
            f"No data left after filtering. Check selected visits. ({selected_visits})"
        )

    # available_visits = sorted(data["VISIT"].unique().tolist())
    available_subjects = sorted(data["SUBJID"].unique().tolist())
    # sorted(grouping["technical_group_name"].unique().tolist())
    grouping.sort_values(
        by=["technical_group_name", "column_names"], inplace=True
    )
    grouping.drop_duplicates(inplace=True)

    subject_df = pd.DataFrame(available_subjects, columns=["SUBJID"])

    if module_wise_features is None:
        module_wise_features: Dict[str, Optional[Set[str]]] = dict()

    overall_data = []

    for module_name, subset in grouping.groupby("technical_group_name"):
        assert isinstance(module_name, str)
        # Skip if module is not selected
        if (
            "stalone" not in module_name
            and selected_modules is not None
            and module_name not in selected_modules
        ):
            logger.info(f"Skipping group {module_name}.")
            continue
        else:
            logger.info(f"Processing group {module_name}.")

        # get the features for this module
        possible_features = subset["column_names"].tolist() + [
            "VISIT",
            "SUBJID",
        ]
        selected_features = list(
            set([x for x in possible_features if x in data.columns])
        )

        # get the data for this module
        module_data = data.loc[:, selected_features].copy()
        selected_columns = module_wise_features.get(module_name, None)

        if module_data["VISIT"].min() != 1:
            raise Exception("Minimum visit is not 1.")

        final_module_data = []

        max_vist = int(max_visit_dict.get(module_name, 0))
        for visit in range(1, max_vist + 1):
            if selected_columns is not None:
                # rename all columns in selected_columns with current visit
                # use regex to replace the part after _VIS with the current visit
                selected_columns = [
                    re.sub("_VIS[a-bA-B0-9]+", f"_VIS{visit}", x)
                    for x in selected_columns
                ]

            logger.info(f"Processing visit {visit}.")
            logger.debug(f"Selected columns: {selected_columns}.")
            # get the visit data
            visit_data = module_data[module_data["VISIT"] == visit].copy()

            # drop the visit column
            visit_data.drop(columns=["VISIT"], inplace=True)

            # merge with subject_df to ensure all subjects are present
            visit_data = subject_df.merge(visit_data, on="SUBJID", how="left")

            # rename all columns by appending VIS_{visit}
            visit_data.rename(
                columns={
                    x: f"{x}_VIS{visit}" if x not in ["SUBJID", "VISIT"] else x
                    for x in visit_data.columns
                },
                inplace=True,
            )
            if "stalone" in module_name:
                visit_data.rename(
                    columns={
                        x: f"SA_{x}" if x not in ["SUBJID", "VISIT"] else x
                        for x in visit_data.columns
                    },
                    inplace=True,
                )

            # filter data
            visit_data, selected_columns = filter_data(
                data=visit_data,
                missingness_threshold=missingness_threshold,
                selected_columns=selected_columns,
                variance_threshold=variance_threshold,
            )
            if len(selected_columns) <= 1:
                logger.warning(
                    f"No columns left after filtering in module {module_name} at visit {visit}."
                )
                raise Exception("No columns left after filtering.")
            else:
                logger.info(f"Selected columns: {selected_columns}.")

            if "SUBJID" in visit_data.columns:
                visit_data.set_index("SUBJID", inplace=True)

            # get the missing mask; output should be a dataframe with 1 = missing, 0 = not missing of the same shape as visit_data
            missing_mask = visit_data.isna().astype(int)

            # define auxiliary data
            # if "stalone" in module name this is equal to missing mask
            # else check row-wise if all values are missing
            if "stalone" in module_name:
                auxiliary_data = missing_mask.copy()
                auxiliary_data.rename(
                    columns={
                        x: f"AUX_{x}" if x not in ["SUBJID", "VISIT"] else x
                        for x in auxiliary_data.columns
                    },
                    inplace=True,
                )
            else:
                auxiliary_data = missing_mask.all(axis=1).astype(int)
                auxiliary_data.columns = [f"AUX_{module_name}_VIS{visit}"]

            # impute missing values if standalone variable
            # if the variable type is continuous (real, pos), use mean imputation
            # else use mode imputation

            # adapt the name of the subset to the current visit
            renamed_subset = subset.copy()
            renamed_subset["column_names"] = renamed_subset[
                "column_names"
            ].apply(
                lambda x: f"{x}_VIS{visit}"
                if x not in ["SUBJID", "VISIT"]
                else x
            )
            if "stalone" in module_name:
                renamed_subset["column_names"] = renamed_subset[
                    "column_names"
                ].apply(
                    lambda x: f"SA_{x}" if x not in ["SUBJID", "VISIT"] else x
                )

            # create subset filtered and keep the order of selected columns
            subset_filtered = renamed_subset.loc[
                renamed_subset["column_names"].isin(selected_columns), :
            ].copy()
            subset_filtered.sort_values(by=["column_names"], inplace=True)

            # make sure the order of "column_names" is identical to the visit_data.columns
            if "stalone" in module_name:
                # iterate over the columns and types
                for column, type in zip(
                    subset_filtered["column_names"],
                    subset_filtered["hivae_types"],
                ):
                    if type in ["real", "pos", "truncate_norm", "gamma"]:
                        mean_val = float(visit_data[column].mean())
                        visit_data[column].fillna(mean_val, inplace=True)
                        logger.info(f"Imputed {column} with mean {mean_val}.")
                    else:
                        option = visit_data[column].mode(dropna=True).tolist()
                        if len(option) > 1:
                            logger.warning(
                                f"Multiple modes found for {column}. Using the first one."
                            )
                            option = option[0]
                        elif len(option) == 0:
                            logger.warning(
                                f"No mode found for {column}. Imputing with 0."
                            )
                            raise Exception("No mode found.")
                        elif len(option) == 1:
                            option = option[0]

                        # convert to float if numeric
                        # if isinstance(option, (int, float)):
                        #     option = float(option)
                        visit_data[column].fillna(option, inplace=True)
                        logger.info(f"Imputed {column} with mode {option}.")

                # check if there are still missing values
                if visit_data.isna().sum().sum() > 0:
                    logger.warning(
                        f"Module {module_name} still contains missing values after imputation."
                    )
                    raise Exception("Missing values after imputation.")

            # convert the missingness mask to a long format with the row and column indices of the missing values
            missing_mask_long = pd.melt(
                missing_mask.reset_index(),
                id_vars=["SUBJID"],
                var_name="column",
                value_name="missing",
            )
            missing_mask_long["row"] = missing_mask_long["SUBJID"].apply(
                lambda x: visit_data.index.get_loc(x) + 1
            )
            missing_mask_long["column"] = missing_mask_long["column"].apply(
                lambda x: visit_data.columns.get_loc(x) + 1
            )
            missing_mask_long = missing_mask_long.loc[
                :, ["row", "column", "missing"]
            ]
            missing_mask_long = missing_mask_long[
                missing_mask_long["missing"] == 1
            ]
            if missing_mask_long.shape[0] == 0:
                missing_mask_long = pd.DataFrame({"row": [], "column": []})
            else:
                missing_mask_long.drop(columns=["missing"], inplace=True)

            # finally create a types dataframe with the columns (type, dim, nclass)
            # type => hivae_types of the respective column
            # dim => 1 if continous (real, pos, count), number_of_classes if categorical
            # nclass => number_of_classes if categorical, 0 if continuous
            subset_filtered["dim"] = None
            subset_filtered["nclass"] = None
            for column, type in zip(
                subset_filtered["column_names"], subset_filtered["hivae_types"]
            ):
                if type in ["real", "pos", "truncate_norm", "count", "gamma"]:
                    subset_filtered.loc[
                        subset_filtered["column_names"] == column, "dim"
                    ] = 1
                    subset_filtered.loc[
                        subset_filtered["column_names"] == column, "nclass"
                    ] = ""
                else:
                    # determine the number of classes
                    nclass = visit_data[column].nunique()
                    subset_filtered.loc[
                        subset_filtered["column_names"] == column, "dim"
                    ] = nclass
                    subset_filtered.loc[
                        subset_filtered["column_names"] == column, "nclass"
                    ] = nclass

            types = subset_filtered[
                [
                    "hivae_types",
                    "dim",
                    "nclass",
                    "column_names",
                ]
            ].copy()
            # rename categorical to cat in hivae_types
            types["hivae_types"] = types["hivae_types"].apply(
                lambda x: "cat" if x == "categorical" else x
            )
            types.rename({"hivae_types": "type"}, axis=1, inplace=True)

            # get raw data with same features and order
            raw_data = module_data[module_data["VISIT"] == visit].copy()
            # rename all columns by appending VIS_{visit}
            raw_data.rename(
                columns={
                    x: f"{x}_VIS{visit}" if x not in ["SUBJID", "VISIT"] else x
                    for x in raw_data.columns
                },
                inplace=True,
            )

            # add "SA_" for stalone variables
            if "stalone" in module_name:
                raw_data.rename(
                    columns={
                        x: f"SA_{x}" if x not in ["SUBJID", "VISIT"] else x
                        for x in raw_data.columns
                    },
                    inplace=True,
                )
            raw_data_merged = subject_df.merge(
                raw_data, on="SUBJID", how="left"
            )
            raw_data_filtered = raw_data_merged.loc[
                :, list(selected_columns)
            ].copy()
            raw_data_filtered.set_index("SUBJID", inplace=True)

            if visit_data.columns.tolist() != types["column_names"].tolist():
                raise Exception(
                    f"Column names do not match. Missing columns: {set(visit_data.columns) - set(types['column_names'])} or {set(types['column_names']) - set(visit_data.columns)}"
                )

            # print info about missingness
            for column in visit_data.columns:
                logger.info(
                    f"Missingness in module {module_name} at visit {visit}: {column} - {visit_data[column].isna().sum() / visit_data.shape[0]}"
                )

            if visit_data.shape[1] == 0:
                logger.warning(
                    f"No columns left after filtering in module {module_name} at visit {visit}."
                )
                raise Exception("No columns left after filtering.")

            logger.info(
                f"Saving data for module {module_name} at visit {visit} with {visit_data.shape[1]} columns."
            )
            # save the data
            visit_data.to_csv(
                output_path / f"{module_name}_VIS{visit}_imp.csv",
            )
            missing_mask_long.to_csv(
                output_path / f"{module_name}_VIS{visit}_mask.csv", index=False
            )
            auxiliary_data.to_csv(
                output_path / f"{module_name}_VIS{visit}_aux.csv",
            )
            types.to_csv(
                output_path / f"{module_name}_VIS{visit}_types.csv", index=False
            )
            raw_data_filtered.to_csv(
                output_path / f"{module_name}_VIS{visit}_raw.csv",
            )
            final_module_data.append(raw_data_filtered)

        # append selected modules to module_wise_features
        if module_name not in module_wise_features:
            module_wise_features[module_name] = selected_columns

        # concat all dataframes for this module vertically/stacked
        filtered_module_df = pd.concat(final_module_data, axis=0)
        overall_data.append(filtered_module_df)

    # merge all dataframes by subject
    overall_data_df = reduce(
        lambda x, y: pd.merge(x, y, on="SUBJID", how="outer"), overall_data
    )

    return overall_data_df

make_data

extract_module_characteristics(name)

Extract the module name and visit from a given file name.

Parameters:

Name Type Description Default
name str

File name.

required

Returns:

Type Description
Tuple[str, str]

Tuple[str, str]: Module name and visit.

Raises:

Type Description
ValueError

If the module name or visit cannot be extracted from the file name.

Source code in vambn/data/make_data.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def extract_module_characteristics(name: str) -> Tuple[str, str]:
    """
    Extract the module name and visit from a given file name.

    Args:
        name (str): File name.

    Returns:
        Tuple[str, str]: Module name and visit.

    Raises:
        ValueError: If the module name or visit cannot be extracted from the file name.
    """
    module_search = re.search("(^[a-zA-Z_0-9]+)_VIS", name)
    if module_search is None:
        raise ValueError(f"Module name could not be extracted from {name}")
    module_name = module_search.group(1)

    visit_search = re.search("_VIS([0-9a-zA-Z]+)_", name)
    if visit_search is None:
        raise ValueError(f"Visit could not be extracted from {name}")
    visit = visit_search.group(1)

    return module_name, visit

make(data_file, grouping_file, groups_file, config_json, output_path, missingness_threshold=50, variance_threshold=0.1, log_file=None, scaling=True)

Process and prepare data for VAMBN analysis.

The function performs the following steps
  1. Set up logging.
  2. Load configuration settings.
  3. Ensure output directories exist.
  4. Read and preprocess input data.
  5. Filter data based on missingness and variance thresholds.
  6. Prepare data for VAMBN analysis and save it.

Parameters:

Name Type Description Default
data_file Path

Path to the data file.

required
grouping_file Path

Path to the grouping file.

required
groups_file Path

Path to the file containing module groups.

required
config_json Path

Path to the configuration JSON file.

required
output_path Path

Path to save the processed data.

required
missingness_threshold int

Threshold for missingness. Defaults to 50.

50
variance_threshold float

Minimum variance threshold. Defaults to 0.1.

0.1
log_file Optional[Path]

Path to the log file. Defaults to None.

None
scaling bool

Whether to apply scaling. Defaults to True.

True
Source code in vambn/data/make_data.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
@preprocess_app.command()
def make(
    data_file: Path,
    grouping_file: Path,
    groups_file: Path,
    config_json: Path,
    output_path: Path,
    missingness_threshold: int = 50,
    variance_threshold: float = 0.1,
    log_file: Optional[Path] = None,
    scaling: bool = True,
):
    """
    Process and prepare data for VAMBN analysis.

    The function performs the following steps:
        1. Set up logging.
        2. Load configuration settings.
        3. Ensure output directories exist.
        4. Read and preprocess input data.
        5. Filter data based on missingness and variance thresholds.
        6. Prepare data for VAMBN analysis and save it.

    Args:
        data_file (Path): Path to the data file.
        grouping_file (Path): Path to the grouping file.
        groups_file (Path): Path to the file containing module groups.
        config_json (Path): Path to the configuration JSON file.
        output_path (Path): Path to save the processed data.
        missingness_threshold (int, optional): Threshold for missingness. Defaults to 50.
        variance_threshold (float, optional): Minimum variance threshold. Defaults to 0.1.
        log_file (Optional[Path], optional): Path to the log file. Defaults to None.
        scaling (bool, optional): Whether to apply scaling. Defaults to True.
    """
    # set up logging
    setup_logging(level=10, log_file=log_file)

    with config_json.open("r") as f:
        config = json.load(f)

    if "missingness_threshold" in config:
        missingness_threshold = config["missingness_threshold"]
        logger.info(f"Missingness threshold set to {missingness_threshold}%.")
    if "variance_threshold" in config:
        variance_threshold = config["variance_threshold"]
        logger.info(f"Variance threshold set to {variance_threshold}.")

    # ensure output folders exist
    output_path.mkdir(parents=True, exist_ok=True)

    # read in the data files
    data = pd.read_csv(data_file)
    if "SUBJID" not in data.columns:
        data.reset_index(inplace=True)
        if "SUBJID" not in data.columns:
            raise ValueError("SUBJID column not found in data.")
    # drop duplicates based on SUBJID and VISIT
    data.drop_duplicates(subset=["SUBJID", "VISIT"], inplace=True)
    grouping = pd.read_csv(grouping_file)
    with groups_file.open("r") as f:
        selected_modules = f.read().splitlines()
    if len(selected_modules) == 0:
        selected_modules = None

    # ensure that column names contain only two underscores
    def to_camel_case(string_list):
        camel_case_list = []
        for s in string_list:
            parts = s.split("_")
            camel_case = parts[0] + "".join(
                word.capitalize() for word in parts[1:]
            )
            camel_case_list.append(camel_case)
        return camel_case_list

    if any(["_" in x for x in data.columns]):
        data.columns = to_camel_case(data.columns.tolist())
        grouping["column_names"] = to_camel_case(
            grouping["column_names"].tolist()
        )

        # copy old files and save new
        data_file.rename(data_file.parent / f"{data_file.stem}.csv.backup")
        data.to_csv(data_file, index=False)
        grouping_file.rename(
            grouping_file.parent / f"{grouping_file.stem}.csv.backup"
        )
        grouping.to_csv(grouping_file, index=False)

    overall_max_visit = 0
    max_visit_dict = {}

    columns_to_drop = set()
    for data_module in grouping["technical_group_name"].unique():
        columns = tuple(
            set(
                grouping.loc[
                    grouping["technical_group_name"] == data_module,
                    "column_names",
                ].tolist()
                + ["SUBJID", "VISIT"]
            )
        )
        columns = tuple(x for x in columns if x in data.columns)
        if len(columns) == 0:
            logger.warning(
                f"No columns found for module {data_module}. Skipping..."
            )
            continue

        subset = data.loc[:, columns]
        number_of_subjects = subset["SUBJID"].nunique()
        # get the availability ratios of subjects per visit per column
        missingness_ratio = {
            col: 1
            - (subset.groupby("VISIT")[col].count().values / number_of_subjects)
            for col in columns
        }

        module_max_visit = None
        for column, missing_ratios in missingness_ratio.items():
            missing_at_first = missing_ratios[0]
            if missing_at_first > (missingness_threshold / 100):
                logger.info(
                    f"Few data available for column {column}. Dropped due to availability of {missing_at_first} on first visit."
                )
                columns_to_drop.add(column)
                continue
            flag_vector = missing_ratios <= (missingness_threshold / 100)
            max_visit_sum = flag_vector.sum()
            max_visit_validation = flag_vector[: max_visit_sum + 1].sum()
            if max_visit_sum != max_visit_validation:
                raise Exception(
                    f"Missingness is unexpected: {missing_ratios} @ {column}"
                )
            if module_max_visit is None or max_visit_sum < module_max_visit:
                logger.info(
                    f"Module {data_module} has {max_visit_sum} visits with more than {missingness_threshold}% data."
                )
                module_max_visit = max_visit_sum

        if module_max_visit is not None:
            max_visit_dict[data_module] = module_max_visit
            if module_max_visit > overall_max_visit:
                overall_max_visit = module_max_visit
    logger.info(f"Maximum visit to keep: {overall_max_visit}.")
    # filter out visits with higher number
    data = data[data["VISIT"] <= overall_max_visit]

    processed_data = prepare_data(
        data=data,
        grouping=grouping,
        output_path=output_path,
        missingness_threshold=missingness_threshold,
        selected_modules=selected_modules,
        module_wise_features=None,
        max_visit_dict=max_visit_dict,
        scaling=scaling,
        variance_threshold=variance_threshold,
    )

    # plot_dir = output_path / "plots"
    # # generate plots per column
    # for column in tqdm(processed_data.columns.drop("SUBJID")):
    #     # make barplot for categorical data
    #     # make boxplot for numerical data

    #     # get the type of the column
    #     column_type = grouping.loc[
    #         grouping["column_names"] == column, "type"
    #     ].values[0]
    #     if column_type == "categorical" or column_type == "cat":
    #         # make barplot
    #         sns.countplot(data=processed_data, x=column)
    #         plt.savefig(plot_dir / f"{column}_barplot.png")
    #         plt.close()
    #     else:
    #         # make boxplot
    #         sns.boxplot(data=processed_data, x=column)
    #         plt.savefig(plot_dir / f"{column}_boxplot.png")
    #         plt.close()
    # save descriptive statistics
    stats = processed_data.describe().T.sort_values("std", ascending=True)
    # add normalized standard deviation
    stats["std_norm"] = stats["std"] / stats["mean"]
    # round to 2 decimal places
    stats = stats.round(2)
    stats.to_csv(output_path / "descriptive_statistics.csv")

    # save a scatterplot with x=std and y=std_norm
    sns.scatterplot(data=stats, x="std", y="std_norm")
    # make x axis in log scale
    plt.xscale("log")
    plt.savefig(output_path / "std_vs_std_norm.png")

    logger.info("Finished preprocessing data.")

merge_csv_files(folder, files, suffix)

Read preprocessed CSV files and merge them by a given column.

Parameters:

Name Type Description Default
folder Path

Folder where the files are located.

required
files List[Path]

Paths to the files.

required
suffix str

Suffix of the files to merge (e.g., '_imp.csv').

required

Raises:

Type Description
ValueError

If no files are provided or if the number of processed files does not match the number of provided files.

Returns:

Type Description
DataFrame

pd.DataFrame: Merged dataframe.

Source code in vambn/data/make_data.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
def merge_csv_files(
    folder: Path, files: List[Path], suffix: str
) -> pd.DataFrame:
    """
    Read preprocessed CSV files and merge them by a given column.

    Args:
        folder (Path): Folder where the files are located.
        files (List[Path]): Paths to the files.
        suffix (str): Suffix of the files to merge (e.g., '_imp.csv').

    Raises:
        ValueError: If no files are provided or if the number of processed files
            does not match the number of provided files.

    Returns:
        pd.DataFrame: Merged dataframe.
    """
    if files is None or len(files) == 0:
        raise ValueError("No files were provided.")

    # use regex to grep module names of the pattern /(^[a-zA-Z_]+)_VIS/
    avail_modules = set()
    avail_visits = set()
    for file in files:
        module, visit = extract_module_characteristics(file.name)
        avail_modules.add(module)
        avail_visits.add(visit)

    logger.info(f"Available modules: {avail_modules}")
    logger.info(f"Available visits: {avail_visits}")

    provided_files = set([str(x) for x in files])
    processed_files = set()
    module_data = []
    for module in avail_modules:
        logger.info(f"Processing module {module}")
        internal_df = []
        for visit in avail_visits:
            logger.info(f"Processing visit {visit}")
            file = folder / f"{module}_VIS{visit}{suffix}"
            if not file.exists():
                logger.warning(f"File {file} does not exist. Skipping...")
                continue
            processed_files.add(str(file))
            visit_data = pd.read_csv(file)
            visit_data["VISIT"] = visit
            # remove _VIS suffix from column names with regex (_VIS[0-9a-zA-Z]+)
            visit_data.columns = [
                re.sub("(_VIS[0-9a-zA-Z]+)", "", x) for x in visit_data.columns
            ]
            if "Unnamed: 0" in visit_data.columns:
                visit_data.drop(columns=["Unnamed: 0"], inplace=True)
            internal_df.append(visit_data)

        # concat along the rows
        internal_df = pd.concat(internal_df, axis=0)
        module_data.append(internal_df)

    # merge the different module data
    overall_data = module_data.pop()
    count = 1
    while module_data:
        count += 1
        overall_data = pd.merge(
            overall_data,
            module_data.pop(),
            how="outer",
            on=["SUBJID", "VISIT"],
        )

    overall_data.insert(0, "VISIT", overall_data.pop("VISIT"))
    overall_data.set_index("SUBJID", inplace=True)

    # check if all files were processed
    if provided_files != processed_files:
        raise ValueError(
            f"Files {provided_files - processed_files} were not processed. Please check the input."
        )

    return overall_data

merge_imputed_data(input_folder, merged_data, transformed_data_path, log_file=None, log_level=20)

Merge imputed data into a single CSV file.

Parameters:

Name Type Description Default
input_folder Path

Folder where the files are located.

required
merged_data Path

File where the merged data should be stored.

required
transformed_data_path Path

Path to save the transformed data.

required
log_file Optional[Path]

Optional file for logging. Defaults to None.

None
log_level int

Logging level. Defaults to 20.

20
Source code in vambn/data/make_data.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
@preprocess_app.command()
def merge_imputed_data(
    input_folder: Path,
    merged_data: Path,
    transformed_data_path: Path,
    log_file: Optional[Path] = None,
    log_level: int = 20,
) -> None:
    """
    Merge imputed data into a single CSV file.

    Args:
        input_folder (Path): Folder where the files are located.
        merged_data (Path): File where the merged data should be stored.
        transformed_data_path (Path): Path to save the transformed data.
        log_file (Optional[Path], optional): Optional file for logging. Defaults to None.
        log_level (int, optional): Logging level. Defaults to 20.
    """
    setup_logging(level=log_level, log_file=log_file)
    input_files = input_folder.glob("**/*_imp.csv")
    overall_data = merge_csv_files(input_folder, list(input_files), "_imp.csv")
    overall_data.to_csv(str(merged_data))

    transformed_data = overall_data.copy()

    for scaler_file in input_folder.glob("**/*scaler.pkl"):
        # print(f"File: {scaler_file}")
        scaler = pickle.loads(scaler_file.read_bytes())

        column_name = "_".join(
            scaler_file.name.replace("_scaler.pkl", "").split("_")[1:]
        )
        try:
            transformed_data[column_name] = scaler.inverse_transform(
                transformed_data[column_name].values.reshape(-1, 1)
            )
        except KeyError:
            logger.warning(
                f"Could not find column {column_name} in dataframe. Skipping..."
            )
            continue

    transformed_data.to_csv(str(transformed_data_path))

merge_raw_data(input_folder, output_file, log_file=None, log_level=20)

Merge raw data into a single CSV file.

Parameters:

Name Type Description Default
input_folder Path

Folder where the files are located.

required
output_file Path

File where the merged data should be stored.

required
log_file Optional[Path]

Optional file for logging. Defaults to None.

None
log_level int

Logging level. Defaults to 20.

20
Source code in vambn/data/make_data.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
@preprocess_app.command()
def merge_raw_data(
    input_folder: Path,
    output_file: Path,
    log_file: Optional[Path] = None,
    log_level: int = 20,
) -> None:
    """
    Merge raw data into a single CSV file.

    Args:
        input_folder (Path): Folder where the files are located.
        output_file (Path): File where the merged data should be stored.
        log_file (Optional[Path], optional): Optional file for logging. Defaults to None.
        log_level (int, optional): Logging level. Defaults to 20.
    """

    setup_logging(level=log_level, log_file=log_file)

    input_files = input_folder.glob("**/*_raw.csv")
    overall_data = merge_csv_files(input_folder, list(input_files), "_raw.csv")
    # sort columns from overall data alphabetically
    overall_data = overall_data.reindex(sorted(overall_data.columns), axis=1)
    overall_data.to_csv(output_file)

merge_stalone_data(input_folder, output_file, log_file=None, log_level=20)

Merge imputed data into a single CSV file.

Parameters:

Name Type Description Default
input_folder Path

Folder where the files are located.

required
output_file Path

File where the merged data should be stored.

required
log_file Optional[Path]

Optional file for logging. Defaults to None.

None
log_level int

Logging level. Defaults to 20.

20
Source code in vambn/data/make_data.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
@preprocess_app.command()
def merge_stalone_data(
    input_folder: Path,
    output_file: Path,
    log_file: Optional[Path] = None,
    log_level: int = 20,
) -> None:
    """
    Merge imputed data into a single CSV file.

    Args:
        input_folder (Path): Folder where the files are located.
        output_file (Path): File where the merged data should be stored.
        log_file (Optional[Path], optional): Optional file for logging. Defaults to None.
        log_level (int, optional): Logging level. Defaults to 20.
    """

    setup_logging(level=log_level, log_file=log_file)

    input_files = input_folder.glob("**/stalone*_imp.csv")
    overall_data = merge_csv_files(input_folder, list(input_files), "_imp.csv")
    overall_data.to_csv(output_file)

modular(decoded_folder, input_file, output_data)

Gather data from decoded files and merge them with the stalone data.

Parameters:

Name Type Description Default
decoded_folder Path

Path to the folder containing the decoded files.

required
input_file Path

Path to the stalone data.

required
output_data Path

Path to the output file.

required
Source code in vambn/data/make_data.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@gather_app.command()
def modular(decoded_folder: Path, input_file: Path, output_data: Path):
    """
    Gather data from decoded files and merge them with the stalone data.

    Args:
        decoded_folder (Path): Path to the folder containing the decoded files.
        input_file (Path): Path to the stalone data.
        output_data (Path): Path to the output file.
    """

    output_data.parent.mkdir(exist_ok=True, parents=True)

    stalone_data = pd.read_csv(input_file)
    decoded_data = read_and_merge(list(decoded_folder.glob("**/*_decoded.csv")))
    # sort columns from decoded data alphabetically
    decoded_data = decoded_data.reindex(sorted(decoded_data.columns), axis=1)
    # assert (
    #     stalone_data.shape[0] == decoded_data.shape[0]
    # ), f"Shapes do not match (stalone: {stalone_data.shape[0]}, decoded: {decoded_data.shape[0]}))"
    merged = pd.merge(
        stalone_data, decoded_data, on=["SUBJID", "VISIT"], how="outer"
    )
    merged.to_csv(output_data, index=False)

read_and_merge(files)

Read all files and merge them on the columns SUBJID and VISIT.

Parameters:

Name Type Description Default
files List[Path]

List of files to read.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Merged dataframe.

Source code in vambn/data/make_data.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def read_and_merge(files: List[Path]) -> pd.DataFrame:
    """
    Read all files and merge them on the columns SUBJID and VISIT.

    Args:
        files (List[Path]): List of files to read.

    Returns:
        pd.DataFrame: Merged dataframe.
    """

    data = []
    for file in files:
        tmp = pd.read_csv(str(file))
        data.append(tmp)

    data = reduce(
        lambda x, y: pd.merge(x, y, on=["SUBJID", "VISIT"], how="outer"), data
    )
    return data

traditional(decoded_folders, input_file, output_data)

Gather data from decoded files and merge them with the stalone data.

Parameters:

Name Type Description Default
decoded_folders List[Path]

List of folders containing the decoded files.

required
input_file Path

Path to the stalone data.

required
output_data Path

Path to the output file.

required
Source code in vambn/data/make_data.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
@gather_app.command()
def traditional(
    decoded_folders: List[Path],
    input_file: Path,
    output_data: Path,
):
    """
    Gather data from decoded files and merge them with the stalone data.

    Args:
        decoded_folders (List[Path]): List of folders containing the decoded files.
        input_file (Path): Path to the stalone data.
        output_data (Path): Path to the output file.
    """

    output_data.parent.mkdir(exist_ok=True, parents=True)

    stalone_data = pd.read_csv(input_file)
    decoded_data = read_and_merge(
        [n for x in decoded_folders for n in x.glob("**/*_decoded.csv")]
    )
    # sort columns from decoded data alphabetically
    decoded_data = decoded_data.reindex(sorted(decoded_data.columns), axis=1)
    # assert (
    #     stalone_data.shape[0] == decoded_data.shape[0]
    # ), f"Shapes do not match (stalone: {stalone_data.shape[0]}, decoded: {decoded_data.shape[0]}))"
    merged = pd.merge(
        stalone_data, decoded_data, on=["SUBJID", "VISIT"], how="outer"
    )

    merged.to_csv(output_data, index=False)

scalers

LogStdScaler

Bases: BaseEstimator, TransformerMixin

A custom scaler that applies a log transformation followed by standard scaling.

This class is deprecated and will be removed soon.

Attributes:

Name Type Description
scaler StandardScaler

The standard scaler used after the log transformation.

Source code in vambn/data/scalers.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class LogStdScaler(BaseEstimator, TransformerMixin):
    """
    A custom scaler that applies a log transformation followed by standard scaling.

    This class is deprecated and will be removed soon.

    Attributes:
        scaler (StandardScaler): The standard scaler used after the log transformation.
    """

    def __init__(self):
        """
        Initializes the LogStdScaler.
        """
        self.scaler = None
        warnings.warn(
            "This class is deprecated and will be removed soon",
            DeprecationWarning,
        )

    def fit(self, X, y=None):
        """
        Fits the scaler to the data after applying a log transformation.

        Args:
            X (array-like): The data to fit.
            y (None, optional): Ignored.

        Returns:
            LogStdScaler: The fitted scaler.
        """
        X_log = np.log1p(X)
        self.scaler = StandardScaler().fit(X_log)
        return self

    def transform(self, X, y=None):
        """
        Transforms the data using the fitted scaler after applying a log transformation.

        Args:
            X (array-like): The data to transform.
            y (None, optional): Ignored.

        Returns:
            array-like: The transformed data.

        Raises:
            RuntimeError: If the scaler has not been fitted yet.
        """
        if self.scaler is None:
            raise RuntimeError(
                "You must fit the scaler before transforming data"
            )
        X_log = np.log1p(X)
        X_scaled = self.scaler.transform(X_log)
        return X_scaled

    def inverse_transform(self, X, y=None):
        """
        Inversely transforms the data using the fitted scaler.

        Args:
            X (array-like): The data to inverse transform.
            y (None, optional): Ignored.

        Returns:
            array-like: The inversely transformed data.
        """
        X_reversed = self.scaler.inverse_transform(X)
        return X_reversed

    def __str__(self) -> str:
        """
        Returns a string representation of the LogStdScaler.

        Returns:
            str: The string representation.
        """
        return "LogStdScaler"

__init__()

Initializes the LogStdScaler.

Source code in vambn/data/scalers.py
17
18
19
20
21
22
23
24
25
def __init__(self):
    """
    Initializes the LogStdScaler.
    """
    self.scaler = None
    warnings.warn(
        "This class is deprecated and will be removed soon",
        DeprecationWarning,
    )

__str__()

Returns a string representation of the LogStdScaler.

Returns:

Name Type Description
str str

The string representation.

Source code in vambn/data/scalers.py
78
79
80
81
82
83
84
85
def __str__(self) -> str:
    """
    Returns a string representation of the LogStdScaler.

    Returns:
        str: The string representation.
    """
    return "LogStdScaler"

fit(X, y=None)

Fits the scaler to the data after applying a log transformation.

Parameters:

Name Type Description Default
X array - like

The data to fit.

required
y None

Ignored.

None

Returns:

Name Type Description
LogStdScaler

The fitted scaler.

Source code in vambn/data/scalers.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def fit(self, X, y=None):
    """
    Fits the scaler to the data after applying a log transformation.

    Args:
        X (array-like): The data to fit.
        y (None, optional): Ignored.

    Returns:
        LogStdScaler: The fitted scaler.
    """
    X_log = np.log1p(X)
    self.scaler = StandardScaler().fit(X_log)
    return self

inverse_transform(X, y=None)

Inversely transforms the data using the fitted scaler.

Parameters:

Name Type Description Default
X array - like

The data to inverse transform.

required
y None

Ignored.

None

Returns:

Type Description

array-like: The inversely transformed data.

Source code in vambn/data/scalers.py
64
65
66
67
68
69
70
71
72
73
74
75
76
def inverse_transform(self, X, y=None):
    """
    Inversely transforms the data using the fitted scaler.

    Args:
        X (array-like): The data to inverse transform.
        y (None, optional): Ignored.

    Returns:
        array-like: The inversely transformed data.
    """
    X_reversed = self.scaler.inverse_transform(X)
    return X_reversed

transform(X, y=None)

Transforms the data using the fitted scaler after applying a log transformation.

Parameters:

Name Type Description Default
X array - like

The data to transform.

required
y None

Ignored.

None

Returns:

Type Description

array-like: The transformed data.

Raises:

Type Description
RuntimeError

If the scaler has not been fitted yet.

Source code in vambn/data/scalers.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def transform(self, X, y=None):
    """
    Transforms the data using the fitted scaler after applying a log transformation.

    Args:
        X (array-like): The data to transform.
        y (None, optional): Ignored.

    Returns:
        array-like: The transformed data.

    Raises:
        RuntimeError: If the scaler has not been fitted yet.
    """
    if self.scaler is None:
        raise RuntimeError(
            "You must fit the scaler before transforming data"
        )
    X_log = np.log1p(X)
    X_scaled = self.scaler.transform(X_log)
    return X_scaled