Spaces:
Running
Running
from pathlib import Path | |
from transformers import pipeline | |
from PIL import Image | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import altair as alt | |
prompt_images = { | |
"Images_11": "blanks/Images_11_blank.png", | |
"Images_4": "blanks/Images_4_blank.png", | |
"Images_17": "blanks/Images_17_blank.png", | |
"Images_9": "blanks/Images_9_blank.png", | |
"Images_3": "blanks/Images_3_blank.png", | |
"Images_8": "blanks/Images_8_blank.png", | |
"Images_13": "blanks/Images_13_blank.png", | |
"Images_15": "blanks/Images_15_blank.png", | |
"Images_12": "blanks/Images_12_blank.png", | |
"Images_7": "blanks/Images_7_blank.png", | |
"Images_56": "blanks/Images_56_blank.png", | |
"Images_19": "blanks/Images_19_blank.png", | |
} | |
dist = pd.read_csv('./score_norm_distribution.csv', dtype=float) | |
def get_percentile(score): | |
return dist[dist['score_norm'] <= score].iloc[-1, 0] | |
def inverse_scale(logits): | |
# undo the min-max scaling that was done from the JRT range to 0-1 | |
scaler_params = {'min': -3.024, 'max': 3.164, 'range': 6.188} | |
return logits * (scaler_params['range']) + scaler_params['min'] | |
def get_predictions(img_dict): | |
# gradio passes a dictionary with background, composite, and layers | |
# the composite is what we want | |
img = img_dict['composite'] | |
predictions = classifier(img) | |
return { | |
'originality' : np.round(predictions[0]['score'], 2), | |
'jrt' : np.round(inverse_scale(0), 2), | |
'percentile': get_percentile(predictions[0]['score']) | |
} | |
base_chart = alt.Chart(dist).mark_line().encode( | |
x='percentile', | |
y='score_norm' | |
) | |
def classify_image(img_dict): | |
# gradio passes a dictionary with background, composite, and layers | |
# the composite is what we want | |
img = img_dict['composite'] | |
p = get_predictions(img_dict) | |
percentile_mark = alt.Chart(pd.DataFrame({'y': [p['originality']]})).mark_rule(color='red').encode(y='y') | |
# Text annotation for the percentile mark | |
text = alt.Chart(pd.DataFrame({'y': [p['originality']], 'text': [f"Percentile: {p['percentile']}; Normalized Score: {p['originality']}"]})).mark_text( | |
align='left', | |
baseline='middle', | |
dx=7, dy=-8 # Nudges text to right so it doesn't overlap with the line | |
).encode( | |
y='y', | |
text='text' | |
) | |
return base_chart + percentile_mark + text | |
def update_editor(background, img_editor): | |
# Clear layers and set the selected background | |
img_editor['background'] = background | |
img_editor['layers'] = [] | |
img_editor['composite'] = None | |
return img_editor | |
classifier = pipeline("image-classification", model='POrg/ocsai-d-large') | |
editor = gr.ImageEditor(type='pil', | |
value=dict( | |
background=Image.open(prompt_images['Images_11']), | |
composite=None, | |
layers=[] | |
), | |
brush=gr.Brush( | |
default_size=2, | |
colors=["#000000", '#333333', '#666666'], | |
color_mode="fixed" | |
), | |
transforms=[], | |
sources=('upload', 'clipboard'), | |
layers=False | |
) | |
examples = [] | |
for k, v in prompt_images.items(): | |
examples.append(dict(background=Image.open(v), composite=None, layers=[])) | |
demo = gr.Interface(fn=classify_image, | |
inputs=[editor], | |
outputs=gr.Plot(), | |
title="Ocsai-D", | |
description="Complete the drawing and classify the originality. Examples are from MTCI ([Barbot 2018](https://pubmed.ncbi.nlm.nih.gov/30618952/)). Choose the brush icon below the image to start editing.", | |
examples=examples | |
) | |
demo.launch(debug=True) | |