File size: 3,908 Bytes
23770d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from pathlib import Path
from transformers import pipeline
from PIL import Image
import gradio as gr
import pandas as pd
import numpy as np
import altair as alt


prompt_images = {
    "Images_11": "blanks/Images_11_blank.png",
    "Images_4": "blanks/Images_4_blank.png",
    "Images_17": "blanks/Images_17_blank.png",
    "Images_9": "blanks/Images_9_blank.png",
    "Images_3": "blanks/Images_3_blank.png",
    "Images_8": "blanks/Images_8_blank.png",
    "Images_13": "blanks/Images_13_blank.png",
    "Images_15": "blanks/Images_15_blank.png",
    "Images_12": "blanks/Images_12_blank.png",
    "Images_7": "blanks/Images_7_blank.png",
    "Images_56": "blanks/Images_56_blank.png",
    "Images_19": "blanks/Images_19_blank.png",
}

dist = pd.read_csv('./score_norm_distribution.csv', dtype=float)


def get_percentile(score):
    return dist[dist['score_norm'] <= score].iloc[-1, 0]


def inverse_scale(logits):
    # undo the min-max scaling that was done from the JRT range to 0-1
    scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188}
    return logits * (scaler_params['range']) + scaler_params['min']


def get_predictions(img_dict):
    # gradio passes a dictionary with background, composite, and layers
    # the composite is what we want
    img = img_dict['composite']
    predictions = classifier(img)
    return {
        'originality' : np.round(predictions[0]['score'], 2),
        'jrt' : np.round(inverse_scale(0), 2),
        'percentile': get_percentile(predictions[0]['score'])
    }


base_chart = alt.Chart(dist).mark_line().encode(
    x='percentile',
    y='score_norm'
)

def classify_image(img_dict):
    # gradio passes a dictionary with background, composite, and layers
    # the composite is what we want
    img = img_dict['composite']
    p = get_predictions(img_dict)

    percentile_mark = alt.Chart(pd.DataFrame({'y': [p['originality']]})).mark_rule(color='red').encode(y='y')

    # Text annotation for the percentile mark
    text = alt.Chart(pd.DataFrame({'y': [p['originality']], 'text': [f"Percentile: {p['percentile']}; Normalized Score: {p['originality']}"]})).mark_text(
        align='left',
        baseline='middle',
        dx=7, dy=-8  # Nudges text to right so it doesn't overlap with the line
    ).encode(
        y='y',
        text='text'
    )
    return base_chart + percentile_mark + text


def update_editor(background, img_editor):
    # Clear layers and set the selected background
    img_editor['background'] = background
    img_editor['layers'] = []
    img_editor['composite'] = None
    return img_editor


classifier = pipeline("image-classification", model='POrg/ocsai-d-large')

editor = gr.ImageEditor(type='pil',
                        value=dict(
                            background=Image.open(prompt_images['Images_11']),
                            composite=None,
                            layers=[]
                        ),
                        brush=gr.Brush(
                            default_size=2,
                            colors=["#000000", '#333333', '#666666'],
                            color_mode="fixed"
                        ),
                        transforms=[],
                        sources=('upload', 'clipboard'),
                        layers=False
                        )

examples = []
for k, v in prompt_images.items():
    examples.append(dict(background=Image.open(v), composite=None, layers=[]))

demo = gr.Interface(fn=classify_image,
                     inputs=[editor],
                     outputs=gr.Plot(),
                     title="Ocsai-D",
                     description="Complete the drawing and classify the originality. Examples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)). Choose the brush icon below the image to start editing.",
                     examples=examples
                     )

demo.launch(debug=True)