visheratin commited on
Commit
cea4766
·
verified ·
1 Parent(s): 288a63c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
6
+
7
+ # Check if flash_attn is available
8
+ def is_flash_attn_available():
9
+ try:
10
+ import flash_attn
11
+ return True
12
+ except ImportError:
13
+ return False
14
+
15
+ # Load model and tokenizer
16
+ @torch.inference_mode()
17
+ def load_model():
18
+ use_optimized = torch.cuda.is_available() and is_flash_attn_available()
19
+
20
+ model = AutoModel.from_pretrained(
21
+ "visheratin/mexma-siglip2",
22
+ torch_dtype=torch.bfloat16,
23
+ trust_remote_code=True,
24
+ optimized=True if use_optimized else False,
25
+ )
26
+ if torch.cuda.is_available():
27
+ model = model.to("cuda")
28
+ tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip2")
29
+ processor = AutoImageProcessor.from_pretrained("visheratin/mexma-siglip2")
30
+ return model, tokenizer, processor
31
+
32
+ model, tokenizer, processor = load_model()
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ def classify_image(image, text_queries):
36
+ if image is None or not text_queries.strip():
37
+ return None
38
+
39
+ # Process image
40
+ processed_image = processor(images=image, return_tensors="pt")["pixel_values"]
41
+ processed_image = processed_image.to(torch.bfloat16)
42
+ if torch.cuda.is_available():
43
+ processed_image = processed_image.to("cuda")
44
+
45
+ # Process text queries
46
+ queries = [q.strip() for q in text_queries.split("\n") if q.strip()]
47
+ if not queries:
48
+ return None
49
+
50
+ text_inputs = tokenizer(queries, return_tensors="pt", padding=True)
51
+ if torch.cuda.is_available():
52
+ text_inputs = text_inputs.to("cuda")
53
+
54
+ # Get predictions
55
+ with torch.inference_mode():
56
+ image_logits, _ = model.get_logits(
57
+ text_inputs["input_ids"],
58
+ text_inputs["attention_mask"],
59
+ processed_image
60
+ )
61
+ probs = F.softmax(image_logits, dim=-1)[0].cpu().tolist()
62
+
63
+ # Format results
64
+ results = {queries[i]: f"{probs[i]:.4f}" for i in range(len(queries))}
65
+ return results
66
+
67
+ # Create Gradio interface
68
+ with gr.Blocks(title="Mexma-SigLIP2 Zero-Shot Classification") as demo:
69
+ gr.Markdown("# Mexma-SigLIP2 Zero-Shot Classification Demo")
70
+ gr.Markdown("""
71
+ This demo showcases the zero-shot classification capabilities of the Mexma-SigLIP2 model.
72
+
73
+ ### Instructions:
74
+ 1. Upload or select an image
75
+ 2. Enter text queries (one per line) to classify the image
76
+ 3. Click 'Submit' to see the classification probabilities
77
+
78
+ The model supports multilingual queries (English, Russian, Hindi, etc.)
79
+ """)
80
+
81
+ with gr.Row():
82
+ with gr.Column():
83
+ image_input = gr.Image(type="pil", label="Upload Image")
84
+ text_input = gr.Textbox(
85
+ placeholder="Enter text queries (one per line)\nExample:\na cat\na dog\nEiffel Tower",
86
+ label="Text Queries",
87
+ lines=5
88
+ )
89
+ submit_btn = gr.Button("Submit", variant="primary")
90
+
91
+ with gr.Column():
92
+ output = gr.Label(label="Classification Results")
93
+
94
+ submit_btn.click(
95
+ fn=classify_image,
96
+ inputs=[image_input, text_input],
97
+ outputs=output
98
+ )
99
+
100
+ gr.Examples(
101
+ [
102
+ [
103
+ "https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg",
104
+ "Eiffel Tower\nStatue of Liberty\nTaj Mahal\nкошка\nएफिल टॉवर"
105
+ ],
106
+ [
107
+ "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg",
108
+ "a cat\na dog\na bird\nкошка\nсобака"
109
+ ]
110
+ ],
111
+ inputs=[image_input, text_input]
112
+ )
113
+
114
+ # Launch the demo
115
+ if __name__ == "__main__":
116
+ demo.launch()