Nepjune commited on
Commit
813ddfe
·
verified ·
1 Parent(s): 8343365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -1,21 +1,31 @@
1
- Run pip install transformers
2
- Run pip install gradio
3
  from PIL import Image
4
- import requests
5
- import gradio as gr
6
 
7
- from transformers import BlipProcessor, BlipForConditionalGeneration
 
 
8
 
9
- model_id = "Salesforce/blip-image-captioning-base"
 
10
 
11
- model = BlipForConditionalGeneration.from_pretrained(model_id)
12
- processor = BlipProcessor.from_pretrained(model_id)
 
13
 
14
- def launch(input):
15
- image = Image.open(requests.get(input, stream=True).raw).convert('RGB')
16
- inputs = processor(image, return_tensors="pt")
17
- out = model.generate(**inputs)
18
- return processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
19
 
20
- iface = gr.Interface(launch, inputs="text", outputs="text")
21
- iface.launch()
 
1
+ from tranformers import VisionEncoderDecoderModle, ViTImageProcer, Autotokenizer
2
+ import torch
3
  from PIL import Image
 
 
4
 
5
+ model = VisionEncoderDecoderModle.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
6
+ feature_external = ViTImageProcer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
7
+ tokenizer = Autotokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
8
 
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ model.to(device)
11
 
12
+ max_length = 16
13
+ num_beams = 4
14
+ gen_kwargs = ("max_length" : max_length, "num_beams" : num_beams)
15
 
16
+ def predict_caption(image_paths):
17
+ images = []
18
+ for image_path in image_paths:
19
+ image = Image.open(image_path)
20
+ if image.mode != "RGB":
21
+ image = image.convert(mode="RGB")
22
+ images.append(image)
23
+
24
+ pixel_values = feature_extractor(images=images, return_pixel_values=True).pixel_values
25
+ pixel_values = pixel_values.to(device)
26
+
27
+ output_ids = model.generate(pixel_values, **gen_kwargs)
28
+
29
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
30
+ return preds
31