Skip to content

Visualization

calculate_metrics

calculate_auc(grouping, original_data, decoded_file, virtual_file, auc_file)

Calculate the Area Under the Curve (AUC) for the given datasets.

Parameters:

Name Type Description Default
grouping Path

Path to the grouping CSV file.

required
original_data Path

Path to the original data CSV file.

required
decoded_file Path

Path to the decoded data CSV file.

required
virtual_file Path

Path to the virtual data CSV file.

required
auc_file Path

Path to save the AUC results CSV file.

required
Source code in vambn/visualization/calculate_metrics.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
@app.command()
def calculate_auc(
    grouping: Path,
    original_data: Path,
    decoded_file: Path,
    virtual_file: Path,
    auc_file: Path,
) -> None:
    """
    Calculate the Area Under the Curve (AUC) for the given datasets.

    Args:
        grouping (Path): Path to the grouping CSV file.
        original_data (Path): Path to the original data CSV file.
        decoded_file (Path): Path to the decoded data CSV file.
        virtual_file (Path): Path to the virtual data CSV file.
        auc_file (Path): Path to save the AUC results CSV file.
    """
    setup_logging(level=logging.INFO)

    grouping_df = pd.read_csv(grouping)
    module_cols = grouping_df.loc[
        ~grouping_df["technical_group_name"].str.match("stalone_"),
        "column_names",
    ]
    module_cols = set(module_cols.to_list())
    modules = set(grouping_df["technical_group_name"].to_list())
    aucs = []
    original_df = pd.read_csv(original_data)
    decoded_df = pd.read_csv(decoded_file)
    virtual_df = pd.read_csv(virtual_file)

    # Relevant columns
    relevant_columns = list(
        set(grouping_df["column_names"].tolist())
        .intersection(original_df.columns.tolist())
        .intersection(decoded_df.columns.tolist())
        .intersection(virtual_df.columns.tolist())
    ) + ["SUBJID", "VISIT"]

    # Order by visit
    original_df = original_df.sort_values("VISIT").loc[:, relevant_columns]
    decoded_df = decoded_df.sort_values("VISIT").loc[:, relevant_columns]
    virtual_df = virtual_df.sort_values("VISIT").loc[:, relevant_columns]

    # Sort and reindex
    original_df = original_df.reindex(sorted(original_df.columns), axis=1)
    decoded_df = decoded_df.reindex(sorted(decoded_df.columns), axis=1)
    virtual_df = virtual_df.reindex(sorted(virtual_df.columns), axis=1)

    assert original_df.shape[1] == decoded_df.shape[1]
    assert original_df.shape[1] == virtual_df.shape[1]

    original_base = original_df.loc[
        original_df["VISIT"] == 1, relevant_columns
    ].drop(columns=["VISIT", "SUBJID"])
    decoded_base = decoded_df.loc[
        decoded_df["VISIT"] == 1, relevant_columns
    ].drop(columns=["VISIT", "SUBJID"])
    virtual_base = virtual_df.loc[
        virtual_df["VISIT"] == 1, relevant_columns
    ].drop(columns=["VISIT", "SUBJID"])

    logger.info("Calculate AUC for all modules and baseline")
    logger.info("Calculate AUC for real vs decoded")
    pauc_decoded, auc_decoded, n_dec = get_auc(
        original_base, decoded_base, n_folds=5
    )
    logger.info("Calculate AUC for real vs virtual")
    pauc_virtual, auc_virtual, n_vir = get_auc(
        original_base, virtual_base, n_folds=5
    )
    logger.info("Calculate AUC for decoded vs virtual")
    pauc_virVdec, auc_virVdec, n_virVdec = get_auc(
        decoded_base, virtual_base, n_folds=5
    )
    aucs.append(
        {
            "module": "all-modules-baseline",
            "pauc_virtual": pauc_decoded,
            "pauc_decoded": pauc_virtual,
            "pauc_virVdec": pauc_virVdec,
            "auc_decoded": auc_decoded,
            "auc_virtual": auc_virtual,
            "auc_virVdec": auc_virVdec,
            "n_virtual": n_vir,
            "n_decoded": n_dec,
            "n_virVdec": n_virVdec,
        }
    )

    # Calculate per module AUC
    for module in modules:
        if "stalone" in module:
            continue
        logger.info(f"Calculating AUC for {module}")

        module_subset = grouping_df.loc[
            grouping_df["technical_group_name"] == module, "column_names"
        ]
        module_subset = set(module_subset.to_list())

        common_cols = module_subset.intersection(original_df.columns)
        common_cols = list(common_cols)
        if len(common_cols) == 0:
            continue

        subset_original = original_df.loc[:, common_cols + ["VISIT"]].dropna()
        max_visit = subset_original["VISIT"].max()
        subset_decoded = decoded_df.loc[
            decoded_df["VISIT"] <= max_visit, common_cols + ["VISIT"]
        ].dropna()
        subset_virtual = virtual_df.loc[
            virtual_df["VISIT"] <= max_visit, common_cols + ["VISIT"]
        ].dropna()
        assert (
            subset_original["VISIT"].unique().tolist()
            == subset_decoded["VISIT"].unique().tolist()
        )
        assert (
            subset_original["VISIT"].unique().tolist()
            == subset_virtual["VISIT"].unique().tolist()
        )

        for col in ["SUBJID", "VISIT"]:
            if col in subset_decoded.columns:
                subset_decoded.drop(col, axis=1, inplace=True)
            if col in subset_virtual.columns:
                subset_virtual.drop(col, axis=1, inplace=True)
            if col in subset_original.columns:
                subset_original.drop(col, axis=1, inplace=True)

        assert subset_original.isna().sum().sum() == 0
        assert subset_decoded.isna().sum().sum() == 0
        assert subset_virtual.isna().sum().sum() == 0
        assert (
            subset_original.columns.to_list()
            == subset_decoded.columns.to_list()
        )

        logger.info("Calculate AUC for real vs decoded")
        pauc_decoded, auc_decoded, n_dec = get_auc(
            subset_original, subset_decoded, n_folds=5
        )

        logger.info("Calculate AUC for real vs virtual")
        pauc_virtual, auc_virtual, n_vir = get_auc(
            subset_original, subset_virtual, n_folds=5
        )

        logger.info("Calculate AUC for decoded vs virtual")
        pauc_virVdec, auc_virVdec, n_virVdec = get_auc(
            subset_decoded, subset_virtual, n_folds=5
        )
        aucs.append(
            {
                "module": module,
                "pauc_virtual": auc_decoded,
                "pauc_decoded": auc_virtual,
                "pauc_virVdec": auc_virVdec,
                "n_virtual": n_vir,
                "n_decoded": n_dec,
                "n_virVdec": n_virVdec,
            }
        )
        logger.info(
            f"AUC for {module} - Decoded: {auc_decoded}, Virtual: {auc_virtual}, Decoded vs Virtual: {auc_virVdec}"
        )

    df = pd.DataFrame(aucs)
    auc_file.parent.mkdir(exist_ok=True, parents=True)
    df.to_csv(auc_file, index=False)

    # Calculate AUC for each module and visit individually
    specific_aucs = []
    for module in modules:
        if "stalone" in module:
            continue
        logger.info(f"Calculating AUC for {module}")
        module_subset = grouping_df.loc[
            grouping_df["technical_group_name"] == module, "column_names"
        ]
        module_subset = set(module_subset.to_list())

        common_cols = module_subset.intersection(original_df.columns)
        common_cols = list(common_cols)
        if len(common_cols) == 0:
            continue

        subset_original = original_df.loc[:, common_cols + ["VISIT"]].dropna()
        max_visit = subset_original["VISIT"].max()
        subset_decoded = decoded_df.loc[
            decoded_df["VISIT"] <= max_visit, common_cols + ["VISIT"]
        ].dropna()
        subset_virtual = virtual_df.loc[
            virtual_df["VISIT"] <= max_visit, common_cols + ["VISIT"]
        ].dropna()
        assert (
            subset_original["VISIT"].unique().tolist()
            == subset_decoded["VISIT"].unique().tolist()
        )
        assert (
            subset_original["VISIT"].unique().tolist()
            == subset_virtual["VISIT"].unique().tolist()
        )

        assert subset_original.isna().sum().sum() == 0
        assert subset_decoded.isna().sum().sum() == 0
        assert subset_virtual.isna().sum().sum() == 0
        assert (
            subset_original.columns.to_list()
            == subset_decoded.columns.to_list()
        )
        for visit in subset_original["VISIT"].unique():
            logger.info(f"Calculating AUC for {module} - Visit {visit}")
            subset_original_visit = subset_original.loc[
                subset_original["VISIT"] == visit
            ]
            subset_decoded_visit = subset_decoded.loc[
                subset_decoded["VISIT"] == visit
            ]
            subset_virtual_visit = subset_virtual.loc[
                subset_virtual["VISIT"] == visit
            ]

            subset_original_visit = drop_irrelevant_columns(
                subset_original_visit
            )
            subset_decoded_visit = drop_irrelevant_columns(subset_decoded_visit)
            subset_virtual_visit = drop_irrelevant_columns(subset_virtual_visit)

            pauc_decoded, auc_decoded, n_dec = get_auc(
                subset_original_visit, subset_decoded_visit, n_folds=5
            )
            pauc_virtual, auc_virtual, n_vir = get_auc(
                subset_original_visit, subset_virtual_visit, n_folds=5
            )
            pauc_virVdec, auc_virVdec, n_virVdec = get_auc(
                subset_decoded_visit, subset_virtual_visit, n_folds=5
            )

            specific_aucs.append(
                {
                    "module": module,
                    "visit": visit,
                    "pauc_virtual": pauc_decoded,
                    "pauc_decoded": pauc_virtual,
                    "pauc_virVdec": pauc_virVdec,
                    "n_virtual": n_vir,
                    "n_decoded": n_dec,
                    "n_virVdec": n_virVdec,
                    "auc_decoded": auc_decoded,
                    "auc_virtual": auc_virtual,
                    "auc_virVdec": auc_virVdec,
                }
            )

    df = pd.DataFrame(specific_aucs)
    specific_auc_path = auc_file.parent / "specific_auc.csv"
    df.to_csv(specific_auc_path, index=False)

