import gradio as gr import torch from transformers import AutoProcessor, Blip2ForConditionalGeneration # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Model ID MODEL_ID_FLAN_T5_XXL = "Salesforce/blip2-flan-t5-xxl" # Load the model and processor processor = AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL) model = Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_FLAN_T5_XXL, load_in_8bit=True).to(device) # Define a function for generating captions and answering questions def generate_text(image, text, decoding_method, temperature, length_penalty, repetition_penalty): if text.startswith("Caption:"): # Generate caption inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate( pixel_values=inputs.pixel_values, do_sample=decoding_method == "Nucleus sampling", temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_length=50, min_length=1, num_beams=5, top_p=0.9, ) result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return result else: # Answer question inputs = processor(images=image, text=text, return_tensors="pt").to(device, torch.float16) generated_ids = model.generate( **inputs, do_sample=decoding_method == "Nucleus sampling", temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_length=30, min_length=1, num_beams=5, top_p=0.9, ) result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() return result # Define Gradio input and output components image_input = gr.Image(type="numpy") text_input = gr.Text() output_text = gr.outputs.Textbox() # Define Gradio interface gr.Interface( fn=generate_text, inputs=[image_input, text_input, gr.inputs.Radio(["Beam search", "Nucleus sampling"]), gr.inputs.Slider(0.5, 1.0, 0.1), gr.inputs.Slider(-1.0, 2.0, 0.2), gr.inputs.Slider(1.0, 5.0, 0.5)], outputs=output_text, examples=[ ["house.png", "Caption:"], ["flower.jpg", "What is this flower and where is its origin?"], ["pizza.jpg", "Caption:"], ["sunset.jpg", "Caption:"], ["forbidden_city.webp", "In what dynasties was this place built?"], ], title="BLIP-2", description="Gradio demo for BLIP-2, image-to-text generation from Salesforce Research.", ).launch()