Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from transformers import AutoModel, AutoTokenizer, AutoImageProcessor | |
# Check if flash_attn is available | |
def is_flash_attn_available(): | |
try: | |
import flash_attn | |
return True | |
except ImportError: | |
return False | |
# Load model and tokenizer | |
def load_model(): | |
use_optimized = torch.cuda.is_available() and is_flash_attn_available() | |
model = AutoModel.from_pretrained( | |
"visheratin/mexma-siglip2", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
optimized=True if use_optimized else False, | |
) | |
if torch.cuda.is_available(): | |
model = model.to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2") | |
processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2") | |
return model, tokenizer, processor | |
model, tokenizer, processor = load_model() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def classify_image(image, text_queries): | |
if image is None or not text_queries.strip(): | |
return None | |
# Process image | |
processed_image = processor(images=image, return_tensors="pt")["pixel_values"] | |
processed_image = processed_image.to(torch.bfloat16) | |
if torch.cuda.is_available(): | |
processed_image = processed_image.to("cuda") | |
# Process text queries | |
queries = [q.strip() for q in text_queries.split("\n") if q.strip()] | |
if not queries: | |
return None | |
text_inputs = tokenizer(queries, return_tensors="pt", padding=True) | |
if torch.cuda.is_available(): | |
text_inputs = text_inputs.to("cuda") | |
# Get predictions | |
with torch.inference_mode(): | |
image_logits, _ = model.get_logits( | |
text_inputs["input_ids"], | |
text_inputs["attention_mask"], | |
processed_image | |
) | |
probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist() | |
# Format results | |
results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))} | |
return results | |
# Create Gradio interface | |
with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo: | |
gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo") | |
gr.Markdown(""" | |
This demo showcases the zero-shot classification capabilities of Mexma-SigLIP2 - state-of-the-art model for multilingual zero-shot classification. | |
### Instructions: | |
1. Upload or select an image | |
2. Enter text queries (one per line) to classify the image | |
3. Click 'Submit' to see the classification probabilities | |
The model supports multilingual queries (English, Russian, Hindi, etc.) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
text_input = gr.Textbox( | |
placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower", | |
label="Text Queries", | |
lines=5 | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
output = gr.Label(label="Classification Results") | |
submit_btn.click( | |
fn=classify_image, | |
inputs=[image_input, text_input], | |
outputs=output | |
) | |
gr.Examples( | |
[ | |
[ | |
"https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg", | |
"Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर" | |
], | |
[ | |
"https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg", | |
"a cat\na dog\na bird\nкошка\nсобака" | |
] | |
], | |
inputs=[image_input, text_input] | |
) | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.launch() | |