calculate_corr_error(grouping, original_data, decoded_data, virtual_data, all_heatmap_virtual, all_heatmap_decoded, cont_heatmap_virtual, cont_heatmap_decoded, result_file, dataset_name, experiment)

Calculate the correlation error and generate heatmaps.

Parameters:

Name Type Description Default
grouping Path

Path to the grouping CSV file.

required
original_data Path

Path to the original data CSV file.

required
decoded_data Path

Path to the decoded data CSV file.

required
virtual_data Path

Path to the virtual data CSV file.

required
all_heatmap_virtual Path

Path to save the heatmap for all virtual data.

required
all_heatmap_decoded Path

Path to save the heatmap for all decoded data.

required
cont_heatmap_virtual Path

Path to save the heatmap for continuous virtual data.

required
cont_heatmap_decoded Path

Path to save the heatmap for continuous decoded data.

required
result_file Path

Path to save the results JSON file.

required
dataset_name str

The name of the dataset.

required
experiment str

The name of the experiment.

required
Source code in vambn/visualization/calculate_metrics.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
@app.command()
def calculate_corr_error(
    grouping: Path,
    original_data: Path,
    decoded_data: Path,
    virtual_data: Path,
    all_heatmap_virtual: Path,
    all_heatmap_decoded: Path,
    cont_heatmap_virtual: Path,
    cont_heatmap_decoded: Path,
    result_file: Path,
    dataset_name: str,
    experiment: str,
) -> None:
    """
    Calculate the correlation error and generate heatmaps.

    Args:
        grouping (Path): Path to the grouping CSV file.
        original_data (Path): Path to the original data CSV file.
        decoded_data (Path): Path to the decoded data CSV file.
        virtual_data (Path): Path to the virtual data CSV file.
        all_heatmap_virtual (Path): Path to save the heatmap for all virtual data.
        all_heatmap_decoded (Path): Path to save the heatmap for all decoded data.
        cont_heatmap_virtual (Path): Path to save the heatmap for continuous virtual data.
        cont_heatmap_decoded (Path): Path to save the heatmap for continuous decoded data.
        result_file (Path): Path to save the results JSON file.
        dataset_name (str): The name of the dataset.
        experiment (str): The name of the experiment.
    """
    # Create necessary directories
    all_heatmap_virtual.parent.mkdir(exist_ok=True, parents=True)
    all_heatmap_decoded.parent.mkdir(exist_ok=True, parents=True)
    cont_heatmap_virtual.parent.mkdir(exist_ok=True, parents=True)
    cont_heatmap_decoded.parent.mkdir(exist_ok=True, parents=True)

    # Read and process grouping data
    groups = pd.read_csv(grouping)
    subset = groups.loc[
        ~groups["technical_group_name"].str.match("stalone_"), :
    ]
    continous_cols = set(
        subset.loc[
            subset["hivae_types"].isin(
                ["pos", "real", "truncate_norm", "count", "gamma"]
            ),
            "column_names",
        ].tolist()
    )
    module_cols = set(subset["column_names"].to_list())

    # Read data and check overlap with column sets
    initial = pd.read_csv(original_data)
    virtual = pd.read_csv(virtual_data)
    decoded = pd.read_csv(decoded_data)

    initial_cols = set(initial.columns.to_list())
    decoded_cols = set(decoded.columns.to_list())
    general_overlap = decoded_cols.intersection(initial_cols)
    all_vars_subset = list(general_overlap.intersection(module_cols))
    continous_subset = list(set(all_vars_subset).intersection(continous_cols))

    # Remove rows with NaN values
    initial_nan = initial.loc[:, all_vars_subset].isna().any(axis=1)
    subset_initial = initial.loc[~initial_nan, all_vars_subset]
    subset_virtual = virtual.loc[:, all_vars_subset].dropna()
    subset_decoded = decoded.loc[:, all_vars_subset].dropna()

    assert subset_initial.shape[1] == subset_virtual.shape[1]
    assert subset_initial.shape[1] == subset_decoded.shape[1]

    # SPEARMAN Correlation Error Calculation
    (
        spearman_corr_error_virtual,
        corr_initial,
        corr_virtual,
    ) = RelativeCorrelation.error(subset_initial, subset_virtual)

    assert (
        corr_initial.shape[1] == len(all_vars_subset)
    ), f"corr_initial.shape[1] = {corr_initial.shape[1]}, hivae_subset.__len__() = {len(all_vars_subset)}"
    assert (
        corr_virtual.shape[1] == len(all_vars_subset)
    ), f"corr_virtual.shape[1] = {corr_virtual.shape[1]}, hivae_subset.__len__() = {len(all_vars_subset)}"

    m_init = corr_initial.reset_index().melt("index")
    m_init["type"] = "Real"

    m_virtual = corr_virtual.reset_index().melt("index")
    m_virtual["type"] = "Virtual"

    merged = pd.concat([m_init, m_virtual])
    merged["value"] = merged["value"].astype(float)
    category_order = ["Real", "Virtual"]
    merged["type"] = pd.Categorical(
        merged["type"], categories=category_order, ordered=True
    )
    g = (
        ggplot(
            data=merged,
            mapping=aes(x="index", y="variable", fill="value"),
        )
        + geom_tile()
        + facet_wrap("type")
        + labs(
            title=f"Relative correlation error: {spearman_corr_error_virtual}",
            x="",
            y="",
        )
        + scale_fill_gradient2(
            low="darkblue", mid="lightgrey", high="darkred", midpoint=0
        )
        + theme_bw()
        + theme(axis_text=element_blank(), axis_ticks=element_blank())
    )
    g.save(
        str(all_heatmap_virtual), dpi=300, width=21.7, height=21.7, units="cm"
    )

    (
        spearman_corr_error_decoded,
        corr_initial,
        corr_decoded,
    ) = RelativeCorrelation.error(subset_initial, subset_decoded)

    m_init = corr_initial.reset_index().melt("index")
    m_init["type"] = "Real"

    m_decoded = corr_decoded.reset_index().melt("index")
    m_decoded["type"] = "Decoded"

    merged = pd.concat([m_init, m_decoded])
    merged["value"] = merged["value"].astype(float)
    category_order = ["Real", "Decoded"]
    merged["type"] = pd.Categorical(
        merged["type"], categories=category_order, ordered=True
    )

    g = (
        ggplot(
            data=merged,
            mapping=aes(x="index", y="variable", fill="value"),
        )
        + geom_tile()
        + facet_wrap("type")
        + labs(
            title=f"Relative correlation error: {spearman_corr_error_decoded}",
            x="",
            y="",
        )
        + scale_fill_gradient2(
            low="darkblue", mid="lightgrey", high="darkred", midpoint=0
        )
        + theme_bw()
        + theme(axis_text=element_blank(), axis_ticks=element_blank())
    )
    g.save(
        str(all_heatmap_decoded), dpi=300, width=21.7, height=21.7, units="cm"
    )

    # PEARSON Correlation Error Calculation for Continuous Data
    continous_initial = subset_initial.loc[:, continous_subset]
    continous_virtual = subset_virtual.loc[:, continous_subset]
    continous_decoded = subset_decoded.loc[:, continous_subset]

    (
        pearson_corr_error_virtual,
        corr_initial,
        corr_virtual,
    ) = RelativeCorrelation.error(
        continous_initial, continous_virtual, method="pearson"
    )

    assert (
        corr_initial.shape[1] == len(continous_subset)
    ), f"corr_initial.shape[1] = {corr_initial.shape[1]}, hivae_subset.__len__() = {len(continous_subset)}"
    assert (
        corr_virtual.shape[1] == len(continous_subset)
    ), f"corr_virtual.shape[1] = {corr_virtual.shape[1]}, hivae_subset.__len__() = {len(continous_subset)}"

    m_init = corr_initial.reset_index().melt("index")
    m_init["type"] = "Real"

    m_virtual = corr_virtual.reset_index().melt("index")
    m_virtual["type"] = "Virtual"

    merged = pd.concat([m_init, m_virtual])
    merged["value"] = merged["value"].astype(float)
    category_order = ["Real", "Virtual"]
    merged["type"] = pd.Categorical(
        merged["type"], categories=category_order, ordered=True
    )
    g = (
        ggplot(
            data=merged,
            mapping=aes(x="index", y="variable", fill="value"),
        )
        + geom_tile()
        + facet_wrap("type")
        + labs(
            title=f"Relative correlation error: {pearson_corr_error_virtual}",
            x="",
            y="",
        )
        + scale_fill_gradient2(
            low="darkblue", mid="lightgrey", high="darkred", midpoint=0
        )
        + theme_bw()
        + theme(axis_text=element_blank(), axis_ticks=element_blank())
    )
    g.save(
        str(cont_heatmap_virtual), dpi=300, width=21.7, height=21.7, units="cm"
    )

    (
        pearson_corr_error_decoded,
        corr_initial,
        corr_decoded,
    ) = RelativeCorrelation.error(continous_initial, continous_decoded)

    m_init = corr_initial.reset_index().melt("index")
    m_init["type"] = "Real"

    m_decoded = corr_decoded.reset_index().melt("index")
    m_decoded["type"] = "Decoded"

    merged = pd.concat([m_init, m_decoded])
    merged["value"] = merged["value"].astype(float)
    category_order = ["Real", "Decoded"]
    merged["type"] = pd.Categorical(
        merged["type"], categories=category_order, ordered=True
    )

    g = (
        ggplot(
            data=merged,
            mapping=aes(x="index", y="variable", fill="value"),
        )
        + geom_tile()
        + facet_wrap("type")
        + labs(
            title=f"Relative correlation error: {pearson_corr_error_decoded}",
            x="",
            y="",
        )
        + scale_fill_gradient2(
            low="darkblue", mid="lightgrey", high="darkred", midpoint=0
        )
        + theme_bw()
        + theme(axis_text=element_blank(), axis_ticks=element_blank())
    )
    g.save(
        str(cont_heatmap_decoded), dpi=300, width=21.7, height=21.7, units="cm"
    )

    numeric_results = {
        "spearman_relcorr_virtual": spearman_corr_error_virtual,
        "spearman_relcorr_decoded": spearman_corr_error_decoded,
        "pearson_relcorr_virtual": pearson_corr_error_virtual,
        "pearson_relcorr_decoded": pearson_corr_error_decoded,
        "dataset": dataset_name,
        "experiment": experiment,
    }
    result_file.parent.mkdir(exist_ok=True, parents=True)
    with result_file.open("w+") as f:
        f.write(json.dumps(numeric_results))

