Spaces:
Runtime error
Runtime error
jaekookang
commited on
Commit
Β·
21c7de8
1
Parent(s):
39a6dd6
update callback
Browse files
.ipynb_checkpoints/gradio_artist_classifier-checkpoint.py
CHANGED
@@ -5,11 +5,12 @@ prototype
|
|
5 |
---
|
6 |
- 2022-01-18 jkang first created
|
7 |
'''
|
8 |
-
|
9 |
import matplotlib.pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
import seaborn as sns
|
12 |
|
|
|
13 |
import json
|
14 |
import skimage.io
|
15 |
from loguru import logger
|
@@ -35,6 +36,11 @@ artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
|
|
35 |
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
|
36 |
logger.info('both models loaded')
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
def load_image_as_array(image_file):
|
39 |
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
|
40 |
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
|
@@ -48,17 +54,63 @@ def load_image_as_tensor(image_file):
|
|
48 |
|
49 |
def predict(input_image):
|
50 |
img_3d_array = load_image_as_array(input_image)
|
51 |
-
img_4d_tensor = load_image_as_tensor(input_image)
|
|
|
52 |
logger.info(f'--- {input_image} loaded')
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
iface = gr.Interface(
|
60 |
predict,
|
61 |
-
title='Predict Artist and Artistic
|
62 |
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
|
63 |
inputs=[
|
64 |
gr.inputs.Image(label='Upload a drawing/image', type='file')
|
|
|
5 |
---
|
6 |
- 2022-01-18 jkang first created
|
7 |
'''
|
8 |
+
from PIL import Image
|
9 |
import matplotlib.pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
import seaborn as sns
|
12 |
|
13 |
+
import io
|
14 |
import json
|
15 |
import skimage.io
|
16 |
from loguru import logger
|
|
|
36 |
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
|
37 |
logger.info('both models loaded')
|
38 |
|
39 |
+
def load_json_as_dict(json_file):
|
40 |
+
with open(json_file, 'r') as f:
|
41 |
+
out = json.load(f)
|
42 |
+
return dict(out)
|
43 |
+
|
44 |
def load_image_as_array(image_file):
|
45 |
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
|
46 |
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
|
|
|
54 |
|
55 |
def predict(input_image):
|
56 |
img_3d_array = load_image_as_array(input_image)
|
57 |
+
# img_4d_tensor = load_image_as_tensor(input_image)
|
58 |
+
img_4d_array = img_3d_array[np.newaxis,...]
|
59 |
logger.info(f'--- {input_image} loaded')
|
60 |
|
61 |
+
artist2id = load_json_as_dict(ARTIST_META)
|
62 |
+
trend2id = load_json_as_dict(TREND_META)
|
63 |
+
id2artist = {artist2id[artist]:artist for artist in artist2id}
|
64 |
+
id2trend = {trend2id[trend]:trend for trend in trend2id}
|
65 |
+
|
66 |
+
# Artist model
|
67 |
+
a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
|
68 |
+
img_4d_array,
|
69 |
+
pred_idx=None)
|
70 |
+
a_img_pil = align_image_with_heatmap(
|
71 |
+
img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
|
72 |
+
a_img = np.asarray(a_img_pil).astype('float32')/255
|
73 |
+
a_label = id2artist[a_pred_id]
|
74 |
+
a_prob = a_pred_out[a_pred_id]
|
75 |
+
|
76 |
+
# Trend model
|
77 |
+
t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
|
78 |
+
img_4d_array,
|
79 |
+
pred_idx=None)
|
80 |
|
81 |
+
t_img_pil = align_image_with_heatmap(
|
82 |
+
img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
|
83 |
+
t_img = np.asarray(t_img_pil).astype('float32')/255
|
84 |
+
t_label = id2trend[t_pred_id]
|
85 |
+
t_prob = t_pred_out[t_pred_id]
|
86 |
+
|
87 |
+
with sns.plotting_context('poster', font_scale=0.7):
|
88 |
+
fig, (ax1, ax2, ax3) = plt.subplots(
|
89 |
+
1, 3, figsize=(12, 6), facecolor='white')
|
90 |
+
for ax in (ax1, ax2, ax3):
|
91 |
+
ax.set_xticks([])
|
92 |
+
ax.set_yticks([])
|
93 |
+
|
94 |
+
ax1.imshow(img_3d_array)
|
95 |
+
ax2.imshow(a_img)
|
96 |
+
ax3.imshow(t_img)
|
97 |
+
|
98 |
+
ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
|
99 |
+
ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
|
100 |
+
ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
|
101 |
+
fig.tight_layout()
|
102 |
+
|
103 |
+
buf = io.BytesIO()
|
104 |
+
fig.save(buf, bbox_inces='tight', fotmat='jpg')
|
105 |
+
buf.seek(0)
|
106 |
+
pil_img = Image.open(buf)
|
107 |
+
plt.close()
|
108 |
+
logger.info('--- output generated')
|
109 |
+
return pil_img
|
110 |
|
111 |
iface = gr.Interface(
|
112 |
predict,
|
113 |
+
title='Predict Artist and Artistic Style of Drawings π¨π¨π»βπ¨ (prototype)',
|
114 |
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
|
115 |
inputs=[
|
116 |
gr.inputs.Image(label='Upload a drawing/image', type='file')
|
gradio_artist_classifier.py
CHANGED
@@ -5,11 +5,12 @@ prototype
|
|
5 |
---
|
6 |
- 2022-01-18 jkang first created
|
7 |
'''
|
8 |
-
|
9 |
import matplotlib.pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
import seaborn as sns
|
12 |
|
|
|
13 |
import json
|
14 |
import skimage.io
|
15 |
from loguru import logger
|
@@ -35,6 +36,11 @@ artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
|
|
35 |
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
|
36 |
logger.info('both models loaded')
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
def load_image_as_array(image_file):
|
39 |
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
|
40 |
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
|
@@ -48,17 +54,63 @@ def load_image_as_tensor(image_file):
|
|
48 |
|
49 |
def predict(input_image):
|
50 |
img_3d_array = load_image_as_array(input_image)
|
51 |
-
img_4d_tensor = load_image_as_tensor(input_image)
|
|
|
52 |
logger.info(f'--- {input_image} loaded')
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
iface = gr.Interface(
|
60 |
predict,
|
61 |
-
title='Predict Artist and Artistic
|
62 |
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
|
63 |
inputs=[
|
64 |
gr.inputs.Image(label='Upload a drawing/image', type='file')
|
|
|
5 |
---
|
6 |
- 2022-01-18 jkang first created
|
7 |
'''
|
8 |
+
from PIL import Image
|
9 |
import matplotlib.pyplot as plt
|
10 |
import matplotlib.image as mpimg
|
11 |
import seaborn as sns
|
12 |
|
13 |
+
import io
|
14 |
import json
|
15 |
import skimage.io
|
16 |
from loguru import logger
|
|
|
36 |
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
|
37 |
logger.info('both models loaded')
|
38 |
|
39 |
+
def load_json_as_dict(json_file):
|
40 |
+
with open(json_file, 'r') as f:
|
41 |
+
out = json.load(f)
|
42 |
+
return dict(out)
|
43 |
+
|
44 |
def load_image_as_array(image_file):
|
45 |
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
|
46 |
if (img.shape[-1] > 3) & (remove_alpha_channel): # if RGBA
|
|
|
54 |
|
55 |
def predict(input_image):
|
56 |
img_3d_array = load_image_as_array(input_image)
|
57 |
+
# img_4d_tensor = load_image_as_tensor(input_image)
|
58 |
+
img_4d_array = img_3d_array[np.newaxis,...]
|
59 |
logger.info(f'--- {input_image} loaded')
|
60 |
|
61 |
+
artist2id = load_json_as_dict(ARTIST_META)
|
62 |
+
trend2id = load_json_as_dict(TREND_META)
|
63 |
+
id2artist = {artist2id[artist]:artist for artist in artist2id}
|
64 |
+
id2trend = {trend2id[trend]:trend for trend in trend2id}
|
65 |
+
|
66 |
+
# Artist model
|
67 |
+
a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
|
68 |
+
img_4d_array,
|
69 |
+
pred_idx=None)
|
70 |
+
a_img_pil = align_image_with_heatmap(
|
71 |
+
img_4d_array, a_heatmap, alpha=alpha, cmap='jet')
|
72 |
+
a_img = np.asarray(a_img_pil).astype('float32')/255
|
73 |
+
a_label = id2artist[a_pred_id]
|
74 |
+
a_prob = a_pred_out[a_pred_id]
|
75 |
+
|
76 |
+
# Trend model
|
77 |
+
t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
|
78 |
+
img_4d_array,
|
79 |
+
pred_idx=None)
|
80 |
|
81 |
+
t_img_pil = align_image_with_heatmap(
|
82 |
+
img_4d_array, t_heatmap, alpha=alpha, cmap='jet')
|
83 |
+
t_img = np.asarray(t_img_pil).astype('float32')/255
|
84 |
+
t_label = id2trend[t_pred_id]
|
85 |
+
t_prob = t_pred_out[t_pred_id]
|
86 |
+
|
87 |
+
with sns.plotting_context('poster', font_scale=0.7):
|
88 |
+
fig, (ax1, ax2, ax3) = plt.subplots(
|
89 |
+
1, 3, figsize=(12, 6), facecolor='white')
|
90 |
+
for ax in (ax1, ax2, ax3):
|
91 |
+
ax.set_xticks([])
|
92 |
+
ax.set_yticks([])
|
93 |
+
|
94 |
+
ax1.imshow(img_3d_array)
|
95 |
+
ax2.imshow(a_img)
|
96 |
+
ax3.imshow(t_img)
|
97 |
+
|
98 |
+
ax1.set_title(f'Artist: {artist}\nTrend: {trend}', ha='left', x=0, y=1.05)
|
99 |
+
ax2.set_title(f'Artist Prediction:\n =>{a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
|
100 |
+
ax3.set_title(f'Trend Prediction:\n =>{t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
|
101 |
+
fig.tight_layout()
|
102 |
+
|
103 |
+
buf = io.BytesIO()
|
104 |
+
fig.save(buf, bbox_inces='tight', fotmat='jpg')
|
105 |
+
buf.seek(0)
|
106 |
+
pil_img = Image.open(buf)
|
107 |
+
plt.close()
|
108 |
+
logger.info('--- output generated')
|
109 |
+
return pil_img
|
110 |
|
111 |
iface = gr.Interface(
|
112 |
predict,
|
113 |
+
title='Predict Artist and Artistic Style of Drawings π¨π¨π»βπ¨ (prototype)',
|
114 |
description='Upload a drawing and the model will predict how likely it seems given 10 artists and their trend/style',
|
115 |
inputs=[
|
116 |
gr.inputs.Image(label='Upload a drawing/image', type='file')
|
requirements-dev.txt
CHANGED
@@ -3,7 +3,8 @@ huggingface_hub==0.4.0
|
|
3 |
loguru==0.5.3
|
4 |
matplotlib==3.5.1
|
5 |
numpy==1.22.0
|
|
|
6 |
scikit_image==0.19.1
|
7 |
seaborn==0.11.2
|
8 |
-
|
9 |
tensorflow==2.7.0
|
|
|
3 |
loguru==0.5.3
|
4 |
matplotlib==3.5.1
|
5 |
numpy==1.22.0
|
6 |
+
Pillow==9.0.0
|
7 |
scikit_image==0.19.1
|
8 |
seaborn==0.11.2
|
9 |
+
skimage==0.0
|
10 |
tensorflow==2.7.0
|