Spaces:
Runtime error
Runtime error
Commit
·
3bf60ff
1
Parent(s):
461429e
Update app.py
Browse files
app.py
CHANGED
@@ -6,56 +6,67 @@ Automatically generated by Colaboratory.
|
|
6 |
Original file is located at
|
7 |
https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX
|
8 |
"""
|
9 |
-
!pip install torch
|
10 |
-
|
11 |
-
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
|
12 |
-
import gradio as gr
|
13 |
-
import torch
|
14 |
from PIL import Image
|
|
|
|
|
15 |
|
16 |
-
model = VisionEncoderDecoderModel.from_pretrained("
|
17 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
18 |
-
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
19 |
|
20 |
-
|
21 |
-
model.to(device)
|
22 |
|
|
|
23 |
|
|
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
28 |
|
29 |
-
|
30 |
-
# images = []
|
31 |
-
# for image_path in image_paths:
|
32 |
-
# i_image = Image.open(image_path)
|
33 |
-
# if i_image.mode != "RGB":
|
34 |
-
# i_image = i_image.convert(mode="RGB")
|
35 |
|
36 |
-
#
|
37 |
|
38 |
-
|
39 |
-
pixel_values = pixel_values.to(device)
|
40 |
|
41 |
-
|
|
|
42 |
|
43 |
-
|
44 |
-
preds = [pred.strip() for pred in preds]
|
45 |
-
return preds
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
article = " <a href=' https://huggingface.co/sachin/vit2distilgpt2 '>Model Repository on Hugging Face Model Hub</a>"
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
article = article,
|
59 |
-
theme = 'huggingface'
|
60 |
-
).launch(debug = True, enable_queue = True)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
Original file is located at
|
7 |
https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX
|
8 |
"""
|
|
|
|
|
|
|
|
|
|
|
9 |
from PIL import Image
|
10 |
+
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
|
11 |
+
import requests
|
12 |
|
13 |
+
model = VisionEncoderDecoderModel.from_pretrained("sachin/vit2distilgpt2")
|
|
|
|
|
14 |
|
15 |
+
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
|
|
16 |
|
17 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
|
18 |
|
19 |
+
# url = 'https://d2gp644kobdlm6.cloudfront.net/wp-content/uploads/2016/06/bigstock-Shocked-and-surprised-boy-on-t-113798588-300x212.jpg'
|
20 |
|
21 |
+
# with Image.open(requests.get(url, stream=True).raw) as img:
|
22 |
+
# pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
|
|
|
23 |
|
24 |
+
#encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=5)
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
#generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
|
27 |
|
28 |
+
#generated_sentences
|
|
|
29 |
|
30 |
+
#naive text processing
|
31 |
+
#generated_sentences[0].split('.')[0]
|
32 |
|
33 |
+
# inference function
|
|
|
|
|
34 |
|
35 |
+
def vit2distilgpt2(img):
|
36 |
+
pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
|
37 |
+
encoder_outputs = generated_ids = model.generate(pixel_values.to('cpu'),num_beams=5)
|
38 |
+
generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
|
|
|
39 |
|
40 |
+
return(generated_sentences[0].split('.')[0])
|
41 |
+
|
42 |
+
#!wget https://media.glamour.com/photos/5f171c4fd35176eaedb36823/master/w_2560%2Cc_limit/bike.jpg
|
43 |
+
|
44 |
+
import gradio as gr
|
|
|
|
|
|
|
45 |
|
46 |
+
inputs = [
|
47 |
+
gr.inputs.Image(type="pil", label="Original Image")
|
48 |
+
]
|
49 |
+
|
50 |
+
outputs = [
|
51 |
+
gr.outputs.Textbox(label = 'Caption')
|
52 |
+
]
|
53 |
+
|
54 |
+
title = "Image Captioning using ViT + GPT2"
|
55 |
+
description = "ViT and GPT2 are used to generate Image Caption for the uploaded image. COCO Dataset was used for training. This image captioning model might have some biases that we couldn't figure during our stress testing, so if you find any bias (gender, race and so on) please use `Flag` button to flag the image with bias"
|
56 |
+
article = " <a href='https://huggingface.co/sachin/vit2distilgpt2'>Model Repo on Hugging Face Model Hub</a>"
|
57 |
+
examples = [
|
58 |
+
["people-walking-street-pedestrian-crossing-traffic-light-city.jpeg"],
|
59 |
+
["elonmusk.jpeg"]
|
60 |
+
|
61 |
+
]
|
62 |
+
|
63 |
+
gr.Interface(
|
64 |
+
vit2distilgpt2,
|
65 |
+
inputs,
|
66 |
+
outputs,
|
67 |
+
title=title,
|
68 |
+
description=description,
|
69 |
+
article=article,
|
70 |
+
examples=examples,
|
71 |
+
theme="huggingface",
|
72 |
+
).launch(debug=True, enable_queue=True)
|