drop_irrelevant_columns(df, cols=['SUBJID', 'VISIT'])

Drop irrelevant columns from the DataFrame.

Parameters:

Name Type Description Default
df DataFrame

The input DataFrame.

required
cols List[str]

List of columns to drop. Defaults to ["SUBJID", "VISIT"].

['SUBJID', 'VISIT']

Returns:

Type Description
DataFrame

pd.DataFrame: The DataFrame with specified columns dropped.

Source code in vambn/visualization/calculate_metrics.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def drop_irrelevant_columns(
    df: pd.DataFrame, cols: List[str] = ["SUBJID", "VISIT"]
) -> pd.DataFrame:
    """
    Drop irrelevant columns from the DataFrame.

    Args:
        df (pd.DataFrame): The input DataFrame.
        cols (List[str], optional): List of columns to drop. Defaults to ["SUBJID", "VISIT"].

    Returns:
        pd.DataFrame: The DataFrame with specified columns dropped.
    """
    cols = [x for x in cols if x in df.columns]
    return df.drop(columns=cols)

generate_jsd_plot(grouping, original_data, comparison_file, jsd_plot)

Generate Jensen-Shannon Distance plot.

Parameters:

Name Type Description Default
grouping Path

Path to the grouping CSV file.

required
original_data Path

Path to the original data CSV file.

required
comparison_file Path

Path to the comparison data CSV file.

required
jsd_plot Path

Path to save the Jensen-Shannon Distance plot.

required
Source code in vambn/visualization/calculate_metrics.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
@app.command()
def generate_jsd_plot(
    grouping: Path, original_data: Path, comparison_file: Path, jsd_plot: Path
) -> None:
    """
    Generate Jensen-Shannon Distance plot.

    Args:
        grouping (Path): Path to the grouping CSV file.
        original_data (Path): Path to the original data CSV file.
        comparison_file (Path): Path to the comparison data CSV file.
        jsd_plot (Path): Path to save the Jensen-Shannon Distance plot.
    """
    jsd_plot.parent.mkdir(exist_ok=True, parents=True)

    grouping_df = pd.read_csv(grouping)
    module_cols = grouping_df.loc[
        ~grouping_df["technical_group_name"].str.match("stalone_"),
        "column_names",
    ]
    module_cols = set(module_cols.to_list())

    original_df = pd.read_csv(original_data)
    compared_df = pd.read_csv(comparison_file)

    original_cols = set(original_df.columns.to_list())
    compared_cols = set(compared_df.columns.to_list())
    general_overlap = compared_cols.intersection(original_cols)
    module_subset = list(general_overlap.intersection(module_cols))

    subset_original = original_df.loc[:, module_subset]
    subset_compared = compared_df.loc[:, module_subset]

    df_list = []
    for col in module_subset:
        orig_series = subset_original.loc[:, col]
        if orig_series.isna().any():
            logger.info(
                f"Warning: {col} contains NaN values (n = {orig_series.isna().sum()})"
            )
            orig_series.dropna(inplace=True)
        vec_original = orig_series.to_numpy()
        compared_series = subset_compared.loc[:, col]
        if compared_series.isna().any():
            logger.info(
                f"Warning: {col} contains NaN values (n = {compared_series.isna().sum()})"
            )
            compared_series.dropna(inplace=True)
        vec_compared = compared_series.to_numpy()

        dtype = grouping_df.loc[
            grouping_df["column_names"] == col, "hivae_types"
        ].tolist()[0]
        if dtype == "categorical":
            dtype = "cat"
        module = grouping_df.loc[
            grouping_df["column_names"] == col, "technical_group_name"
        ].tolist()[0]
        jsd = jensen_shannon_distance(vec_original, vec_compared, dtype)

        df_list.append(
            {"col": col, "type": dtype, "jsd": jsd, "module": module}
        )

    df = pd.DataFrame(df_list)
    g = (
        ggplot(data=df, mapping=aes(x="module", y="jsd"))
        + geom_boxplot()
        + geom_jitter(alpha=0.6, size=1, color="black")
        + labs(
            title="Distribution of JSDs",
            x="Module",
            y="Jensen-Shannon Distance",
        )
        + ylim(0, 1)
        + theme_bw()
        + theme(axis_text_x=element_text(angle=45, hjust=1, vjust=1))
    )
    g.save(jsd_plot)

