ashish-001 commited on
Commit
dc6ad72
·
verified ·
1 Parent(s): 9b097df

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Image[[:space:]]2.jpg filter=lfs diff=lfs merge=lfs -text
Image 2.jpg ADDED

Git LFS Details

  • SHA256: e8de3170abe960ff6df25cdfa0832a95764d97839228969e5c454055abd6b4f4
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
Image.jpg ADDED
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model_architecture import ImageCaptionGenerationWithAttention
3
+ from transformers import BartForConditionalGeneration, BartTokenizer, ViTModel, ViTImageProcessor
4
+ import torch
5
+ from PIL import Image
6
+ from dotenv import load_dotenv
7
+ import os
8
+ import traceback
9
+
10
+ load_dotenv()
11
+ HF_TOKEN = os.getenv('hf_token')
12
+
13
+
14
+ class GenerateCaptions:
15
+ def __init__(self):
16
+ self.device = torch.device(
17
+ "cuda" if torch.cuda.is_available() else "cpu")
18
+ vit_model = ViTModel.from_pretrained(
19
+ "google/vit-base-patch16-224", token=HF_TOKEN).to(self.device)
20
+ bart_model = BartForConditionalGeneration.from_pretrained(
21
+ "facebook/bart-base").to(self.device)
22
+ self.processor = ViTImageProcessor.from_pretrained(
23
+ "google/vit-base-patch16-224")
24
+ self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
25
+ self.model = ImageCaptionGenerationWithAttention(
26
+ vit_model, bart_model, self.tokenizer)
27
+ self.model.load_state_dict(torch.load(
28
+ 'image_captioning_model_state_dict.pt', map_location=self.device))
29
+ self.model.eval()
30
+
31
+ def generate_caption(self, frame, max_length=50, num_beams=5):
32
+ try:
33
+ image_pixel_values = self.processor(
34
+ frame, return_tensors="pt").pixel_values
35
+ generated_caption_ids = self.model.generate(
36
+ image_pixel_values, max_length, num_beams)
37
+ return self.tokenizer.decode(generated_caption_ids[0], skip_special_tokens=True)
38
+ except Exception as e:
39
+ print(e)
40
+ print(traceback.format_exc())
41
+
42
+
43
+ gc = GenerateCaptions()
44
+
45
+ demo = gr.Interface(
46
+ fn=gc.generate_caption,
47
+ inputs=gr.Image(type='pil'),
48
+ outputs="text",
49
+ title="Image Caption with Attention",
50
+ examples=['Image.jpg', 'Image 2.jpg'],
51
+ submit_btn='Generate Caption',
52
+ flagging_mode='never'
53
+ )
54
+
55
+
56
+ demo.launch()
image_captioning_model_state_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52b231076ad851a143939d672135b36965c3fec9d9d2531c9bd6417207e5a6e0
3
+ size 905995498
model_architecture.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+ import torch.nn as nn
4
+
5
+
6
+ class ImageCaptionGenerationWithAttention(nn.Module):
7
+ def __init__(self, vit_model, bart_model, tokenizer):
8
+ super().__init__()
9
+ self.tokenizer = tokenizer
10
+ self.vit = vit_model
11
+ self.bart = bart_model
12
+ self.visual_projection = nn.Linear(
13
+ vit_model.config.hidden_size, bart_model.config.d_model)
14
+
15
+ def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
16
+ vit_outputs = self.vit(pixel_values)
17
+ if isinstance(vit_outputs, tuple):
18
+ last_hidden_state = vit_outputs[0]
19
+ else:
20
+ last_hidden_state = vit_outputs.last_hidden_state
21
+
22
+ visual_features = self.visual_projection(last_hidden_state)
23
+
24
+ if input_ids is not None:
25
+ decoder_outputs = self.bart(
26
+ labels=input_ids,
27
+ encoder_outputs=BaseModelOutput(
28
+ last_hidden_state=visual_features),
29
+ return_dict=True
30
+ )
31
+ return decoder_outputs
32
+ else:
33
+ return visual_features
34
+
35
+ def generate(self, pixel_values, max_length=50, num_beams=5, early_stopping=True):
36
+ self.eval()
37
+ with torch.no_grad():
38
+ vit_outputs = self.vit(pixel_values)
39
+ if isinstance(vit_outputs, tuple):
40
+ last_hidden_state = vit_outputs[0]
41
+ else:
42
+ last_hidden_state = vit_outputs.last_hidden_state
43
+ visual_features = self.visual_projection(last_hidden_state)
44
+ generated_ids = self.bart.generate(
45
+ encoder_outputs=BaseModelOutput(
46
+ last_hidden_state=visual_features),
47
+ max_length=max_length,
48
+ num_beams=num_beams,
49
+ early_stopping=early_stopping,
50
+ decoder_start_token_id=self.tokenizer.bos_token_id,
51
+ eos_token_id=self.tokenizer.eos_token_id,
52
+ return_dict_in_generate=False
53
+ )
54
+ return generated_ids
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.4.1
2
+ transformers==4.35.2
3
+ gradio==5.0.2
4
+ python-dotenv==1.0.1