Upload 6 files
Browse files- .gitattributes +1 -0
- Image 2.jpg +3 -0
- Image.jpg +0 -0
- app.py +56 -0
- image_captioning_model_state_dict.pt +3 -0
- model_architecture.py +54 -0
- requirements.txt +4 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
Image[[:space:]]2.jpg filter=lfs diff=lfs merge=lfs -text
|
Image 2.jpg
ADDED
![]() |
Git LFS Details
|
Image.jpg
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model_architecture import ImageCaptionGenerationWithAttention
|
3 |
+
from transformers import BartForConditionalGeneration, BartTokenizer, ViTModel, ViTImageProcessor
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import os
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
HF_TOKEN = os.getenv('hf_token')
|
12 |
+
|
13 |
+
|
14 |
+
class GenerateCaptions:
|
15 |
+
def __init__(self):
|
16 |
+
self.device = torch.device(
|
17 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
vit_model = ViTModel.from_pretrained(
|
19 |
+
"google/vit-base-patch16-224", token=HF_TOKEN).to(self.device)
|
20 |
+
bart_model = BartForConditionalGeneration.from_pretrained(
|
21 |
+
"facebook/bart-base").to(self.device)
|
22 |
+
self.processor = ViTImageProcessor.from_pretrained(
|
23 |
+
"google/vit-base-patch16-224")
|
24 |
+
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
25 |
+
self.model = ImageCaptionGenerationWithAttention(
|
26 |
+
vit_model, bart_model, self.tokenizer)
|
27 |
+
self.model.load_state_dict(torch.load(
|
28 |
+
'image_captioning_model_state_dict.pt', map_location=self.device))
|
29 |
+
self.model.eval()
|
30 |
+
|
31 |
+
def generate_caption(self, frame, max_length=50, num_beams=5):
|
32 |
+
try:
|
33 |
+
image_pixel_values = self.processor(
|
34 |
+
frame, return_tensors="pt").pixel_values
|
35 |
+
generated_caption_ids = self.model.generate(
|
36 |
+
image_pixel_values, max_length, num_beams)
|
37 |
+
return self.tokenizer.decode(generated_caption_ids[0], skip_special_tokens=True)
|
38 |
+
except Exception as e:
|
39 |
+
print(e)
|
40 |
+
print(traceback.format_exc())
|
41 |
+
|
42 |
+
|
43 |
+
gc = GenerateCaptions()
|
44 |
+
|
45 |
+
demo = gr.Interface(
|
46 |
+
fn=gc.generate_caption,
|
47 |
+
inputs=gr.Image(type='pil'),
|
48 |
+
outputs="text",
|
49 |
+
title="Image Caption with Attention",
|
50 |
+
examples=['Image.jpg', 'Image 2.jpg'],
|
51 |
+
submit_btn='Generate Caption',
|
52 |
+
flagging_mode='never'
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
demo.launch()
|
image_captioning_model_state_dict.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52b231076ad851a143939d672135b36965c3fec9d9d2531c9bd6417207e5a6e0
|
3 |
+
size 905995498
|
model_architecture.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers.modeling_outputs import BaseModelOutput
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class ImageCaptionGenerationWithAttention(nn.Module):
|
7 |
+
def __init__(self, vit_model, bart_model, tokenizer):
|
8 |
+
super().__init__()
|
9 |
+
self.tokenizer = tokenizer
|
10 |
+
self.vit = vit_model
|
11 |
+
self.bart = bart_model
|
12 |
+
self.visual_projection = nn.Linear(
|
13 |
+
vit_model.config.hidden_size, bart_model.config.d_model)
|
14 |
+
|
15 |
+
def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
|
16 |
+
vit_outputs = self.vit(pixel_values)
|
17 |
+
if isinstance(vit_outputs, tuple):
|
18 |
+
last_hidden_state = vit_outputs[0]
|
19 |
+
else:
|
20 |
+
last_hidden_state = vit_outputs.last_hidden_state
|
21 |
+
|
22 |
+
visual_features = self.visual_projection(last_hidden_state)
|
23 |
+
|
24 |
+
if input_ids is not None:
|
25 |
+
decoder_outputs = self.bart(
|
26 |
+
labels=input_ids,
|
27 |
+
encoder_outputs=BaseModelOutput(
|
28 |
+
last_hidden_state=visual_features),
|
29 |
+
return_dict=True
|
30 |
+
)
|
31 |
+
return decoder_outputs
|
32 |
+
else:
|
33 |
+
return visual_features
|
34 |
+
|
35 |
+
def generate(self, pixel_values, max_length=50, num_beams=5, early_stopping=True):
|
36 |
+
self.eval()
|
37 |
+
with torch.no_grad():
|
38 |
+
vit_outputs = self.vit(pixel_values)
|
39 |
+
if isinstance(vit_outputs, tuple):
|
40 |
+
last_hidden_state = vit_outputs[0]
|
41 |
+
else:
|
42 |
+
last_hidden_state = vit_outputs.last_hidden_state
|
43 |
+
visual_features = self.visual_projection(last_hidden_state)
|
44 |
+
generated_ids = self.bart.generate(
|
45 |
+
encoder_outputs=BaseModelOutput(
|
46 |
+
last_hidden_state=visual_features),
|
47 |
+
max_length=max_length,
|
48 |
+
num_beams=num_beams,
|
49 |
+
early_stopping=early_stopping,
|
50 |
+
decoder_start_token_id=self.tokenizer.bos_token_id,
|
51 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
52 |
+
return_dict_in_generate=False
|
53 |
+
)
|
54 |
+
return generated_ids
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.1
|
2 |
+
transformers==4.35.2
|
3 |
+
gradio==5.0.2
|
4 |
+
python-dotenv==1.0.1
|