distribution

compare_data(original_file, decoded_file, virtual_file, output_data, output_data_dec, metric_file, grouping, dataset_name, var)

Compare data from original, decoded, and virtual sources, and generate metrics and visualizations.

Parameters:

Name Type Description Default
original_file Path

Path to the original data CSV file.

required
decoded_file Path

Path to the decoded data CSV file.

required
virtual_file Path

Path to the virtual data CSV file.

required
output_data Path

Path to save the PDF report of all data comparisons.

required
output_data_dec Path

Path to save the PDF report of original vs. decoded comparisons.

required
metric_file Path

Path to save the JSON file with Jensen-Shannon distances.

required
grouping Path

Path to the grouping CSV file.

required
dataset_name str

Name of the dataset.

required
var str

Experiment variable name.

required
Source code in vambn/visualization/distribution.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
@app.command()
def compare_data(
    original_file: Path,
    decoded_file: Path,
    virtual_file: Path,
    output_data: Path,
    output_data_dec: Path,
    metric_file: Path,
    grouping: Path,
    dataset_name: str,
    var: str,
) -> None:
    """
    Compare data from original, decoded, and virtual sources, and generate metrics and visualizations.

    Args:
        original_file (Path): Path to the original data CSV file.
        decoded_file (Path): Path to the decoded data CSV file.
        virtual_file (Path): Path to the virtual data CSV file.
        output_data (Path): Path to save the PDF report of all data comparisons.
        output_data_dec (Path): Path to save the PDF report of original vs. decoded comparisons.
        metric_file (Path): Path to save the JSON file with Jensen-Shannon distances.
        grouping (Path): Path to the grouping CSV file.
        dataset_name (str): Name of the dataset.
        var (str): Experiment variable name.
    """
    groups = pd.read_csv(grouping)
    # Read in the data
    decoded_data = remove_vis(pd.read_csv(decoded_file))
    virtual_data = pd.read_csv(virtual_file)

    # Read original data
    original_data = pd.read_csv(original_file, index_col=0).reset_index()

    # Reduce dataframes to overlapping columns
    common_cols = list(
        set(original_data.columns)
        & set(decoded_data.columns)
        & set(virtual_data.columns)
    )
    original_data = original_data[common_cols]
    decoded_data = decoded_data[common_cols]
    virtual_data = virtual_data[common_cols]

    # Clip inf values to 1e20 and -1e20
    decoded_data = decoded_data.replace([np.inf, -np.inf], [1e20, -1e20])
    virtual_data = virtual_data.replace([np.inf, -np.inf], [1e20, -1e20])

    jsd_list = []

    output_data.parent.mkdir(exist_ok=True, parents=True)
    # Open the pdf file
    all_folder = output_data.parent / "all_data"
    all_folder.mkdir(exist_ok=True, parents=True)
    with PdfPages(output_data) as pdf:
        for col in common_cols:
            if col in ["SUBJID", "VISIT"]:
                continue

            dtype = groups.loc[
                groups["column_names"] == col, "hivae_types"
            ].tolist()[0]
            plt.figure()
            sub_orig = (
                original_data.reset_index()
                .loc[:, ["SUBJID", "VISIT", col]]
                .dropna()
            )
            max_visit = sub_orig["VISIT"].max()
            orig = sub_orig[col].tolist()
            dec = decoded_data.reset_index().loc[
                decoded_data["VISIT"] <= max_visit, ["SUBJID", "VISIT", col]
            ]
            dec.rename(columns={col: f"{col}_dec"}, inplace=True)
            dec = (
                pd.merge(sub_orig, dec, on=["SUBJID", "VISIT"], how="left")
                .drop(columns=col)[f"{col}_dec"]
                .tolist()
            )
            vir = virtual_data.loc[
                virtual_data["VISIT"] <= max_visit, col
            ].tolist()

            # Identify outliers
            lower_bound = min(orig) * 0.5 if min(orig) < 0 else min(orig) * 1.5
            if lower_bound == 0:
                lower_bound = -0.5

            max_length = max(len(orig), len(dec), len(vir))
            orig = pad_with_nan(orig, max_length)
            dec = pad_with_nan(dec, max_length)
            vir = pad_with_nan(vir, max_length)

            wide = pd.DataFrame(
                {
                    "Real (original)": orig,
                    "Decoded": dec,
                    "Virtual": vir,
                }
            )
            plot_df = wide.melt().rename(
                columns={"variable": "type", "value": "value"}
            )

            dec_jsd = round(jensen_shannon_distance(orig, dec, dtype), 3)
            vir_jsd = round(jensen_shannon_distance(orig, vir, dtype), 3)

            jsd_list.append(
                {
                    "column": col,
                    "jsd_decoded": dec_jsd,
                    "jsd_virtual": vir_jsd,
                    "dataset_var": dataset_name,
                    "experiment": var,
                }
            )

            title = (
                f"Distribution for {col}\n Decoded {dec_jsd}; Virtual {vir_jsd}"
            )

            plt.figure()
            if dtype in ["cat", "categorical", "count"]:
                plot_df_freq = (
                    plot_df.groupby("type")["value"]
                    .value_counts(normalize=True)
                    .rename("frequency")
                    .reset_index()
                )

                # Plot using sns.barplot
                sns.barplot(
                    x="value",
                    y="frequency",
                    hue="type",
                    data=plot_df_freq,
                    dodge=True,
                )
                plt.xlabel(col)
                plt.ylabel("Frequency")
                plt.title(title)
                plt.legend()

            else:
                print(
                    f"Plotting {col}, min value in plot_df: {plot_df['value'].min()}, max value in plot_df: {plot_df['value'].max()}"
                )
                plt.subplot(211)
                sns.violinplot(x="type", y="value", data=plot_df.dropna())
                plt.ylabel("Value")
                plt.title(title)
                plt.legend()

                plt.subplot(212)
                plt.axis("off")
                plt.table(
                    cellText=wide.describe().round(1).values,
                    colLabels=wide.columns,
                    rowLabels=wide.describe().index,
                    loc="center",
                )

            plt.tight_layout()

            pdf.savefig()
            plt.savefig(all_folder / f"{col}.png", dpi=300)

            plt.close()

    # Open the pdf file
    indiv_folder = output_data_dec.parent / "original_vs_decoded"
    indiv_folder.mkdir(exist_ok=True, parents=True)
    with PdfPages(output_data_dec) as pdf:
        for col in common_cols:
            if col in ["SUBJID", "VISIT"]:
                continue

            dtype = groups.loc[
                groups["column_names"] == col, "hivae_types"
            ].tolist()[0]
            sub_orig = (
                original_data.reset_index()
                .loc[:, ["SUBJID", "VISIT", col]]
                .dropna()
            )
            max_visit = sub_orig["VISIT"].max()
            orig = sub_orig[col].tolist()
            dec = decoded_data.reset_index().loc[
                decoded_data["VISIT"] <= max_visit, ["SUBJID", "VISIT", col]
            ]
            dec.rename(columns={col: f"{col}_dec"}, inplace=True)
            dec = (
                pd.merge(sub_orig, dec, on=["SUBJID", "VISIT"], how="left")
                .drop(columns=col)[f"{col}_dec"]
                .tolist()
            )

            # Identify outliers
            lower_bound = min(orig) * 0.5 if min(orig) < 0 else min(orig) * 1.5
            if lower_bound == 0:
                lower_bound = -0.5

            max_length = max(len(orig), len(dec))
            orig = pad_with_nan(orig, max_length)
            dec = pad_with_nan(dec, max_length)

            plt.figure()

            wide = pd.DataFrame(
                {
                    "Real (original)": orig,
                    "Decoded": dec,
                }
            )
            plot_df = wide.melt().rename(
                columns={"variable": "type", "value": "value"}
            )

            dec_jsd = round(jensen_shannon_distance(orig, dec, dtype), 3)
            try:
                orig, dec = handle_nan_values(orig, dec)
                orig = orig.iloc[:, 0].to_numpy()
                dec = dec.iloc[:, 0].to_numpy()

                if np.isinf(dec).any():
                    # replace inf with 1e20 and -inf with -1e20
                    dec = dec.replace([np.inf, -np.inf], [1e20, -1e20])
                if dtype in ("pos", "real", "count", "truncate_norm", "gamma"):
                    corr, pval = pearsonr(orig, dec)
                    title = f"Distribution for {col}\n JSD: {dec_jsd}; Correlation: {round(corr, 3)} / {round(pval, 4)}, type: {dtype}"
                elif dtype == "cat" or dtype == "categorical":
                    # calculate accuracy for categorical data
                    acc = accuracy_score(orig, dec)
                    title = f"Distribution for {col}\n JSD: {dec_jsd}; Accuracy: {round(acc, 3)}, type: {dtype}"
                else:
                    raise Exception(f"Unknown dtype: {dtype}")
            except ValueError:
                raise ValueError(
                    f"Orig: {len(orig)}, Dec: {len(dec)}, any nan? {pd.isna(orig).any()} {pd.isna(dec).any()}"
                )

            if dtype in ["cat", "categorical", "count"]:
                plot_df_freq = (
                    plot_df.groupby("type")["value"]
                    .value_counts(normalize=True)
                    .rename("frequency")
                    .reset_index()
                )

                # Plot using sns.barplot
                ax = sns.barplot(
                    x="value", y="frequency", hue="type", data=plot_df_freq
                )
                plt.xlabel(col)
                plt.ylabel("Frequency")
                plt.title(title)
                plt.legend()
            else:
                # make a subfigure for both the violin plot and the table
                ax = plt.subplot(211)

                # use the first subplot for the violin plot
                ax = sns.violinplot(x="type", y="value", data=plot_df)
                plt.ylabel("Value")
                plt.title(title)
                plt.legend()

                # use the second subplot for the table
                ax = plt.subplot(212)
                ax.axis("off")
                ax.table(
                    cellText=wide.describe().round(1).values,
                    colLabels=wide.columns,
                    rowLabels=wide.describe().index,
                    loc="center",
                )

            plt.tight_layout()
            plt.savefig(indiv_folder / f"{col}.png", dpi=300)
            pdf.savefig()
            plt.close()

    metric_file.parent.mkdir(exist_ok=True, parents=True)
    with metric_file.open("w+") as f:
        f.write(json.dumps(jsd_list, indent=4))

