Skip to content

Utilities

config

GeneralConfig

Configuration for general settings.

Attributes:

Name Type Description
seed int

The random seed value.

eval_batch_size int

The batch size for evaluation.

logging LoggingConfig

The logging configuration.

device str

The device to run the computations on (e.g., 'cpu', 'cuda').

optuna_db Optional[str]

Optional database string for Optuna.

Source code in vambn/utils/config.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@dataclass
class GeneralConfig:
    """Configuration for general settings.

    Attributes:
        seed: The random seed value.
        eval_batch_size: The batch size for evaluation.
        logging: The logging configuration.
        device: The device to run the computations on (e.g., 'cpu', 'cuda').
        optuna_db: Optional database string for Optuna.
    """

    seed: int
    eval_batch_size: int
    logging: LoggingConfig
    device: str
    optuna_db: Optional[str]

LoggingConfig

Configuration class for logging settings.

Attributes:

Name Type Description
level int

The logging level.

mlflow MlflowConfig

The MLflow configuration.

Source code in vambn/utils/config.py
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass
class LoggingConfig:
    """
    Configuration class for logging settings.

    Attributes:
        level (int): The logging level.
        mlflow (MlflowConfig): The MLflow configuration.
    """

    level: int
    mlflow: MlflowConfig

MlflowConfig

Configuration class for MLflow settings.

Attributes:

Name Type Description
use bool

Whether to use MLflow for logging.

tracking_uri str

The URI of the MLflow tracking server.

experiment_name str

The name of the MLflow experiment.

Source code in vambn/utils/config.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@dataclass
class MlflowConfig:
    """
    Configuration class for MLflow settings.

    Attributes:
        use (bool): Whether to use MLflow for logging.
        tracking_uri (str): The URI of the MLflow tracking server.
        experiment_name (str): The name of the MLflow experiment.
    """

    use: bool
    tracking_uri: str
    experiment_name: str

OptimizationConfig

Configuration class for optimization settings.

Attributes:

Name Type Description
max_epochs int

The maximum number of epochs.

folds int

The number of folds for cross-validation.

n_modular_trials int

The number of trials for modular models.

n_traditional_trials int

The number of trials for traditional models.

s_dim_lower int

The lower bound of the s dimension.

s_dim_upper int

The upper bound of the s dimension.

s_dim_step int

The step size for the s dimension.

fixed_s_dim bool

Whether the s dimension is fixed.

y_dim_lower int

The lower bound of the y dimension.

y_dim_upper int

The upper bound of the y dimension.

y_dim_step int

The step size for the y dimension.

fixed_y_dim bool

Whether the y dimension is fixed.

latent_dim_lower int

The lower bound of the latent dimension.

latent_dim_upper int

The upper bound of the latent dimension.

latent_dim_step int

The step size for the latent dimension.

batch_size_lower_n int

The lower bound of the batch size.

batch_size_upper_n int

The upper bound of the batch size.

learning_rate_lower float

The lower bound of the learning rate.

learning_rate_upper float

The upper bound of the learning rate.

fixed_learning_rate bool

Whether the learning rate is fixed.

lstm_layers_lower int

The lower bound of the LSTM layers.

lstm_layers_upper int

The upper bound of the LSTM layers.

lstm_layers_step int

The step size for the LSTM layers.

use_relative_correlation_error_for_optimization bool

Whether to use relative correlation error for optimization.

use_auc_for_optimization bool

Whether to use AUC for optimization.

