vlm_comparer / app.py
sflindrs's picture
Update app.py
e329bce verified
import os
import gradio as gr
from transformers import pipeline
import spaces # This module is available when deploying on HF Spaces with ZeroGPU
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)
# --- Trending models for image text-to-text tasks ---
TRENDING_MODELS = [
"Salesforce/blip2-opt-2.7b", # Uses Blip2Config
"Salesforce/blip2-flan-t5-xl", # Uses Blip2Config
"Salesforce/instructblip-vicuna-7b", # Uses InstructBlipConfig
"llava-hf/llava-1.5-7b-hf", # Uses LlavaConfig
"liuhaotian/llava-v1.5-13b", # Uses LlavaConfig
"llava-hf/llava-v1.6-mistral-7b-hf", # Uses LlavaNextConfig
"Qwen/Qwen2-VL-7B-Instruct", # Uses Qwen2VLConfig
"google/pix2struct-ai2d-base", # Uses Pix2StructConfig
"nlpconnect/vit-gpt2-image-captioning", # Uses VisionEncoderDecoderConfig
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf", # Uses LlavaOnevisionConfig
"mosaicml/mpt-7b-chat", # Uses MllamaConfig
"ibm-granite/granite-vision-3.1-2b-preview",
"allenai/Molmo-7B-D-0924"
]
# --- Helper: if the user selects "Custom", then they can enter any model identifier ---
def resolve_model(chosen, custom):
if chosen == "Custom":
return custom.strip()
else:
return chosen
# --- Main inference function ---
# The @spaces.GPU() decorator ensures that heavy inference runs on GPU in a ZeroGPU Space.
@spaces.GPU()
def compare_image_to_text_models(image, prompt, model1_choice, model1_custom, model2_choice, model2_custom):
# Determine which model identifiers to use.
model1_name = resolve_model(model1_choice, model1_custom)
model2_name = resolve_model(model2_choice, model2_custom)
# Set device to GPU (0) if USE_GPU is enabled; otherwise use CPU (-1)
device = 0 if os.environ.get("USE_GPU", "0") == "1" else -1
# Create pipelines for image-to-text.
# These models should support a call signature of (image, prompt)
pipe1 = pipeline(task="image-text-to-text", model=model1_name, device=device)
pipe2 = pipeline(task="image-text-to-text", model=model2_name, device=device)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"url": image,
},
{"type": "text", "text": prompt},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
],
},
]
# Run inference on the image with the provided prompt.
output1 = pipe1(text=messages, max_new_tokens=1024)
output2 = pipe2(text=messages, max_new_tokens=1024)
# Extract the generated text.
def extract_text(output):
if isinstance(output, list) and len(output) > 0 and isinstance(output[0], dict) and "generated_text" in output[0]:
return output[0]["generated_text"]
else:
return str(output)
result1 = extract_text(output1)
result2 = extract_text(output2)
# Format results as chat conversations.
# Each chatbot conversation is a list of (speaker, message) tuples.
chat1 = [("User", prompt), ("Bot", result1)]
chat2 = [("User", prompt), ("Bot", result2)]
return chat1, chat2
# --- Build the Gradio interface ---
sample_prompt = "Describe the image in explicit detail. Return a nested JSON object in response."
with gr.Blocks(title="Image Text-to-Text Comparison Tool") as demo:
gr.Markdown(
"""
# Image Text-to-Text Comparison Tool
Compare two trending image text-to-text (instruction-following) models side-by-side.
Select a model from the dropdown (or choose Custom to enter your own model identifier) and see how it describes the image.
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Input")
image_input = gr.Image(label="Upload an Image", type="pil")
prompt_input = gr.Textbox(label="Text Prompt", value=sample_prompt, lines=3)
with gr.Column(scale=1):
gr.Markdown("## Model Selection")
with gr.Row():
with gr.Column():
gr.Markdown("### Model 1")
model1_choice = gr.Dropdown(
choices=TRENDING_MODELS + ["Custom"],
value=TRENDING_MODELS[0],
label="Select Model 1"
)
model1_custom = gr.Textbox(label="Custom Model 1", placeholder="e.g., username/model_name")
with gr.Column():
gr.Markdown("### Model 2")
model2_choice = gr.Dropdown(
choices=TRENDING_MODELS + ["Custom"],
value=TRENDING_MODELS[1],
label="Select Model 2"
)
model2_custom = gr.Textbox(label="Custom Model 2", placeholder="e.g., username/model_name")
compare_button = gr.Button("Compare Models")
gr.Markdown("## Chatbot Outputs (Side-by-Side)")
with gr.Row():
chatbot1 = gr.Chatbot(label="Model 1 Chatbot")
chatbot2 = gr.Chatbot(label="Model 2 Chatbot")
compare_button.click(
fn=compare_image_to_text_models,
inputs=[image_input, prompt_input, model1_choice, model1_custom, model2_choice, model2_custom],
outputs=[chatbot1, chatbot2]
)
demo.launch()