pad_with_nan(x, num)

Pad a list with NaN values to a specified length.

Parameters:

Name Type Description Default
x List[Any]

The input list.

required
num int

The desired length of the list.

required

Returns:

Type Description
List[Any]

List[Any]: The padded list.

Source code in vambn/visualization/distribution.py
34
35
36
37
38
39
40
41
42
43
44
45
def pad_with_nan(x: List[Any], num: int) -> List[Any]:
    """
    Pad a list with NaN values to a specified length.

    Args:
        x (List[Any]): The input list.
        num (int): The desired length of the list.

    Returns:
        List[Any]: The padded list.
    """
    return x + [np.nan] * (num - len(x))

remove_vis(data)

Remove '_VIS1' from column names in the DataFrame.

Parameters:

Name Type Description Default
data DataFrame

The input DataFrame.

required

Returns:

Type Description
DataFrame

pd.DataFrame: The DataFrame with modified column names.

Source code in vambn/visualization/distribution.py
20
21
22
23
24
25
26
27
28
29
30
31
def remove_vis(data: pd.DataFrame) -> pd.DataFrame:
    """
    Remove '_VIS1' from column names in the DataFrame.

    Args:
        data (pd.DataFrame): The input DataFrame.

    Returns:
        pd.DataFrame: The DataFrame with modified column names.
    """
    data.columns = [x.replace("_VIS1", "") for x in data.columns]
    return data

generate_jsd_plot

main(results_file, output_file)

Generate a boxplot with jitter from a results CSV file and save it as an image.

Parameters:

Name Type Description Default
results_file Path

Path to the input CSV file containing results.

required
output_file Path

Path to save the output image file.

required
Source code in vambn/visualization/generate_jsd_plot.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def main(results_file: Path, output_file: Path) -> None:
    """
    Generate a boxplot with jitter from a results CSV file and save it as an image.

    Args:
        results_file (Path): Path to the input CSV file containing results.
        output_file (Path): Path to save the output image file.
    """
    assert output_file.parent.exists(), "Output directory does not exist."

    plot_df = pd.read_csv(str(results_file))

    g = (
        ggplot(plot_df, aes("module_name", "jsd", color="column"))
        + geom_boxplot()
        + geom_jitter(alpha=0.6, size=1, color="black")
        + coord_flip()
        + theme_bw()
    )
    g.save(str(output_file), dpi=300)

generate_optuna_plots

plot_study_results(study_uri, study_name, output_folder)

Load the Optuna study from the SQLite database and save various graphics about the conducted study.

Parameters:

Name Type Description Default
study_uri str

The URI of the SQLite database where the study is stored.

required
study_name str

The name of the study to load.

required
output_folder Path

The directory where the plots will be saved.

required

Raises:

Type Description
RuntimeError

If no trials are found in the study or if the variance of trials is zero.

Source code in vambn/visualization/generate_optuna_plots.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@app.command()
def plot_study_results(study_uri: str, study_name: str, output_folder: Path):
    """
    Load the Optuna study from the SQLite database and save various graphics about the conducted study.

    Args:
        study_uri (str): The URI of the SQLite database where the study is stored.
        study_name (str): The name of the study to load.
        output_folder (Path): The directory where the plots will be saved.

    Raises:
        RuntimeError: If no trials are found in the study or if the variance of trials is zero.
    """
    # Load study
    study = optuna.load_study(study_name=study_name, storage=study_uri)

    # Ensure the output folder exists
    output_folder.mkdir(parents=True, exist_ok=True)

    # Check if it's a single or multi-objective study
    is_multi_obj = len(study.directions) > 1

    # Plot param importances
    try:
        fig = V.plot_param_importances(study)
        fig.write_image(os.path.join(output_folder, "param_importances.png"))

        if not is_multi_obj:
            # Plot optimization history
            fig = V.plot_optimization_history(study)
            fig.write_image(
                os.path.join(output_folder, "optimization_history.png")
            )

            # Plot parallel coordinate
            fig = V.plot_parallel_coordinate(study)
            fig.write_image(
                os.path.join(output_folder, "parallel_coordinate.png")
            )

            # Plot slice
            fig = V.plot_slice(study)
            fig.write_image(os.path.join(output_folder, "slice.png"))
        else:
            fig = V.plot_pareto_front(study)
            fig.write_image(os.path.join(output_folder, "pareto_front.png"))

            # Generate the other plots per target
            for i in range(len(study.directions)):

                def target(trial):
                    return trial.values[i]

                target_name = (
                    study.metric_names[i]
                    if study.metric_names
                    else f"Objective {i}"
                )

                # Plot optimization history
                fig = V.plot_optimization_history(
                    study, target=target, target_name=target_name
                )
                fig.write_image(
                    os.path.join(output_folder, f"optimization_history_{i}.png")
                )

                # Plot parallel coordinate
                fig = V.plot_parallel_coordinate(
                    study, target=target, target_name=target_name
                )
                fig.write_image(
                    os.path.join(output_folder, f"parallel_coordinate_{i}.png")
                )

                # Plot slice
                fig = V.plot_slice(
                    study, target=target, target_name=target_name
                )
                fig.write_image(os.path.join(output_folder, f"slice_{i}.png"))

        typer.echo(f"Plots saved in {output_folder}.")
    except RuntimeError:
        logger.error("No trials found in study or variance equals 0.")

generate_results_plot

merge_results(grouping_file, input_files, output_csv, output_plot)

Merge results from multiple JSON and CSV files and generate summary plots.

Parameters:

Name Type Description Default
grouping_file Path

The path to the CSV file containing column name mappings.

required
input_files List[Path]

A list of input JSON and CSV files to process.

required
output_csv Path

The path to the output CSV file where merged results will be saved.

required
output_plot Path

The path to the output plot file where the generated plot will be saved.

required

Raises:

Type Description
ValueError

If there are no input files or if a column has multiple mappings.

Exception

If JSON files contain unsupported data types or lists.