Source code in vambn/utils/config.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
@dataclass
class OptimizationConfig:
    """
    Configuration class for optimization settings.

    Attributes:
        max_epochs (int): The maximum number of epochs.
        folds (int): The number of folds for cross-validation.
        n_modular_trials (int): The number of trials for modular models.
        n_traditional_trials (int): The number of trials for traditional models.
        s_dim_lower (int): The lower bound of the s dimension.
        s_dim_upper (int): The upper bound of the s dimension.
        s_dim_step (int): The step size for the s dimension.
        fixed_s_dim (bool): Whether the s dimension is fixed.
        y_dim_lower (int): The lower bound of the y dimension.
        y_dim_upper (int): The upper bound of the y dimension.
        y_dim_step (int): The step size for the y dimension.
        fixed_y_dim (bool): Whether the y dimension is fixed.
        latent_dim_lower (int): The lower bound of the latent dimension.
        latent_dim_upper (int): The upper bound of the latent dimension.
        latent_dim_step (int): The step size for the latent dimension.
        batch_size_lower_n (int): The lower bound of the batch size.
        batch_size_upper_n (int): The upper bound of the batch size.
        learning_rate_lower (float): The lower bound of the learning rate.
        learning_rate_upper (float): The upper bound of the learning rate.
        fixed_learning_rate (bool): Whether the learning rate is fixed.
        lstm_layers_lower (int): The lower bound of the LSTM layers.
        lstm_layers_upper (int): The upper bound of the LSTM layers.
        lstm_layers_step (int): The step size for the LSTM layers.
        use_relative_correlation_error_for_optimization (bool): Whether to use relative correlation error for optimization.
        use_auc_for_optimization (bool): Whether to use AUC for optimization.
    """

    max_epochs: int
    folds: int
    n_modular_trials: int
    n_traditional_trials: int
    s_dim_lower: int
    s_dim_upper: int
    s_dim_step: int
    fixed_s_dim: bool
    y_dim_lower: int
    y_dim_upper: int
    y_dim_step: int
    fixed_y_dim: bool
    latent_dim_lower: int
    latent_dim_upper: int
    latent_dim_step: int
    batch_size_lower_n: int
    batch_size_upper_n: int
    learning_rate_lower: float
    learning_rate_upper: float
    fixed_learning_rate: bool
    lstm_layers_lower: int
    lstm_layers_upper: int
    lstm_layers_step: int
    use_relative_correlation_error_for_optimization: bool
    use_auc_for_optimization: bool

PipelineConfig

Configuration for the pipeline settings.

Attributes:

Name Type Description
general GeneralConfig

General configuration settings.

optimization OptimizationConfig

Optimization configuration settings.

training TrainingConfig

Training configuration settings.

Source code in vambn/utils/config.py
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass
class PipelineConfig:
    """Configuration for the pipeline settings.

    Attributes:
        general: General configuration settings.
        optimization: Optimization configuration settings.
        training: Training configuration settings.
    """

    general: GeneralConfig
    optimization: OptimizationConfig
    training: TrainingConfig

TrainingConfig

Configuration for training settings.

Attributes:

Name Type Description
use_imputation_layer bool

Whether to use an imputation layer.

use_mtl bool

Whether to use multi-task learning.

with_gan bool

Whether to use a GAN.

Source code in vambn/utils/config.py
119
120
121
122
123
124
125
126
127
128
129
130
131
@dataclass
class TrainingConfig:
    """Configuration for training settings.

    Attributes:
        use_imputation_layer: Whether to use an imputation layer.
        use_mtl: Whether to use multi-task learning.
        with_gan: Whether to use a GAN.
    """

    use_imputation_layer: bool
    use_mtl: bool
    with_gan: bool

exceptions

InvalidSamples

Bases: Exception

Excpetion Type for invalid samples in the HIVAE.

Source code in vambn/utils/exceptions.py
1
2
3
4
class InvalidSamples(Exception):
    """Excpetion Type for invalid samples in the HIVAE."""

    pass

helpers

AggregatedMetric

Class to aggregate and compute average of float metrics.

Source code in vambn/utils/helpers.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class AggregatedMetric:
    """Class to aggregate and compute average of float metrics."""

    def __init__(self) -> None:
        self._values = []

    def __add__(self, new_value: float) -> None:
        """Adds a new value to the metric list.

        Args:
            new_value: The new float value to add.
        """
        self._values.append(new_value)

    def __call__(self) -> float:
        """Computes the average of the aggregated values.

        Returns:
            The average value of the aggregated metrics.
        """
        return sum(self._values) / len(self._values)

__add__(new_value)

Adds a new value to the metric list.

Parameters:

Name Type Description Default
new_value float

The new float value to add.

required
Source code in vambn/utils/helpers.py
29
30
31
32
33
34
35
def __add__(self, new_value: float) -> None:
    """Adds a new value to the metric list.

    Args:
        new_value: The new float value to add.
    """
    self._values.append(new_value)

__call__()

Computes the average of the aggregated values.

Returns:

Type Description
float

The average value of the aggregated metrics.

