Spaces:
Runtime error
Runtime error
Commit
·
1c3da59
1
Parent(s):
375dc5c
Attempt to fix `InstructBlipForConditionalGeneration`
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ blip2_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-6.7b")
|
|
21 |
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", torch_dtype=torch.float16)
|
22 |
|
23 |
instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
24 |
-
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b"
|
25 |
|
26 |
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
|
27 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
@@ -45,7 +45,8 @@ def generate_caption_blip2(processor, model, image, replace_token=False):
|
|
45 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16)
|
46 |
|
47 |
generated_ids = model.generate(pixel_values=inputs.pixel_values,
|
48 |
-
num_beams=5, max_length=50, min_length=1, top_p=0.9,
|
|
|
49 |
if replace_token:
|
50 |
# TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
|
51 |
generated_ids[generated_ids == 0] = 2
|
|
|
21 |
blip2_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b", device_map="auto", torch_dtype=torch.float16)
|
22 |
|
23 |
instructblip_processor = AutoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
24 |
+
instructblip_model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b").to(device)
|
25 |
|
26 |
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
|
27 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
|
|
45 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16)
|
46 |
|
47 |
generated_ids = model.generate(pixel_values=inputs.pixel_values,
|
48 |
+
num_beams=5, max_length=50, min_length=1, top_p=0.9,
|
49 |
+
repetition_penalty=1.5, length_penalty=1.0, temperature=1)
|
50 |
if replace_token:
|
51 |
# TODO remove once https://github.com/huggingface/transformers/pull/24492 is merged
|
52 |
generated_ids[generated_ids == 0] = 2
|