Source code in vambn/visualization/generate_results_plot.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
@app.command()
def merge_results(
    grouping_file: Path,
    input_files: List[Path],
    output_csv: Path,
    output_plot: Path,
):
    """
    Merge results from multiple JSON and CSV files and generate summary plots.

    Args:
        grouping_file (Path): The path to the CSV file containing column name mappings.
        input_files (List[Path]): A list of input JSON and CSV files to process.
        output_csv (Path): The path to the output CSV file where merged results will be saved.
        output_plot (Path): The path to the output plot file where the generated plot will be saved.

    Raises:
        ValueError: If there are no input files or if a column has multiple mappings.
        Exception: If JSON files contain unsupported data types or lists.
    """
    json_files = [x for x in input_files if x.suffix == ".json"]
    csv_files = [x for x in input_files if x.suffix == ".csv"]
    grouping = pd.read_csv(grouping_file)
    # Derive mapping from grouping file
    colname_map = {
        column: grouping.loc[
            grouping["column_names"] == column, "technical_group_name"
        ]
        .drop_duplicates()
        .tolist()
        for column in grouping["column_names"]
    }
    for key, value in colname_map.items():
        if len(value) > 1:
            raise ValueError(f"Column {key} has multiple mappings {value}")
        else:
            colname_map[key] = value[0]

    if len(json_files) == 0 and len(csv_files) == 0:
        raise ValueError("No input files found")

    corr_metrics = []
    jsd_metrics = []
    for file in tqdm(json_files, desc="Reading JSON files"):
        with file.open("r") as f:
            obj = json.loads(f.read())
            if file.stem == "jsd_metrics":
                tmp = pd.DataFrame(obj)
                tmp.rename(columns={"dataset_var": "dataset"}, inplace=True)
                if "modular" in file.parent.stem:
                    tmp["variant"] = "-".join(file.parent.stem.split("_")[:2])
                else:
                    tmp["variant"] = file.parent.stem.split("_")[0]

                tmp["module"] = tmp["column"].map(lambda x: colname_map[x])

                jsd_metrics.append(tmp)
            elif file.stem == "corr_metrics":
                if "modular" in file.parent.stem:
                    obj["variant"] = "-".join(file.parent.stem.split("_")[:2])
                else:
                    obj["variant"] = file.parent.stem.split("_")[0]

                # Ensure that the values are floats and not list
                for key, value in obj.items():
                    if isinstance(value, list) and len(value) > 1:
                        raise Exception(
                            "The JSON file contains lists, which is not supported"
                        )
                    elif isinstance(value, list) and len(value) == 1:
                        obj[key] = value[0]
                    elif isinstance(value, float):
                        pass
                    elif isinstance(value, str):
                        pass
                    else:
                        raise Exception(
                            f"Unsupported type {type(value)} for key {key}"
                        )

                corr_metrics.append(obj)
            else:
                raise ValueError(f"Unknown JSON file {file.stem}")

    corr_df = pd.DataFrame(corr_metrics)
    # Reduce the list of jsd dataframes
    jsd_df = reduce(lambda x, y: pd.concat([x, y]), jsd_metrics)
    agg_jsd = (
        jsd_df.groupby(["dataset", "variant", "experiment", "module"])
        .aggregate(
            {"jsd_virtual": ["mean", "std"], "jsd_decoded": ["mean", "std"]}
        )
        .reset_index()
    )
    # agg jsd has two levels of columns, we need to flatten it
    agg_jsd.columns = ["_".join(x).strip("_") for x in agg_jsd.columns.ravel()]

    csv_objects = []
    for file in tqdm(csv_files, desc="Reading CSV files"):
        tmp = pd.read_csv(file)
        dataset_name = file.parent.stem.split("_")[-3]
        if "modular" in file.parent.stem:
            variant = "-".join(file.parent.stem.split("_")[:2])
        else:
            variant = file.parent.stem.split("_")[0]

        tmp["experiment"] = "_".join(file.parent.stem.split("_")[-2:])

        tmp["dataset"] = dataset_name
        tmp["variant"] = variant
        csv_objects.append(tmp)

    auc_df = reduce(lambda x, y: pd.concat([x, y]), csv_objects)
    module_aucs = auc_df.loc[
        auc_df["module"] != "all-modules-baseline", :
    ].drop(columns=["pauc_virVdec", "n_virtual", "n_decoded", "n_virVdec"])
    auc_df = auc_df.loc[auc_df["module"] == "all-modules-baseline", :].drop(
        columns=["pauc_virVdec", "n_virtual", "n_decoded", "n_virVdec"]
    )
    # Aggregate results
    reshaped = agg_jsd.melt(
        id_vars=["dataset", "variant", "module", "experiment"],
        var_name="metric",
        value_name="value",
    )
    merged_auc_jsd = pd.concat(
        [
            reshaped.loc[
                reshaped["metric"].isin(
                    ["jsd_virtual_mean", "jsd_decoded_mean"]
                ),
                :,
            ],
            module_aucs.melt(
                id_vars=["dataset", "variant", "module", "experiment"],
                var_name="metric",
                value_name="value",
            ),
        ]
    )

    aggregate_over_modules = (
        reshaped.groupby(["dataset", "variant", "experiment", "metric"])
        .aggregate({"value": ["mean", "std"]})
        .reset_index()
    )
    auc_df_reshaped = auc_df.drop(columns="module").melt(
        id_vars=["dataset", "variant", "experiment"],
        var_name="metric",
        value_name="value",
    )

    aggregate_over_modules.columns = [
        "_".join(x).strip("_") for x in aggregate_over_modules.columns.ravel()
    ]
    aggregate_over_modules.rename(
        columns={"value_mean": "value", "value_std": "std"}, inplace=True
    )

    aggregate_over_modules = pd.concat(
        [aggregate_over_modules, auc_df_reshaped]
    )

    corr_df_reshaped = corr_df.melt(
        id_vars=["dataset", "variant", "experiment"],
        var_name="metric",
        value_name="value",
    )

    # Merge the two dataframes
    merged = pd.concat([corr_df_reshaped, aggregate_over_modules])

    def _assign_type(x: str) -> str:
        if "pearson_relcorr" in x:
            return "pearson-corr"
        elif "spearman_relcorr" in x:
            return "spearman-corr"
        elif "jsd" in x:
            return "jsd"
        elif "auc" in x:
            return "auc"
        else:
            raise ValueError(f"Unknown metric type {x}")

    merged_auc_jsd["metric_type"] = merged_auc_jsd["metric"].map(_assign_type)
    merged["metric_type"] = merged["metric"].map(_assign_type)
    # Drop metrics with "_std" suffix
    merged = merged[~merged["metric"].str.contains("_std")]

    merged["data_type"] = merged["metric"].map(
        lambda x: "virtual" if "virtual" in x else "decoded"
    )
    merged_auc_jsd["data_type"] = merged_auc_jsd["metric"].map(
        lambda x: "virtual" if "virtual" in x else "decoded"
    )

    # Normalize the metrics
    def _normalize(x: pd.Series) -> pd.Series:
        if pd.isna(x["value"]) or x["value"] == "NA" or x["value"] == "NaN":
            return math.nan
        elif (
            x["metric_type"] == "pearson-corr"
            or x["metric_type"] == "spearman-corr"
        ):
            x_mod = min(x["value"], 1)
            return math.floor((1 - x_mod) * 100)
        elif x["metric_type"] == "jsd":
            return math.floor((1 - x["value"]) * 100)
        elif x["metric_type"] == "auc":
            auc = x["value"]
            if auc < 0.5:
                auc = 1 - auc

            return max(math.floor((1 - auc) * 200), 1)
        else:
            raise ValueError(f"Unknown metric type {x['metric_type']}")

    merged["normalized_value"] = merged.apply(_normalize, axis=1)
    merged_auc_jsd["normalized_value"] = merged_auc_jsd.apply(
        _normalize, axis=1
    )

    # Remove spearman correlation and rename pearson correlation
    merged = merged[~merged["metric_type"].str.contains("spearman-corr")]
    merged["metric_type"] = merged["metric_type"].map(
        lambda x: "norm" if x == "pearson-corr" else x
    )
    merged.to_csv(output_csv, index=False)

    merged["overall_variant"] = merged["variant"] + "-" + merged["experiment"]
    merged_auc_jsd["overall_variant"] = (
        merged_auc_jsd["variant"] + "-" + merged_auc_jsd["experiment"]
    )

    # Plot the results
    plot = (
        ggplot(
            merged, aes(x="metric_type", y="normalized_value", fill="data_type")
        )
        + geom_bar(stat="identity", position="dodge")
        + facet_wrap("~overall_variant", scales="free", ncol=4)
        + labs(x="Metric", y="Quality Score", fill="Data type")
        + ylim(0, 100)
        + theme_bw()
        + theme(
            axis_text_x=element_text(angle=45, hjust=1),
            strip_text_x=element_text(size=8),
            legend_position="top",
        )
    )
    nrow = len(merged["overall_variant"].unique()) // 4
    height = 8 * nrow

    plot.save(output_plot, width=29, height=height, units="cm", limitsize=False)

    plot = (
        ggplot(
            merged_auc_jsd,
            aes(x="metric_type", y="normalized_value", fill="data_type"),
        )
        + geom_boxplot(
            position="dodge",
            outlier_alpha=0.1,
            outlier_shape=".",
            outlier_size=1,
        )
        + facet_wrap("~overall_variant", scales="free", ncol=4)
        + labs(x="Metric", y="Quality Score", fill="Data type")
        + ylim(0, 100)
        + theme_bw()
        + theme(
            axis_text_x=element_text(angle=45, hjust=1),
            strip_text_x=element_text(size=8),
            legend_position="top",
        )
    )
    nrow = len(merged_auc_jsd["overall_variant"].unique()) // 4
    height = 8 * nrow
    plot.save(
        output_plot.with_name(output_plot.stem + "_boxplot.png"),
        width=29,
        height=height,
        units="cm",
        limitsize=False,
    )

