Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| import os | |
| import gradio as gr | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor | |
| DESCRIPTION = "# InstructBLIP" | |
| MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024")) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_id = "Salesforce/instructblip-vicuna-7b" | |
| processor = InstructBlipProcessor.from_pretrained(model_id) | |
| model = InstructBlipForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
| def run( | |
| image: PIL.Image.Image, | |
| prompt: str, | |
| text_decoding_method: str = "Nucleus sampling", | |
| num_beams: int = 5, | |
| max_length: int = 256, | |
| min_length: int = 1, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.5, | |
| length_penalty: float = 1.0, | |
| temperature: float = 1.0, | |
| ) -> str: | |
| h, w = image.size | |
| scale = MAX_IMAGE_SIZE / max(h, w) | |
| if scale < 1: | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) | |
| inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
| generated_ids = model.generate( | |
| **inputs, | |
| do_sample=text_decoding_method == "Nucleus sampling", | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| min_length=min_length, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| temperature=temperature, | |
| ) | |
| return processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button() | |
| with gr.Accordion(label="Advanced options", open=False): | |
| text_decoding_method = gr.Radio( | |
| label="Text Decoding Method", | |
| choices=["Beam search", "Nucleus sampling"], | |
| value="Nucleus sampling", | |
| ) | |
| num_beams = gr.Slider( | |
| label="Number of Beams", | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| value=5, | |
| ) | |
| max_length = gr.Slider( | |
| label="Max Length", | |
| minimum=1, | |
| maximum=512, | |
| step=1, | |
| value=256, | |
| ) | |
| min_length = gr.Slider( | |
| label="Minimum Length", | |
| minimum=1, | |
| maximum=64, | |
| step=1, | |
| value=1, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P", | |
| minimum=0.1, | |
| maximum=1.0, | |
| step=0.1, | |
| value=0.9, | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition Penalty", | |
| info="Larger value prevents repetition.", | |
| minimum=1.0, | |
| maximum=5.0, | |
| step=0.5, | |
| value=1.5, | |
| ) | |
| length_penalty = gr.Slider( | |
| label="Length Penalty", | |
| info="Set to larger for longer sequence, used with beam search.", | |
| minimum=-1.0, | |
| maximum=2.0, | |
| step=0.2, | |
| value=1.0, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| info="Used with nucleus sampling.", | |
| minimum=0.5, | |
| maximum=1.0, | |
| step=0.1, | |
| value=1.0, | |
| ) | |
| with gr.Column(): | |
| output = gr.Textbox(label="Result") | |
| gr.on( | |
| triggers=[prompt.submit, run_button.click], | |
| fn=run, | |
| inputs=[ | |
| input_image, | |
| prompt, | |
| text_decoding_method, | |
| num_beams, | |
| max_length, | |
| min_length, | |
| top_p, | |
| repetition_penalty, | |
| length_penalty, | |
| temperature, | |
| ], | |
| outputs=output, | |
| api_name="run", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |