mexma-siglip2 / app.py
visheratin's picture
Update app.py
c768c12 verified
raw
history blame
3.93 kB
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
# Check if flash_attn is available
def is_flash_attn_available():
try:
import flash_attn
return True
except ImportError:
return False
# Load model and tokenizer
@torch.inference_mode()
def load_model():
use_optimized = torch.cuda.is_available() and is_flash_attn_available()
model = AutoModel.from_pretrained(
"visheratin/mexma-siglip2",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
optimized=True if use_optimized else False,
)
if torch.cuda.is_available():
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2")
processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2")
return model, tokenizer, processor
model, tokenizer, processor = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
def classify_image(image, text_queries):
if image is None or not text_queries.strip():
return None
# Process image
processed_image = processor(images=image, return_tensors="pt")["pixel_values"]
processed_image = processed_image.to(torch.bfloat16)
if torch.cuda.is_available():
processed_image = processed_image.to("cuda")
# Process text queries
queries = [q.strip() for q in text_queries.split("\n") if q.strip()]
if not queries:
return None
text_inputs = tokenizer(queries, return_tensors="pt", padding=True)
if torch.cuda.is_available():
text_inputs = text_inputs.to("cuda")
# Get predictions
with torch.inference_mode():
image_logits, _ = model.get_logits(
text_inputs["input_ids"],
text_inputs["attention_mask"],
processed_image
)
probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist()
# Format results
results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))}
return results
# Create Gradio interface
with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo:
gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo")
gr.Markdown("""
This demo showcases the zero-shot classification capabilities of Mexma-SigLIP2 - state-of-the-art model for multilingual zero-shot classification.
### Instructions:
1. Upload or select an image
2. Enter text queries (one per line) to classify the image
3. Click 'Submit' to see the classification probabilities
The model supports multilingual queries (English, Russian, Hindi, etc.)
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(
placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower",
label="Text Queries",
lines=5
)
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column():
output = gr.Label(label="Classification Results")
submit_btn.click(
fn=classify_image,
inputs=[image_input, text_input],
outputs=output
)
gr.Examples(
[
[
"https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg",
"Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर"
],
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
"a cat\na dog\na bird\nкошка\nсобака"
]
],
inputs=[image_input, text_input]
)
# Launch the demo
if __name__ == "__main__":
demo.launch()