anon5's picture
Update app.py
aae03c9 verified
raw
history blame contribute delete
840 Bytes
import torch
from transformers import pipeline
import torch
import gradio as gr
with gr.Blocks() as demo:
def submit(image):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
pipe = pipeline("image-classification", model="./checkpoint-600")
output = pipe(images=[image])
result = {}
for index, item in enumerate(output[0]):
result[item["label"]] = item["score"]
return result
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input image", type="filepath")
with gr.Column():
label = gr.Label()
submit_button = gr.Button(value="Submit", variant="primary")
submit_button.click(submit, inputs=[image_input], outputs=label)
demo.launch()