import re import io import torch import gradio as gr from PIL import Image from transformers import OwlViTProcessor, OwlViTForImageClassification # Load the model and processor model_id = "google/owlvit-base-patch16" device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the model and processor model = OwlViTForImageClassification.from_pretrained(model_id).to(device) processor = OwlViTProcessor.from_pretrained(model_id) def generate_model_response(image_file, user_query): """ Processes the uploaded image and user query to generate a response from the model. Parameters: - image_file: The uploaded image file. - user_query: The user's question about the image. Returns: - str: The generated response from the model. """ try: # Load and prepare the image raw_image = Image.open(image_file).convert("RGB") # Prepare inputs for the model using the processor inputs = processor(images=raw_image, text=user_query, return_tensors="pt").to(device) # Generate response from the model outputs = model(**inputs) # Decode and return the response response_text = outputs.logits.argmax(dim=-1) # Example of how to process output return f"Detected class ID: {response_text.item()}" except Exception as e: print(f"Error in generating response: {e}") return f"An error occurred: {str(e)}" # Gradio Interface iface = gr.Interface( fn=generate_model_response, inputs=[ gr.Image(type="file", label="Upload Image"), gr.Textbox(label="Enter your question", placeholder="What do you want to know about this image?") ], outputs="text", ) iface.launch(share=True)