Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from transformers import AutoModel, AutoTokenizer, AutoImageProcessor | |
| # Check if flash_attn is available | |
| def is_flash_attn_available(): | |
| try: | |
| import flash_attn | |
| return True | |
| except ImportError: | |
| return False | |
| # Load model and tokenizer | |
| def load_model(): | |
| use_optimized = torch.cuda.is_available() and is_flash_attn_available() | |
| model = AutoModel.from_pretrained( | |
| "visheratin/mexma-siglip2", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| optimized=True if use_optimized else False, | |
| ) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2") | |
| processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2") | |
| return model, tokenizer, processor | |
| model, tokenizer, processor = load_model() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def classify_image(image, text_queries): | |
| if image is None or not text_queries.strip(): | |
| return None | |
| # Process image | |
| processed_image = processor(images=image, return_tensors="pt")["pixel_values"] | |
| processed_image = processed_image.to(torch.bfloat16) | |
| if torch.cuda.is_available(): | |
| processed_image = processed_image.to("cuda") | |
| # Process text queries | |
| queries = [q.strip() for q in text_queries.split("\n") if q.strip()] | |
| if not queries: | |
| return None | |
| text_inputs = tokenizer(queries, return_tensors="pt", padding=True) | |
| if torch.cuda.is_available(): | |
| text_inputs = text_inputs.to("cuda") | |
| # Get predictions | |
| with torch.inference_mode(): | |
| image_logits, _ = model.get_logits( | |
| text_inputs["input_ids"], | |
| text_inputs["attention_mask"], | |
| processed_image | |
| ) | |
| probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist() | |
| # Format results | |
| results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))} | |
| return results | |
| # Create Gradio interface | |
| with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo: | |
| gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo") | |
| gr.Markdown(""" | |
| This demo showcases the zero-shot classification capabilities of Mexma-SigLIP2 - state-of-the-art model for multilingual zero-shot classification. | |
| ### Instructions: | |
| 1. Upload or select an image | |
| 2. Enter text queries (one per line) to classify the image | |
| 3. Click 'Submit' to see the classification probabilities | |
| The model supports multilingual queries (English, Russian, Hindi, etc.) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox( | |
| placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower", | |
| label="Text Queries", | |
| lines=5 | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Column(): | |
| output = gr.Label(label="Classification Results") | |
| submit_btn.click( | |
| fn=classify_image, | |
| inputs=[image_input, text_input], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| [ | |
| [ | |
| "https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg", | |
| "Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर" | |
| ], | |
| [ | |
| "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg", | |
| "a cat\na dog\na bird\nкошка\nсобака" | |
| ] | |
| ], | |
| inputs=[image_input, text_input] | |
| ) | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo.launch() | |