Nepjune commited on
Commit
fe24d04
·
verified ·
1 Parent(s): 7872b1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -26
app.py CHANGED
@@ -1,31 +1,19 @@
1
- from transformers import ViTFeatureExtractor, ViTForImageToText, AutoTokenizer
 
2
 
3
- import torch
4
- from PIL import Image
 
5
 
6
- model = ViTForImageToText.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
7
- feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
 
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model.to(device)
12
 
13
- max_length = 16
14
- num_beams = 4
15
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
16
 
17
- def predict_caption(image_paths):
18
- images = []
19
- for image_path in image_paths:
20
- image = Image.open(image_path)
21
- if image.mode != "RGB":
22
- image = image.convert(mode="RGB")
23
- images.append(image)
24
-
25
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
26
- pixel_values = pixel_values.to(device)
27
-
28
- output_ids = model.generate(pixel_values, **gen_kwargs)
29
-
30
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
31
- return preds
 
1
+ import gradio as gr
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration
3
 
4
+ model_id = "dblasko/blip-dalle3-img2prompt"
5
+ model = BlipForConditionalGeneration.from_pretrained(model_id)
6
+ processor = BlipProcessor.from_pretrained(model_id)
7
 
8
+ def generate_caption(image):
9
+ inputs = processor(images=image, return_tensors="pt")
10
+ pixel_values = inputs.pixel_values
11
 
12
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
13
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True, temperature=0.8, top_k=40, top_p=0.9)[0]
14
 
15
+ return generated_caption
 
 
16
 
17
+ # Create a gradio interface with an image input and a textbox output
18
+ demo = gr.Interface(fn=generate_caption, inputs=gr.Image(shape=(224, 224)), outputs=gr.Textbox(label="Generated caption"))
19
+ demo.launch()