ashish-001 commited on
Commit
906f611
·
verified ·
1 Parent(s): dc6ad72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -56
app.py CHANGED
@@ -1,56 +1,56 @@
1
- import gradio as gr
2
- from model_architecture import ImageCaptionGenerationWithAttention
3
- from transformers import BartForConditionalGeneration, BartTokenizer, ViTModel, ViTImageProcessor
4
- import torch
5
- from PIL import Image
6
- from dotenv import load_dotenv
7
- import os
8
- import traceback
9
-
10
- load_dotenv()
11
- HF_TOKEN = os.getenv('hf_token')
12
-
13
-
14
- class GenerateCaptions:
15
- def __init__(self):
16
- self.device = torch.device(
17
- "cuda" if torch.cuda.is_available() else "cpu")
18
- vit_model = ViTModel.from_pretrained(
19
- "google/vit-base-patch16-224", token=HF_TOKEN).to(self.device)
20
- bart_model = BartForConditionalGeneration.from_pretrained(
21
- "facebook/bart-base").to(self.device)
22
- self.processor = ViTImageProcessor.from_pretrained(
23
- "google/vit-base-patch16-224")
24
- self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
25
- self.model = ImageCaptionGenerationWithAttention(
26
- vit_model, bart_model, self.tokenizer)
27
- self.model.load_state_dict(torch.load(
28
- 'image_captioning_model_state_dict.pt', map_location=self.device))
29
- self.model.eval()
30
-
31
- def generate_caption(self, frame, max_length=50, num_beams=5):
32
- try:
33
- image_pixel_values = self.processor(
34
- frame, return_tensors="pt").pixel_values
35
- generated_caption_ids = self.model.generate(
36
- image_pixel_values, max_length, num_beams)
37
- return self.tokenizer.decode(generated_caption_ids[0], skip_special_tokens=True)
38
- except Exception as e:
39
- print(e)
40
- print(traceback.format_exc())
41
-
42
-
43
- gc = GenerateCaptions()
44
-
45
- demo = gr.Interface(
46
- fn=gc.generate_caption,
47
- inputs=gr.Image(type='pil'),
48
- outputs="text",
49
- title="Image Caption with Attention",
50
- examples=['Image.jpg', 'Image 2.jpg'],
51
- submit_btn='Generate Caption',
52
- flagging_mode='never'
53
- )
54
-
55
-
56
- demo.launch()
 
1
+ import gradio as gr
2
+ from model_architecture import ImageCaptionGenerationWithAttention
3
+ from transformers import BartForConditionalGeneration, BartTokenizer, ViTModel, ViTImageProcessor
4
+ import torch
5
+ from PIL import Image
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import traceback
9
+
10
+ load_dotenv()
11
+ HF_TOKEN = os.getenv('hf_token')
12
+
13
+
14
+ class GenerateCaptions:
15
+ def __init__(self):
16
+ self.device = torch.device(
17
+ "cuda" if torch.cuda.is_available() else "cpu")
18
+ vit_model = ViTModel.from_pretrained(
19
+ "google/vit-base-patch16-224", token=HF_TOKEN).to(self.device)
20
+ bart_model = BartForConditionalGeneration.from_pretrained(
21
+ "facebook/bart-base").to(self.device)
22
+ self.processor = ViTImageProcessor.from_pretrained(
23
+ "google/vit-base-patch16-224")
24
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
25
+ self.model = ImageCaptionGenerationWithAttention(
26
+ vit_model, bart_model, self.tokenizer)
27
+ self.model.load_state_dict(torch.load(
28
+ 'image_captioning_model_state_dict.pt', map_location=self.device))
29
+ self.model.eval()
30
+
31
+ def generate_caption(self, frame, max_length=50, num_beams=5):
32
+ try:
33
+ image_pixel_values = self.processor(
34
+ frame, return_tensors="pt").pixel_values
35
+ generated_caption_ids = self.model.generate(
36
+ image_pixel_values, max_length, num_beams)
37
+ return self.tokenizer.decode(generated_caption_ids[0], skip_special_tokens=True)
38
+ except Exception as e:
39
+ print(e)
40
+ print(traceback.format_exc())
41
+
42
+
43
+ gc = GenerateCaptions()
44
+
45
+ demo = gr.Interface(
46
+ fn=gc.generate_caption,
47
+ inputs=gr.Image(type='pil'),
48
+ outputs="text",
49
+ title="Image Caption Generation",
50
+ examples=['Image.jpg', 'Image 2.jpg'],
51
+ submit_btn='Generate Caption',
52
+ flagging_mode='never'
53
+ )
54
+
55
+
56
+ demo.launch()