Source code in vambn/utils/helpers.py
37
38
39
40
41
42
43
def __call__(self) -> float:
    """Computes the average of the aggregated values.

    Returns:
        The average value of the aggregated metrics.
    """
    return sum(self._values) / len(self._values)

AggregatedTorchMetric

Class to aggregate and compute average of torch.Tensor metrics.

Source code in vambn/utils/helpers.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class AggregatedTorchMetric:
    """Class to aggregate and compute average of torch.Tensor metrics."""

    def __init__(self) -> None:
        self._values = []

    def __add__(self, new_value: torch.Tensor) -> "AggregatedTorchMetric":
        """Adds a new tensor value to the metric list.

        Args:
            new_value: The new tensor value to add.

        Returns:
            self: The updated AggregatedTorchMetric object.
        """
        self._values.append(new_value)
        return self

    def __call__(self) -> torch.Tensor:
        """Computes the average of the aggregated tensor values.

        Returns:
            The average tensor value of the aggregated metrics.
        """
        return torch.mean(torch.stack(self._values))

__add__(new_value)

Adds a new tensor value to the metric list.

Parameters:

Name Type Description Default
new_value Tensor

The new tensor value to add.

required

Returns:

Name Type Description
self AggregatedTorchMetric

The updated AggregatedTorchMetric object.

Source code in vambn/utils/helpers.py
52
53
54
55
56
57
58
59
60
61
62
def __add__(self, new_value: torch.Tensor) -> "AggregatedTorchMetric":
    """Adds a new tensor value to the metric list.

    Args:
        new_value: The new tensor value to add.

    Returns:
        self: The updated AggregatedTorchMetric object.
    """
    self._values.append(new_value)
    return self

__call__()

Computes the average of the aggregated tensor values.

Returns:

Type Description
Tensor

The average tensor value of the aggregated metrics.

Source code in vambn/utils/helpers.py
64
65
66
67
68
69
70
def __call__(self) -> torch.Tensor:
    """Computes the average of the aggregated tensor values.

    Returns:
        The average tensor value of the aggregated metrics.
    """
    return torch.mean(torch.stack(self._values))

NaNHandlingStrategy

Bases: Enum

Enumeration of strategies for handling NaN values.

Source code in vambn/utils/helpers.py
14
15
16
17
18
19
20
class NaNHandlingStrategy(Enum):
    """Enumeration of strategies for handling NaN values."""

    accept_inbalance = "accept_inbalance"
    sample_random = "sample_random"
    sample_closest = "sample_closest"
    encode_nan = "encode_nan"

column_is_categorical(col)

Determines if a column is categorical.

Parameters:

Name Type Description Default
col Series

The column to check.

required

Returns:

Type Description
bool

True if the column is categorical, False otherwise.

Source code in vambn/utils/helpers.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def column_is_categorical(col: pd.Series) -> bool:
    """Determines if a column is categorical.

    Args:
        col: The column to check.

    Returns:
        True if the column is categorical, False otherwise.
    """
    if col.dtype == "object":
        return True
    # ugly heuristic for encoded categorical values
    elif np.sum(np.asarray(col)) % 1 == 0 and np.max(np.asarray(col)) < 10:
        return True
    else:
        return False

delete_directory(dir_path)

Deletes a directory and all its contents.

Parameters:

Name Type Description Default
dir_path Path

Path to the directory to delete.

required
Source code in vambn/utils/helpers.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def delete_directory(dir_path: Path) -> None:
    """Deletes a directory and all its contents.

    Args:
        dir_path: Path to the directory to delete.
    """
    dir_path = Path(dir_path)
    if dir_path.exists():
        for item in dir_path.iterdir():
            if item.is_dir():
                # Recursively delete subdirectories
                for subitem in item.rglob("*"):
                    if subitem.is_file():
                        subitem.unlink()
                item.rmdir()
            else:
                # Delete files
                item.unlink()
        dir_path.rmdir()
    else:
        logger.warning("Directory does not exist")

encode_numerical_columns(patient_data)

Encodes non-numeric columns of a DataFrame with categorical values.

Parameters:

Name Type Description Default
patient_data DataFrame

The DataFrame with patient data.

required

Returns:

Type Description
DataFrame

A DataFrame with non-numeric columns encoded as categorical values.