generate_tsne_plot

generate_tsne_plot(grouping_file, real_data, decoded_data, virtual_data, output_decoded, output_virtual, max_samples=1000)

Generate t-SNE plots comparing real, decoded, and virtual data.

Parameters:

Name Type Description Default
grouping_file Path

The path to the CSV file containing column name mappings.

required
real_data Path

The path to the CSV file containing real data.

required
decoded_data Path

The path to the CSV file containing decoded data.

required
virtual_data Path

The path to the CSV file containing virtual data.

required
output_decoded Path

The path to the output file for the decoded t-SNE plot.

required
output_virtual Path

The path to the output file for the virtual t-SNE plot.

required
max_samples int

The maximum number of samples to use for the t-SNE plot.

1000
Source code in vambn/visualization/generate_tsne_plot.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def generate_tsne_plot(
    grouping_file: Path,
    real_data: Path,
    decoded_data: Path,
    virtual_data: Path,
    output_decoded: Path,
    output_virtual: Path,
    max_samples: int = 1000,
):
    """
    Generate t-SNE plots comparing real, decoded, and virtual data.

    Args:
        grouping_file (Path): The path to the CSV file containing column name mappings.
        real_data (Path): The path to the CSV file containing real data.
        decoded_data (Path): The path to the CSV file containing decoded data.
        virtual_data (Path): The path to the CSV file containing virtual data.
        output_decoded (Path): The path to the output file for the decoded t-SNE plot.
        output_virtual (Path): The path to the output file for the virtual t-SNE plot.
        max_samples (int): The maximum number of samples to use for the t-SNE plot.
    """
    grouping = pd.read_csv(grouping_file)
    real = pd.read_csv(real_data)
    decoded = pd.read_csv(decoded_data)
    virtual = pd.read_csv(virtual_data)

    def _prepare_data(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
        """
        Prepare the data by encoding numerical columns and handling inf values.

        Args:
            df (pd.DataFrame): The dataframe to prepare.
            cols (List[str]): List of relevant columns to keep.

        Returns:
            pd.DataFrame: The prepared dataframe.
        """
        x = df.loc[df["VISIT"] == 1, cols]
        x = encode_numerical_columns(x)
        if "subjid" in x.columns:
            x = x.rename(columns={"subjid": "SUBJID"})
        x.sort_values(by=["SUBJID", "VISIT"], inplace=True)
        x = x.replace([float("inf"), float("-inf")], [-1e6, 1e6])
        return x

    # Derive relevant columns from grouping file
    subset_without_stalone = grouping.loc[
        ~grouping["technical_group_name"].str.startswith("stalone"), :
    ]
    relevant_columns = subset_without_stalone["column_names"].tolist()

    # Filter relevant columns from real and synthetic data
    available_columns = [
        col
        for col in relevant_columns
        if col in real.columns
        and col in virtual.columns
        and col in decoded.columns
    ] + ["SUBJID", "VISIT"]
    real = _prepare_data(real, available_columns)
    decoded = _prepare_data(decoded, available_columns)
    virtual = _prepare_data(virtual, available_columns)
    assert real.shape[1] == decoded.shape[1] == virtual.shape[1]

    real_1, decoded = handle_nan_values(real, decoded)
    real_2, virtual = handle_nan_values(real, virtual)

    # Drop SUBJID and VISIT columns
    real_1 = real_1.drop(columns=["SUBJID", "VISIT"])
    real_2 = real_2.drop(columns=["SUBJID", "VISIT"])
    decoded = decoded.drop(columns=["SUBJID", "VISIT"])
    virtual = virtual.drop(columns=["SUBJID", "VISIT"])

    # Sample data if necessary
    if real_1.shape[0] > max_samples:
        real_1 = real_1.sample(n=max_samples, random_state=42)
    if real_2.shape[0] > max_samples:
        real_2 = real_2.sample(n=max_samples, random_state=42)
    if decoded.shape[0] > max_samples:
        decoded = decoded.sample(n=max_samples, random_state=42)
    if virtual.shape[0] > max_samples:
        virtual = virtual.sample(n=max_samples, random_state=42)

    get_tsne_plot_data(output_decoded, real_1, decoded)
    get_tsne_plot_data(output_virtual, real_2, virtual)

    print("Done!")

get_tsne_plot_data(output_file, real_data, virtual_data)

Generate and save t-SNE plot data.

Parameters:

Name Type Description Default
output_file Path

The path to the output file for the t-SNE plot.

required
real_data DataFrame

Dataframe containing real data.

required
virtual_data DataFrame

Dataframe containing virtual data.

required
Source code in vambn/visualization/generate_tsne_plot.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def get_tsne_plot_data(output_file, real_data, virtual_data):
    """
    Generate and save t-SNE plot data.

    Args:
        output_file (Path): The path to the output file for the t-SNE plot.
        real_data (pd.DataFrame): Dataframe containing real data.
        virtual_data (pd.DataFrame): Dataframe containing virtual data.
    """
    x = pd.concat([real_data, virtual_data], ignore_index=True)
    perplexity = 30
    if real_data.shape[1] < 30:
        perplexity = real_data.shape[1] - 1

    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    tsne_result = tsne.fit_transform(x)
    border = real_data.shape[0]
    x_real = tsne_result[:border, 0]
    y_real = tsne_result[:border, 1]
    x_virtual = tsne_result[border:, 0]
    y_virtual = tsne_result[border:, 1]

    fig, ax = plt.subplots()
    ax.scatter(x_real, y_real, c="#d62728", label="real", alpha=0.35, s=5)
    ax.scatter(
        x_virtual, y_virtual, c="#17becf", label="virtual", alpha=0.35, s=5
    )
    ax.set_xlabel("tSNE1")
    ax.set_ylabel("tSNE2")
    ax.legend()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    plt.close()

generate_umap_plot

generate_plot(output_file, real_data, synthetic_data)

Generate and save UMAP plot data.

Parameters:

Name Type Description Default
output_file Path

The path to the output file for the UMAP plot.

required
real_data DataFrame

Dataframe containing real data.

required
synthetic_data DataFrame

Dataframe containing synthetic data.

required
Source code in vambn/visualization/generate_umap_plot.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def generate_plot(output_file, real_data, synthetic_data):
    """
    Generate and save UMAP plot data.

    Args:
        output_file (Path): The path to the output file for the UMAP plot.
        real_data (pd.DataFrame): Dataframe containing real data.
        synthetic_data (pd.DataFrame): Dataframe containing synthetic data.
    """
    concat = pd.concat([real_data, synthetic_data]).clip(-1e15, 1e15)
    reducer = umap.UMAP(n_components=2, random_state=42)
    dim_reduct = reducer.fit_transform(concat)
    border = real_data.shape[0]
    x_real = dim_reduct[:border, 0]
    y_real = dim_reduct[:border, 1]
    x_virtual = dim_reduct[border:, 0]
    y_virtual = dim_reduct[border:, 1]

    fig, ax = plt.subplots()
    ax.scatter(x_real, y_real, c="#d62728", label="real", alpha=0.35, s=5)
    ax.scatter(
        x_virtual, y_virtual, c="#17becf", label="virtual", alpha=0.35, s=5
    )
    ax.set_xlabel("UMAP1")
    ax.set_ylabel("UMAP2")
    ax.legend()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    plt.close()

generate_umap_plot(grouping_file, real_data, decoded_data, virtual_data, output_decoded, output_virtual, max_samples=1000)

Generate UMAP plots comparing real, decoded, and virtual data.

Parameters:

Name Type Description Default
grouping_file Path

The path to the CSV file containing column name mappings.

required
real_data Path

The path to the CSV file containing real data.

required
decoded_data Path

The path to the CSV file containing decoded data.

required
virtual_data Path

The path to the CSV file containing virtual data.

required
output_decoded Path

The path to the output file for the decoded UMAP plot.

required
output_virtual Path

The path to the output file for the virtual UMAP plot.

required
max_samples int

The maximum number of samples to use for the UMAP plot.

1000
Source code in vambn/visualization/generate_umap_plot.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def generate_umap_plot(
    grouping_file: Path,
    real_data: Path,
    decoded_data: Path,
    virtual_data: Path,
    output_decoded: Path,
    output_virtual: Path,
    max_samples: int = 1000,
):
    """
    Generate UMAP plots comparing real, decoded, and virtual data.

    Args:
        grouping_file (Path): The path to the CSV file containing column name mappings.
        real_data (Path): The path to the CSV file containing real data.
        decoded_data (Path): The path to the CSV file containing decoded data.
        virtual_data (Path): The path to the CSV file containing virtual data.
        output_decoded (Path): The path to the output file for the decoded UMAP plot.
        output_virtual (Path): The path to the output file for the virtual UMAP plot.
        max_samples (int): The maximum number of samples to use for the UMAP plot.
    """
    grouping = pd.read_csv(grouping_file)
    real = pd.read_csv(real_data)
    decoded = pd.read_csv(decoded_data)
    virtual = pd.read_csv(virtual_data)

    def _prepare_data(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
        """
        Prepare the data by encoding numerical columns and sorting by subject and visit.

        Args:
            df (pd.DataFrame): The dataframe to prepare.
            cols (List[str]): List of relevant columns to keep.

        Returns:
            pd.DataFrame: The prepared dataframe.
        """
        x = df.loc[df["VISIT"] == 1, cols]
        x = encode_numerical_columns(x)
        if "subjid" in x.columns:
            x = x.rename(columns={"subjid": "SUBJID"})
        x.sort_values(by=["SUBJID", "VISIT"], inplace=True)
        return x

    # Derive relevant columns from grouping file
    subset_without_stalone = grouping.loc[
        ~grouping["technical_group_name"].str.startswith("stalone"), :
    ]
    relevant_columns = subset_without_stalone["column_names"].tolist()

    # Filter relevant columns from real and synthetic data
    available_columns = [
        col
        for col in relevant_columns
        if col in real.columns
        and col in virtual.columns
        and col in decoded.columns
    ] + ["SUBJID", "VISIT"]
    real = _prepare_data(real, available_columns)
    decoded = _prepare_data(decoded, available_columns)
    virtual = _prepare_data(virtual, available_columns)
    assert real.shape[1] == decoded.shape[1] == virtual.shape[1]

    real_1, decoded = handle_nan_values(real, decoded)
    real_2, virtual = handle_nan_values(real, virtual)

    # Drop SUBJID and VISIT columns
    real_1 = real_1.drop(columns=["SUBJID", "VISIT"])
    real_2 = real_2.drop(columns=["SUBJID", "VISIT"])
    decoded = decoded.drop(columns=["SUBJID", "VISIT"])
    virtual = virtual.drop(columns=["SUBJID", "VISIT"])

    # Sample data if necessary
    if real_1.shape[0] > max_samples:
        real_1 = real_1.sample(n=max_samples, random_state=42)
    if real_2.shape[0] > max_samples:
        real_2 = real_2.sample(n=max_samples, random_state=42)
    if decoded.shape[0] > max_samples:
        decoded = decoded.sample(n=max_samples, random_state=42)
    if virtual.shape[0] > max_samples:
        virtual = virtual.sample(n=max_samples, random_state=42)

    generate_plot(output_decoded, real_1, decoded)
    generate_plot(output_virtual, real_2, virtual)

    print("Done!")

jsd_by_module

main(jsd_metrics, grouping, output)

Generate a bar plot with error bars showing the Jensen-Shannon Divergence (JSD) for different modules and types, based on the provided JSD metrics and grouping file.

Parameters:

Name Type Description Default
jsd_metrics Path

Path to the JSON file containing JSD metrics.

required
grouping Path

Path to the CSV file containing column names and their respective types and module mappings.

required
output Path

Path to save the generated plot.

required
Source code in vambn/visualization/jsd_by_module.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def main(jsd_metrics: Path, grouping: Path, output: Path):
    """
    Generate a bar plot with error bars showing the Jensen-Shannon Divergence (JSD)
    for different modules and types, based on the provided JSD metrics and grouping file.

    Args:
        jsd_metrics (Path): Path to the JSON file containing JSD metrics.
        grouping (Path): Path to the CSV file containing column names and their respective types and module mappings.
        output (Path): Path to save the generated plot.
    """
    with jsd_metrics.open("r") as f:
        jsd_metrics_dict = json.load(f)
    grouping_df = pd.read_csv(grouping)

    type_dict = {
        row["column_names"]: row["hivae_types"]
        for _, row in grouping_df.iterrows()
    }
    module_dict = {
        row["column_names"]: row["technical_group_name"]
        for _, row in grouping_df.iterrows()
    }
    jsd_df = pd.DataFrame(jsd_metrics_dict)
    jsd_df["types"] = jsd_df["column"].map(type_dict)
    jsd_df["module"] = jsd_df["column"].map(module_dict)

    agg_metrics = jsd_df.groupby(["module", "types"]).aggregate(
        {
            "jsd_decoded": ["mean", "std"],
            "jsd_virtual": ["mean", "std"],
        }
    )

    # Melt the aggregated DataFrame to long format
    agg_metrics_melted = agg_metrics.reset_index()
    agg_metrics_melted.columns = [
        "_".join(col).strip("_") for col in agg_metrics_melted.columns.values
    ]
    agg_metrics_melted = agg_metrics_melted.melt(
        id_vars=["module", "types"], var_name="metric", value_name="value"
    )

    # Split the "metric" column into separate "variant" and "stat" columns
    metric_stat_df = agg_metrics_melted["metric"].str.split("_", expand=True)
    metric_stat_df.columns = ["metric", "variant", "stat"]
    agg_metrics_melted = pd.concat(
        [
            agg_metrics_melted.loc[:, ["module", "types", "value"]],
            metric_stat_df,
        ],
        axis=1,
    )
    agg_metrics_melted = agg_metrics_melted.drop("metric", axis=1)

    agg_wide = agg_metrics_melted.pivot_table(
        index=["module", "types", "variant"],
        columns="stat",
        values="value",
    ).reset_index()
    agg_wide["lower"] = agg_wide["mean"] - agg_wide["std"]
    agg_wide["upper"] = agg_wide["mean"] + agg_wide["std"]

    # Generate bar plot with error bars and facets per type (virtual/decoded)
    plot = (
        ggplot(
            agg_wide,
            aes(x="module", y="mean", fill="types"),
        )
        + geom_bar(stat="identity", position="dodge")
        + geom_errorbar(
            aes(ymin="lower", ymax="upper"),
            position=position_dodge(width=0.9),
            width=0.2,
        )
        + ylim(0, 1)
        + labs(y="Jensen-Shannon Divergence", title="JSD by module")
        + facet_wrap("~variant", scales="free_x")
        + theme(
            axis_text_x=element_text(angle=45, hjust=1),
            axis_title_x=element_blank(),
            figure_size=(10, 6),
        )
    )

    # Save the plot to the output file
    plot.save(output, dpi=300)