Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from transformers import pipeline | |
import spaces # This module is available when deploying on HF Spaces with ZeroGPU | |
import multiprocessing | |
multiprocessing.set_start_method("spawn", force=True) | |
# --- Trending models for image text-to-text tasks --- | |
TRENDING_MODELS = [ | |
"Salesforce/blip2-opt-2.7b", # Uses Blip2Config | |
"Salesforce/blip2-flan-t5-xl", # Uses Blip2Config | |
"Salesforce/instructblip-vicuna-7b", # Uses InstructBlipConfig | |
"llava-hf/llava-1.5-7b-hf", # Uses LlavaConfig | |
"liuhaotian/llava-v1.5-13b", # Uses LlavaConfig | |
"llava-hf/llava-v1.6-mistral-7b-hf", # Uses LlavaNextConfig | |
"Qwen/Qwen2-VL-7B-Instruct", # Uses Qwen2VLConfig | |
"google/pix2struct-ai2d-base", # Uses Pix2StructConfig | |
"nlpconnect/vit-gpt2-image-captioning", # Uses VisionEncoderDecoderConfig | |
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", # Uses LlavaOnevisionConfig | |
"mosaicml/mpt-7b-chat", # Uses MllamaConfig | |
"ibm-granite/granite-vision-3.1-2b-preview", | |
"allenai/Molmo-7B-D-0924" | |
] | |
# --- Helper: if the user selects "Custom", then they can enter any model identifier --- | |
def resolve_model(chosen, custom): | |
if chosen == "Custom": | |
return custom.strip() | |
else: | |
return chosen | |
# --- Main inference function --- | |
# The @spaces.GPU() decorator ensures that heavy inference runs on GPU in a ZeroGPU Space. | |
def compare_image_to_text_models(image, prompt, model1_choice, model1_custom, model2_choice, model2_custom): | |
# Determine which model identifiers to use. | |
model1_name = resolve_model(model1_choice, model1_custom) | |
model2_name = resolve_model(model2_choice, model2_custom) | |
# Set device to GPU (0) if USE_GPU is enabled; otherwise use CPU (-1) | |
device = 0 if os.environ.get("USE_GPU", "0") == "1" else -1 | |
# Create pipelines for image-to-text. | |
# These models should support a call signature of (image, prompt) | |
pipe1 = pipeline(task="image-text-to-text", model=model1_name, device=device) | |
pipe2 = pipeline(task="image-text-to-text", model=model2_name, device=device) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"url": image, | |
}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
{ | |
"role": "assistant", | |
"content": [ | |
{"type": "text", "text": ""}, | |
], | |
}, | |
] | |
# Run inference on the image with the provided prompt. | |
output1 = pipe1(text=messages, max_new_tokens=1024) | |
output2 = pipe2(text=messages, max_new_tokens=1024) | |
# Extract the generated text. | |
def extract_text(output): | |
if isinstance(output, list) and len(output) > 0 and isinstance(output[0], dict) and "generated_text" in output[0]: | |
return output[0]["generated_text"] | |
else: | |
return str(output) | |
result1 = extract_text(output1) | |
result2 = extract_text(output2) | |
# Format results as chat conversations. | |
# Each chatbot conversation is a list of (speaker, message) tuples. | |
chat1 = [("User", prompt), ("Bot", result1)] | |
chat2 = [("User", prompt), ("Bot", result2)] | |
return chat1, chat2 | |
# --- Build the Gradio interface --- | |
sample_prompt = "Describe the image in explicit detail. Return a nested JSON object in response." | |
with gr.Blocks(title="Image Text-to-Text Comparison Tool") as demo: | |
gr.Markdown( | |
""" | |
# Image Text-to-Text Comparison Tool | |
Compare two trending image text-to-text (instruction-following) models side-by-side. | |
Select a model from the dropdown (or choose Custom to enter your own model identifier) and see how it describes the image. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## Input") | |
image_input = gr.Image(label="Upload an Image", type="pil") | |
prompt_input = gr.Textbox(label="Text Prompt", value=sample_prompt, lines=3) | |
with gr.Column(scale=1): | |
gr.Markdown("## Model Selection") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Model 1") | |
model1_choice = gr.Dropdown( | |
choices=TRENDING_MODELS + ["Custom"], | |
value=TRENDING_MODELS[0], | |
label="Select Model 1" | |
) | |
model1_custom = gr.Textbox(label="Custom Model 1", placeholder="e.g., username/model_name") | |
with gr.Column(): | |
gr.Markdown("### Model 2") | |
model2_choice = gr.Dropdown( | |
choices=TRENDING_MODELS + ["Custom"], | |
value=TRENDING_MODELS[1], | |
label="Select Model 2" | |
) | |
model2_custom = gr.Textbox(label="Custom Model 2", placeholder="e.g., username/model_name") | |
compare_button = gr.Button("Compare Models") | |
gr.Markdown("## Chatbot Outputs (Side-by-Side)") | |
with gr.Row(): | |
chatbot1 = gr.Chatbot(label="Model 1 Chatbot") | |
chatbot2 = gr.Chatbot(label="Model 2 Chatbot") | |
compare_button.click( | |
fn=compare_image_to_text_models, | |
inputs=[image_input, prompt_input, model1_choice, model1_custom, model2_choice, model2_custom], | |
outputs=[chatbot1, chatbot2] | |
) | |
demo.launch() | |