File size: 4,339 Bytes
814a594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#%%
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json, os
import seaborn as sns

plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})

# Load data
def load_data(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)
base_dir = 'metadata'
data = load_data(os.path.join(base_dir, 'modality_counts.json'))
separate_submodality = False

# Transform data for plotting
def transform_data(data):
    df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count'])
    return df

df = transform_data(data)

# Calculate total counts by modality and sort
def calculate_totals(df):
    total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True)
    sorted_modalities = total_counts_by_modality.index.tolist()
    return total_counts_by_modality, sorted_modalities

total_counts_by_modality, sorted_modalities = calculate_totals(df)

# Generate color map
def generate_color_map(total_counts_by_modality):
    base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality)))
    modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)}
    return modality_color_map

modality_color_map = generate_color_map(total_counts_by_modality)

# Format total count for display
def format_total_count(total_count):
    if total_count >= 1000:
        exponent = int(np.floor(np.log10(total_count)))
        mantissa = total_count / 10**exponent
        formatted_total = f'{mantissa:.2f} x 10$^{exponent}$'
    else:
        exponent = 0
        formatted_total = str(total_count)
    return formatted_total, exponent

# Plotting function
def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality):
    fig, ax = plt.subplots(figsize=(10, 12))
    current_bottom = np.zeros(len(sorted_modalities))
    gap = 0.005 if separate_submodality else 0
    shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2)

    if separate_submodality:
        for i, modality in enumerate(sorted_modalities):
            subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False)
            for j, (index, row) in enumerate(subdf.iterrows()):
                count = row['Count']
                if count > 0:
                    color = np.array(modality_color_map[modality]) * shades[j % len(shades)]
                    ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
                    current_bottom[i] += count + gap
            current_bottom[i] -= gap
            total_count = total_counts_by_modality[modality]
            formatted_total, exponent = format_total_count(total_count)
            ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
    else:
        for i, modality in enumerate(sorted_modalities):
            total_count = total_counts_by_modality[modality]
            color = np.array(modality_color_map[modality] * shades[0])
            if modality.islower():
                modality = modality.capitalize()
            ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
            formatted_total, exponent = format_total_count(total_count)
            ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')

    configure_plot(ax, sorted_modalities)

    plt.tight_layout()
    plt.savefig("plots/data_dist_modality_bar_subbar.pdf" if separate_submodality else "plots/data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0)
    plt.show()

# Configure plot aesthetics
def configure_plot(ax, sorted_modalities):
    ax.set_xscale('log')
    ax.set_title("Number of images per modality", fontsize=28)
    plt.yticks(rotation=0, fontsize=24, va='center')
    ax.tick_params(axis='x', which='major', length=8)
    ax.tick_params(axis='x', which='minor', length=5)
    plt.xticks(fontsize=24)
    sns.despine()

# Main script execution
plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality)

# %%