Source code in vambn/utils/helpers.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def encode_numerical_columns(patient_data: pd.DataFrame) -> pd.DataFrame:
    """Encodes non-numeric columns of a DataFrame with categorical values.

    Args:
        patient_data: The DataFrame with patient data.

    Returns:
        A DataFrame with non-numeric columns encoded as categorical values.
    """
    tmp = patient_data.copy()
    for column in tmp:
        if not is_numeric_dtype(tmp[column]):
            tmp[column] = tmp[column].astype("category")
            tmp[column] = tmp[column].cat.codes
    return tmp

get_normalized_vector_distance(vec1, vec2)

Computes the normalized distance between two vectors.

Parameters:

Name Type Description Default
vec1 ndarray

The first vector.

required
vec2 ndarray

The second vector.

required

Returns:

Type Description
float

The normalized distance between the two vectors.

Source code in vambn/utils/helpers.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def get_normalized_vector_distance(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Computes the normalized distance between two vectors.

    Args:
        vec1: The first vector.
        vec2: The second vector.

    Returns:
        The normalized distance between the two vectors.
    """
    diff = np.abs(vec1 - vec2)
    max = np.maximum.reduce([vec1, vec2])
    min = np.minimum.reduce([vec1, vec2])
    range = np.abs(max) + np.abs(min)
    quotient = np.divide(diff, range)
    quotient[np.isnan(quotient)] = 0
    sum = np.sum(quotient)
    return sum

get_vector_to_mixed_matrix_distance(vec, matrix)

Computes the distance from a vector to a mixed-type matrix.

Parameters:

Name Type Description Default
vec ndarray

The vector to compare.

required
matrix ndarray

The matrix to compare against.

required

Returns:

Type Description
ndarray

An array of distances.

Source code in vambn/utils/helpers.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def get_vector_to_mixed_matrix_distance(
    vec: np.ndarray, matrix: np.ndarray
) -> np.ndarray:
    """Computes the distance from a vector to a mixed-type matrix.

    Args:
        vec: The vector to compare.
        matrix: The matrix to compare against.

    Returns:
        An array of distances.
    """
    # get mask for categorical columns
    cat_cols = [column_is_categorical(col) for col in matrix.T]
    # number of different categorical columns for each row of df
    cat_distances = np.sum(matrix[:, cat_cols] == vec[cat_cols], axis=1)
    num_distances = np.array(
        [get_normalized_vector_distance(vec, row) for row in matrix]
    )
    return cat_distances + num_distances

handle_nan_values(real, virtual, strategy=NaNHandlingStrategy.sample_random)

Handles NaN values in two dataframes according to the specified strategy.

Parameters:

Name Type Description Default
real DataFrame

The real dataframe.

required
virtual DataFrame

The virtual dataframe.

required
strategy NaNHandlingStrategy

The strategy to use for handling NaN values.

sample_random

Returns:

Type Description
tuple[DataFrame, DataFrame]

A tuple containing the processed real and virtual dataframes.

Source code in vambn/utils/helpers.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def handle_nan_values(
    real: pd.DataFrame,
    virtual: pd.DataFrame,
    strategy: NaNHandlingStrategy = NaNHandlingStrategy.sample_random,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Handles NaN values in two dataframes according to the specified strategy.

    Args:
        real: The real dataframe.
        virtual: The virtual dataframe.
        strategy: The strategy to use for handling NaN values.

    Returns:
        A tuple containing the processed real and virtual dataframes.
    """
    np.random.seed(42)

    # convert real and virtual to pandas if they are numpy arrays
    if isinstance(real, np.ndarray) or isinstance(real, list):
        real = pd.DataFrame(real)

    if isinstance(virtual, np.ndarray) or isinstance(virtual, list):
        virtual = pd.DataFrame(virtual)

    if not (real.isnull().values.any() or virtual.isnull().values.any()):
        return real, virtual

    if strategy == NaNHandlingStrategy.accept_inbalance:
        return real.dropna(), virtual.dropna()

    if strategy == NaNHandlingStrategy.sample_random:
        # Drop all rows with high NaN ratio
        real = real.dropna(axis=0, thresh=0.8 * real.shape[1])
        virtual = virtual.dropna(axis=0, thresh=0.8 * virtual.shape[1])

        # Drop columns with high NaN ratio
        real = real.dropna(axis=1, thresh=0.8 * real.shape[0])
        # use the same columns for real and virtual
        virtual = virtual[real.columns]

        # remove all rows containing NaN values
        real = real.dropna()
        virtual = virtual.dropna()
        # subsample in such a way that each dataframe has the same amount of rows
        if real.shape[0] < virtual.shape[0]:
            virtual = virtual.loc[
                np.random.choice(virtual.index, real.shape[0], replace=False)
            ]
        elif virtual.shape[0] < real.shape[0]:
            real = real.loc[
                np.random.choice(real.index, virtual.shape[0], replace=False)
            ]

    elif strategy == NaNHandlingStrategy.sample_closest:
        # remove all rows containing NaN values
        real = real.dropna()
        virtual = virtual.dropna()
        # order columns such that identical align in order
        real = real.reindex(sorted(real.columns), axis=1)
        virtual = virtual.reindex(sorted(virtual.columns), axis=1)
        # sample in the bigger set (either real or virtual) the data points that are most similar
        sample_idx = []
        if real.shape[0] < virtual.shape[0]:
            for a in real.to_numpy():
                distance = get_vector_to_mixed_matrix_distance(
                    a, virtual.to_numpy()
                )
                sample_idx.append(np.argmin(distance))
            virtual = virtual.iloc[sample_idx]
        elif virtual.shape[0] < real.shape[0]:
            for a in virtual.to_numpy():
                distance = get_vector_to_mixed_matrix_distance(
                    a, real.to_numpy()
                )
                sample_idx.append(np.argmin(distance))
            real = real.iloc[sample_idx]

    elif strategy == NaNHandlingStrategy.encode_nan:
        for col in real:
            sum_real_na = real[col].isna().sum()
            sum_virtual_na = virtual[col].isna().sum()
            sum_diff = abs(sum_virtual_na - sum_real_na)
            if sum_real_na > sum_virtual_na:
                # sample indices to replace with NaN
                replace_idx = rnd.sample(range(1, virtual.shape[0]), sum_diff)
                virtual.loc[replace_idx, col] = None
            elif sum_real_na < sum_virtual_na:
                # sample indices to replace with NaN
                replace_idx = rnd.sample(range(1, real.shape[0]), sum_diff)
                real.loc[replace_idx, col] = None
            # encode NaN
            cat_col_bool = [
                column_is_categorical(col) for col in real.dropna().to_numpy().T
            ]
            cat_cols = np.array(real.columns)[cat_col_bool]
            num_cols = np.array(real.columns)[np.invert(cat_col_bool)]
            for col in cat_cols:
                real[col] = real[col].fillna(real[col].max() + 1)
                virtual[col] = virtual[col].fillna(virtual[col].max() + 1)
            for col in num_cols:
                real[col] = real[col].fillna(0)
                virtual[col] = virtual[col].fillna(0)

    return real, virtual

logging

setup_logging(level, log_file=None)

Setup logging to stdout or a specified file.

Parameters:

Name Type Description Default
level int

The logging level.

required
log_file Optional[Path]

The file where logs should be saved. Defaults to None.

None
Source code in vambn/utils/logging.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def setup_logging(level: int, log_file: Optional[Path] = None) -> None:
    """
    Setup logging to stdout or a specified file.

    Args:
        level (int): The logging level.
        log_file (Optional[Path], optional): The file where logs should be saved. Defaults to None.

    """
    if log_file is None:
        logging.basicConfig(
            stream=sys.stdout,
            level=level,
            format="%(asctime)s [%(levelname)s] %(message)s",
        )
        logger.info("Logging to stdout")
    else:
        log_file.parent.mkdir(exist_ok=True, parents=True)
        file_handler = RotatingFileHandler(
            str(log_file),
            maxBytes=10_000_000,
            backupCount=10,
        )
        # Create a format with the file name in the log
        formatter = logging.Formatter(
            "[%(asctime)s] {%(module)s:%(lineno)d} %(levelname)s - %(message)s",
            "%m-%d %H:%M:%S",
        )

        file_handler.setFormatter(formatter)
        logging.basicConfig(
            level=level,
            handlers=[file_handler],
            format="%(asctime)s [%(levelname)s] %(message)s",
        )

        try:
            pl_logger = logging.getLogger("lightning.pytorch")
            pl_logger.handlers.clear()
            pl_logger.setLevel(level=level)
            pl_logger.addHandler(file_handler)
            pl_logger.info("Logging Lightning to file")
        except Exception as e:
            logger.error(f"Could not setup Lightning logging: {e}")

        logger.info(f"Logging to {log_file}")

syndat_conversion

main(original_input, decoded_input, synthetic_input, original_output, decoded_output, synthetic_output)

Processes and pivots the original, decoded, and synthetic CSV files, then saves the results.

Parameters:

Name Type Description Default
original_input Path

Path to the original input CSV file.

required
decoded_input Path

Path to the decoded input CSV file.

required
synthetic_input Path

Path to the synthetic input CSV file.

required
original_output Path

Path where the processed original output CSV file will be saved.

required
decoded_output Path

Path where the processed decoded output CSV file will be saved.

required
synthetic_output Path

Path where the processed synthetic output CSV file will be saved.

required
Source code in vambn/utils/syndat_conversion.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def main(
    original_input: Path,
    decoded_input: Path,
    synthetic_input: Path,
    original_output: Path,
    decoded_output: Path,
    synthetic_output: Path,
) -> None:
    """
    Processes and pivots the original, decoded, and synthetic CSV files, then saves the results.

    Args:
        original_input (Path): Path to the original input CSV file.
        decoded_input (Path): Path to the decoded input CSV file.
        synthetic_input (Path): Path to the synthetic input CSV file.
        original_output (Path): Path where the processed original output CSV file will be saved.
        decoded_output (Path): Path where the processed decoded output CSV file will be saved.
        synthetic_output (Path): Path where the processed synthetic output CSV file will be saved.
    """
    # Read input CSV files
    synthetic = pd.read_csv(synthetic_input)
    decoded = pd.read_csv(decoded_input)
    original = pd.read_csv(original_input)

    # Order by SUBJID and VISIT
    synthetic = synthetic.sort_values(["SUBJID", "VISIT"])
    decoded = decoded.sort_values(["SUBJID", "VISIT"])
    original = original.sort_values(["SUBJID", "VISIT"])

    # Get the common columns of synthetic and decoded
    common_cols = synthetic.columns.intersection(decoded.columns)
    # Subset all dataframes to the common columns
    synthetic = synthetic[common_cols]
    decoded = decoded[common_cols]
    original = original[common_cols]

    def pivot_to_wide(df: pd.DataFrame) -> pd.DataFrame:
        """
        Pivots a dataframe to wide format by SUBJID and VISIT.

        Args:
            df (pd.DataFrame): The input dataframe.

        Returns:
            pd.DataFrame: The pivoted dataframe.
        """
        pivot_df = df.pivot(index="SUBJID", columns="VISIT")
        # Add _VIS[visit_number] to the column names
        pivot_df.columns = [
            f"{col[0]}_VIS{str(col[1])}" for col in pivot_df.columns.values
        ]
        # Drop SUBJID and potentially VISIT
        pivot_df = pivot_df.reset_index().drop("SUBJID", axis=1)
        return pivot_df

    # Pivot the dataframes to wide format
    synthetic = pivot_to_wide(synthetic)
    decoded = pivot_to_wide(decoded)
    original = pivot_to_wide(original)

    # Save the processed data
    synthetic.to_csv(synthetic_output, index=False)
    decoded.to_csv(decoded_output, index=False)
    original.to_csv(original_output, index=False)

trial_counter

main(db_url, study_name)

Get the number of completed or pruned trials and return to bash.

Parameters:

Name Type Description Default
db_url str

URL to the Optuna database.

required
study_name str

Name of the study.

required
Source code in vambn/utils/trial_counter.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def main(db_url: str, study_name: str) -> None:
    """
    Get the number of completed or pruned trials and return to bash.

    Args:
        db_url (str): URL to the Optuna database.
        study_name (str): Name of the study.
    """
    try:
        study = optuna.load_study(study_name=study_name, storage=db_url)
        completed_or_pruned_trials = sum(
            trial.state in [TrialState.COMPLETE, TrialState.PRUNED]
            for trial in study.trials
        )

        typer.echo(completed_or_pruned_trials)
    except KeyError:
        # In case the study is not found, output 0
        typer.echo(0)