randomarnab commited on
Commit
3bf60ff
·
1 Parent(s): 461429e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -39
app.py CHANGED
@@ -6,56 +6,67 @@ Automatically generated by Colaboratory.
6
  Original file is located at
7
  https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX
8
  """
9
- !pip install torch
10
-
11
- from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
12
- import gradio as gr
13
- import torch
14
  from PIL import Image
 
 
15
 
16
- model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
17
- feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
18
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
19
 
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- model.to(device)
22
 
 
23
 
 
24
 
25
- max_length = 16
26
- num_beams = 4
27
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
28
 
29
- def predict_step(image):
30
- # images = []
31
- # for image_path in image_paths:
32
- # i_image = Image.open(image_path)
33
- # if i_image.mode != "RGB":
34
- # i_image = i_image.convert(mode="RGB")
35
 
36
- # images.append(i_image)
37
 
38
- pixel_values = feature_extractor(images = image, return_tensors = "pt").pixel_values
39
- pixel_values = pixel_values.to(device)
40
 
41
- output_ids = model.generate(pixel_values, **gen_kwargs)
 
42
 
43
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
44
- preds = [pred.strip() for pred in preds]
45
- return preds
46
 
47
- inputs = [ gr.inputs.Image(type = 'pil', label = 'Original Image')]
48
- outputs = [ gr.outputs.Textbox(label = 'Caption')]
49
- title = 'Image Captioning using ViT + GPT2'
50
- description = 'ViT and GPT2 are used here to generate Image Caption for the user uploaded image.'
51
- article = " <a href=' https://huggingface.co/sachin/vit2distilgpt2 '>Model Repository on Hugging Face Model Hub</a>"
52
 
53
- gr.Interface(
54
- predict_step,
55
- inputs, outputs,
56
- title = title,
57
- description = description,
58
- article = article,
59
- theme = 'huggingface'
60
- ).launch(debug = True, enable_queue = True)
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  Original file is located at
7
  https://colab.research.google.com/drive/1Uvn7yZCyrMpOYNPb7K0G45tQZJVx8LyX
8
  """
 
 
 
 
 
9
  from PIL import Image
10
+ from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
11
+ import requests
12
 
13
+ model = VisionEncoderDecoderModel.from_pretrained("sachin/vit2distilgpt2")
 
 
14
 
15
+ vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
 
16
 
17
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")
18
 
19
+ # url = 'https://d2gp644kobdlm6.cloudfront.net/wp-content/uploads/2016/06/bigstock-Shocked-and-surprised-boy-on-t-113798588-300x212.jpg'
20
 
21
+ # with Image.open(requests.get(url, stream=True).raw) as img:
22
+ # pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
 
23
 
24
+ #encoder_outputs = model.generate(pixel_values.to('cpu'),num_beams=5)
 
 
 
 
 
25
 
26
+ #generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
27
 
28
+ #generated_sentences
 
29
 
30
+ #naive text processing
31
+ #generated_sentences[0].split('.')[0]
32
 
33
+ # inference function
 
 
34
 
35
+ def vit2distilgpt2(img):
36
+ pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
37
+ encoder_outputs = generated_ids = model.generate(pixel_values.to('cpu'),num_beams=5)
38
+ generated_sentences = tokenizer.batch_decode(encoder_outputs, skip_special_tokens=True)
 
39
 
40
+ return(generated_sentences[0].split('.')[0])
41
+
42
+ #!wget https://media.glamour.com/photos/5f171c4fd35176eaedb36823/master/w_2560%2Cc_limit/bike.jpg
43
+
44
+ import gradio as gr
 
 
 
45
 
46
+ inputs = [
47
+ gr.inputs.Image(type="pil", label="Original Image")
48
+ ]
49
+
50
+ outputs = [
51
+ gr.outputs.Textbox(label = 'Caption')
52
+ ]
53
+
54
+ title = "Image Captioning using ViT + GPT2"
55
+ description = "ViT and GPT2 are used to generate Image Caption for the uploaded image. COCO Dataset was used for training. This image captioning model might have some biases that we couldn't figure during our stress testing, so if you find any bias (gender, race and so on) please use `Flag` button to flag the image with bias"
56
+ article = " <a href='https://huggingface.co/sachin/vit2distilgpt2'>Model Repo on Hugging Face Model Hub</a>"
57
+ examples = [
58
+ ["people-walking-street-pedestrian-crossing-traffic-light-city.jpeg"],
59
+ ["elonmusk.jpeg"]
60
+
61
+ ]
62
+
63
+ gr.Interface(
64
+ vit2distilgpt2,
65
+ inputs,
66
+ outputs,
67
+ title=title,
68
+ description=description,
69
+ article=article,
70
+ examples=examples,
71
+ theme="huggingface",
72
+ ).launch(debug=True, enable_queue=True)