|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
import tensorflow as tf |
|
from fastapi import FastAPI |
|
from io import BytesIO |
|
from PIL import Image |
|
from pydantic import BaseModel |
|
from cat_breeds_dict import CAT_BREEDS, CAT_DESCRIPTIONS |
|
from scripts.crop_image import crop_image |
|
|
|
|
|
class Url(BaseModel): |
|
link: str |
|
|
|
|
|
MODEL = tf.keras.models.load_model('./models/cats_18_EfficientNetB0.h5') |
|
GRADIO_PATH = '/' |
|
INPUT_SHAPE = MODEL.layers[0].input_shape[1] |
|
NUM_CLASSES = MODEL.layers[-1].output_shape[1] |
|
app = FastAPI() |
|
|
|
|
|
def predict(image, api_mode=False): |
|
image = crop_image(image, INPUT_SHAPE, INPUT_SHAPE) |
|
image = image.resize((INPUT_SHAPE, INPUT_SHAPE)) |
|
image = np.asarray(image) |
|
image = image.reshape(1, INPUT_SHAPE, INPUT_SHAPE, 3) |
|
|
|
prediction = MODEL.predict(image)[0] |
|
predicted_breed = CAT_BREEDS[np.argmax(prediction)] |
|
breed_description = CAT_DESCRIPTIONS[predicted_breed] |
|
|
|
all_predictions = { |
|
CAT_BREEDS[i]: float(prediction[i]) for i in range(NUM_CLASSES) |
|
} |
|
|
|
if api_mode: |
|
breed_description = ' '.join(breed_description.replace('\n', '.') |
|
.replace('#', '') |
|
.split()) |
|
|
|
return { |
|
'breed': predicted_breed, |
|
'description': breed_description, |
|
'predictions': all_predictions |
|
} |
|
return all_predictions, breed_description, gr.HTML.update(visible=True), gr.Markdown.update(visible=True) |
|
|
|
|
|
@app.post('/predict_breed/') |
|
def predict_api(url: Url): |
|
try: |
|
image = requests.get(url.link).content |
|
except Exception as e: |
|
return {'error': 'Invalid link', 'exception': str(e)} |
|
image = Image.open(BytesIO(image)) |
|
return predict(image, api_mode=True) |
|
|
|
|
|
with gr.Blocks(css='./static/style.css', title="Cat Classifier") as gradio_ui: |
|
|
|
gr.Markdown( |
|
""" |
|
# Классификатор пород котов |
|
Разработано студентами Шершневым А.А, Ивановым С.С, Шалаевой И.Г. и |
|
Ильиным С.С. |
|
Группы: РИМ-120906, РИМ-120907 |
|
""", |
|
elem_id='md-text' |
|
) |
|
|
|
with gr.Row(elem_id='main-row') as row: |
|
|
|
with gr.Column(scale=2, elem_id='first-col') as col_1: |
|
user_image = gr.Image( |
|
label='Загрузите фотографию котика сюда', |
|
type='pil', |
|
elem_id='user-image' |
|
) |
|
predict_button = gr.Button(value='Определить породу') |
|
|
|
with gr.Column(scale=1, elem_id='second-col') as col_2: |
|
predicted_labels = gr.Label( |
|
num_top_classes=5, |
|
label='Результат определения породы', |
|
elem_id='predictions-text' |
|
) |
|
|
|
breed_description = gr.Markdown(elem_id='breed-description') |
|
banner_text = gr.Markdown( |
|
""" |
|
# <center>Места, которые будут Вам интересны</center> |
|
""", visible=False |
|
) |
|
embedded_map = gr.HTML(''' |
|
<iframe src="https://yandex.ru/map-widget/v1/?um=constructor%3A8cead4799165c7f6356c4f269f2847032ef2803cb46871dbfd6dd68c09834f4c&source=constructor" width="100%" height="500" frameborder="0"></iframe> |
|
''', visible=False, elem_id='embedded-map') |
|
|
|
predict_button.click( |
|
fn=predict, |
|
inputs=[user_image], |
|
outputs=[ |
|
predicted_labels, breed_description, |
|
embedded_map, banner_text |
|
] |
|
) |
|
|
|
app = gr.mount_gradio_app(app, gradio_ui, path=GRADIO_PATH) |
|
|