Spaces:
Running
Running
Commit
·
af9e948
1
Parent(s):
3ed97b9
upgrade to gradio blocks (#1)
Browse files- upgrade to gradio blocks (274546c2d9b7fbb87d302c2d76789da4d490ca33)
Co-authored-by: Abdullah Meda <[email protected]>
app.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
import
|
| 2 |
import re
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
-
from pathlib import Path
|
| 5 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
| 6 |
|
| 7 |
-
|
| 8 |
# Pattern to ignore all the text after 2 or more full stops
|
| 9 |
regex_pattern = "[.]{2,}"
|
| 10 |
|
|
@@ -19,6 +19,10 @@ def post_process(text):
|
|
| 19 |
return text
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def predict(image, max_length=64, num_beams=4):
|
| 23 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 24 |
pixel_values = pixel_values.to(device)
|
|
@@ -52,29 +56,29 @@ print("Loaded feature_extractor")
|
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
| 53 |
if model.decoder.name_or_path == "gpt2":
|
| 54 |
tokenizer.pad_token = tokenizer.eos_token
|
| 55 |
-
|
| 56 |
print("Loaded tokenizer")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import re
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
| 7 |
|
|
|
|
| 8 |
# Pattern to ignore all the text after 2 or more full stops
|
| 9 |
regex_pattern = "[.]{2,}"
|
| 10 |
|
|
|
|
| 19 |
return text
|
| 20 |
|
| 21 |
|
| 22 |
+
def set_example_image(example: list) -> dict:
|
| 23 |
+
return gr.Image.update(value=example[0])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
def predict(image, max_length=64, num_beams=4):
|
| 27 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
| 28 |
pixel_values = pixel_values.to(device)
|
|
|
|
| 56 |
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
| 57 |
if model.decoder.name_or_path == "gpt2":
|
| 58 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
| 59 |
print("Loaded tokenizer")
|
| 60 |
|
| 61 |
+
examples = [[f"examples/{filename}"] for filename in next(os.walk('examples'), (None, None, []))[2]]
|
| 62 |
+
print(f"Loaded {len(examples)} example images")
|
| 63 |
+
|
| 64 |
+
with gr.Blocks(css="#title { margin: 0 auto; padding: 25px 25px 25px 25px }") as poster2plot:
|
| 65 |
+
with gr.Column():
|
| 66 |
+
with gr.Row():
|
| 67 |
+
gr.Markdown("# Poster2Plot: Upload a Movie/T.V show poster to generate a plot", elem_id='title')
|
| 68 |
+
with gr.Row():
|
| 69 |
+
with gr.Column():
|
| 70 |
+
with gr.Row():
|
| 71 |
+
input_image = gr.Image(label='Input Image', type='numpy')
|
| 72 |
+
with gr.Row():
|
| 73 |
+
submit_button = gr.Button(value="Submit", variant='primary')
|
| 74 |
+
with gr.Column():
|
| 75 |
+
plot = gr.Textbox(label="Plot")
|
| 76 |
+
with gr.Row():
|
| 77 |
+
example_images = gr.Dataset(components=[input_image], samples=examples)
|
| 78 |
+
with gr.Row():
|
| 79 |
+
gr.Markdown("Made by: [dk-crazydiv](https://twitter.com/kartik_godawat) and [dsr](https://twitter.com/dsr_ai)")
|
| 80 |
+
|
| 81 |
+
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
|
| 82 |
+
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
|
| 83 |
+
|
| 84 |
+
poster2plot.launch()
|