Malaji71 commited on
Commit
a7d8c02
·
verified ·
1 Parent(s): 05bb27a

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +362 -0
models.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model management for FLUX Prompt Optimizer
3
+ Handles Florence-2 and Bagel model integration
4
+ """
5
+
6
+ import logging
7
+ import requests
8
+ import spaces
9
+ import torch
10
+ from typing import Optional, Dict, Any, Tuple
11
+ from PIL import Image
12
+ from transformers import AutoProcessor, AutoModelForCausalLM
13
+
14
+ from config import MODEL_CONFIG, get_device_config
15
+ from utils import clean_memory, safe_execute
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class BaseImageAnalyzer:
21
+ """Base class for image analysis models"""
22
+
23
+ def __init__(self):
24
+ self.model = None
25
+ self.processor = None
26
+ self.device_config = get_device_config()
27
+ self.is_initialized = False
28
+
29
+ def initialize(self) -> bool:
30
+ """Initialize the model"""
31
+ raise NotImplementedError
32
+
33
+ def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
34
+ """Analyze image and return description"""
35
+ raise NotImplementedError
36
+
37
+ def cleanup(self) -> None:
38
+ """Clean up model resources"""
39
+ if self.model is not None:
40
+ del self.model
41
+ self.model = None
42
+ if self.processor is not None:
43
+ del self.processor
44
+ self.processor = None
45
+ clean_memory()
46
+
47
+
48
+ class Florence2Analyzer(BaseImageAnalyzer):
49
+ """Florence-2 model for image analysis"""
50
+
51
+ def __init__(self):
52
+ super().__init__()
53
+ self.config = MODEL_CONFIG["florence2"]
54
+
55
+ def initialize(self) -> bool:
56
+ """Initialize Florence-2 model"""
57
+ if self.is_initialized:
58
+ return True
59
+
60
+ try:
61
+ logger.info("Initializing Florence-2 model...")
62
+
63
+ model_id = self.config["model_id"]
64
+
65
+ # Load processor
66
+ self.processor = AutoProcessor.from_pretrained(
67
+ model_id,
68
+ trust_remote_code=self.config["trust_remote_code"]
69
+ )
70
+
71
+ # Load model
72
+ self.model = AutoModelForCausalLM.from_pretrained(
73
+ model_id,
74
+ trust_remote_code=self.config["trust_remote_code"],
75
+ torch_dtype=self.config["torch_dtype"] if self.device_config["use_gpu"] else torch.float32
76
+ )
77
+
78
+ # Move to appropriate device
79
+ if self.device_config["use_gpu"]:
80
+ self.model = self.model.to(self.device_config["device"])
81
+ else:
82
+ self.model = self.model.to("cpu")
83
+
84
+ self.model.eval()
85
+ self.is_initialized = True
86
+
87
+ logger.info(f"Florence-2 initialized on {self.device_config['device']}")
88
+ return True
89
+
90
+ except Exception as e:
91
+ logger.error(f"Florence-2 initialization failed: {e}")
92
+ self.cleanup()
93
+ return False
94
+
95
+ @spaces.GPU(duration=60)
96
+ def _gpu_inference(self, image: Image.Image, task_prompt: str) -> str:
97
+ """Run inference on GPU with spaces decorator"""
98
+ try:
99
+ # Move model to GPU for inference
100
+ if self.device_config["use_gpu"]:
101
+ self.model = self.model.to("cuda")
102
+
103
+ # Prepare inputs
104
+ inputs = self.processor(text=task_prompt, images=image, return_tensors="pt")
105
+
106
+ # Move inputs to device
107
+ device = "cuda" if self.device_config["use_gpu"] else self.device_config["device"]
108
+ inputs = {k: v.to(device) for k, v in inputs.items()}
109
+
110
+ # Generate response
111
+ with torch.no_grad():
112
+ if self.device_config["use_gpu"]:
113
+ with torch.cuda.amp.autocast(dtype=torch.float16):
114
+ generated_ids = self.model.generate(
115
+ input_ids=inputs["input_ids"],
116
+ pixel_values=inputs["pixel_values"],
117
+ max_new_tokens=self.config["max_new_tokens"],
118
+ num_beams=3,
119
+ do_sample=False
120
+ )
121
+ else:
122
+ generated_ids = self.model.generate(
123
+ input_ids=inputs["input_ids"],
124
+ pixel_values=inputs["pixel_values"],
125
+ max_new_tokens=self.config["max_new_tokens"],
126
+ num_beams=3,
127
+ do_sample=False
128
+ )
129
+
130
+ # Decode response
131
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
132
+ parsed = self.processor.post_process_generation(
133
+ generated_text,
134
+ task=task_prompt,
135
+ image_size=(image.width, image.height)
136
+ )
137
+
138
+ # Extract caption
139
+ if task_prompt in parsed:
140
+ return parsed[task_prompt]
141
+ else:
142
+ return str(parsed) if parsed else ""
143
+
144
+ except Exception as e:
145
+ logger.error(f"Florence-2 GPU inference failed: {e}")
146
+ return ""
147
+ finally:
148
+ # Move model back to CPU to free GPU memory
149
+ if self.device_config["use_gpu"]:
150
+ self.model = self.model.to("cpu")
151
+ clean_memory()
152
+
153
+ def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
154
+ """Analyze image using Florence-2"""
155
+ if not self.is_initialized:
156
+ success = self.initialize()
157
+ if not success:
158
+ return "Model initialization failed", {"error": "Florence-2 not available"}
159
+
160
+ try:
161
+ # Define analysis tasks
162
+ tasks = {
163
+ "detailed": "<DETAILED_CAPTION>",
164
+ "more_detailed": "<MORE_DETAILED_CAPTION>",
165
+ "caption": "<CAPTION>"
166
+ }
167
+
168
+ results = {}
169
+
170
+ # Run analysis for each task
171
+ for task_name, task_prompt in tasks.items():
172
+ if self.device_config["use_gpu"]:
173
+ result = self._gpu_inference(image, task_prompt)
174
+ else:
175
+ result = self._cpu_inference(image, task_prompt)
176
+ results[task_name] = result
177
+
178
+ # Choose best result
179
+ if results["more_detailed"]:
180
+ main_description = results["more_detailed"]
181
+ elif results["detailed"]:
182
+ main_description = results["detailed"]
183
+ else:
184
+ main_description = results["caption"] or "A photograph"
185
+
186
+ # Prepare metadata
187
+ metadata = {
188
+ "model": "Florence-2",
189
+ "device": self.device_config["device"],
190
+ "all_results": results,
191
+ "confidence": 0.85 # Florence-2 generally reliable
192
+ }
193
+
194
+ logger.info(f"Florence-2 analysis complete: {len(main_description)} chars")
195
+ return main_description, metadata
196
+
197
+ except Exception as e:
198
+ logger.error(f"Florence-2 analysis failed: {e}")
199
+ return "Analysis failed", {"error": str(e)}
200
+
201
+ def _cpu_inference(self, image: Image.Image, task_prompt: str) -> str:
202
+ """Run inference on CPU"""
203
+ try:
204
+ inputs = self.processor(text=task_prompt, images=image, return_tensors="pt")
205
+
206
+ with torch.no_grad():
207
+ generated_ids = self.model.generate(
208
+ input_ids=inputs["input_ids"],
209
+ pixel_values=inputs["pixel_values"],
210
+ max_new_tokens=self.config["max_new_tokens"],
211
+ num_beams=2, # Reduced for CPU
212
+ do_sample=False
213
+ )
214
+
215
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
216
+ parsed = self.processor.post_process_generation(
217
+ generated_text,
218
+ task=task_prompt,
219
+ image_size=(image.width, image.height)
220
+ )
221
+
222
+ if task_prompt in parsed:
223
+ return parsed[task_prompt]
224
+ else:
225
+ return str(parsed) if parsed else ""
226
+
227
+ except Exception as e:
228
+ logger.error(f"Florence-2 CPU inference failed: {e}")
229
+ return ""
230
+
231
+
232
+ class BagelAnalyzer(BaseImageAnalyzer):
233
+ """Bagel-7B model analyzer via API"""
234
+
235
+ def __init__(self):
236
+ super().__init__()
237
+ self.config = MODEL_CONFIG["bagel"]
238
+ self.session = requests.Session()
239
+
240
+ def initialize(self) -> bool:
241
+ """Initialize Bagel analyzer (API-based)"""
242
+ try:
243
+ # Test API connectivity
244
+ test_response = self.session.get(
245
+ self.config["api_url"],
246
+ timeout=self.config["timeout"]
247
+ )
248
+
249
+ if test_response.status_code == 200:
250
+ self.is_initialized = True
251
+ logger.info("Bagel API connection established")
252
+ return True
253
+ else:
254
+ logger.error(f"Bagel API not accessible: {test_response.status_code}")
255
+ return False
256
+
257
+ except Exception as e:
258
+ logger.error(f"Bagel initialization failed: {e}")
259
+ return False
260
+
261
+ def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
262
+ """Analyze image using Bagel-7B API"""
263
+ if not self.is_initialized:
264
+ success = self.initialize()
265
+ if not success:
266
+ return "Bagel API not available", {"error": "API connection failed"}
267
+
268
+ try:
269
+ # Convert image to base64 or prepare for API call
270
+ # Note: This is a placeholder - actual implementation would depend on Bagel API format
271
+
272
+ # For now, return a placeholder response
273
+ # In real implementation, you would:
274
+ # 1. Convert image to required format
275
+ # 2. Make API call to Bagel endpoint
276
+ # 3. Parse response
277
+
278
+ description = "Detailed image analysis via Bagel-7B (API implementation needed)"
279
+ metadata = {
280
+ "model": "Bagel-7B",
281
+ "method": "API",
282
+ "confidence": 0.8
283
+ }
284
+
285
+ logger.info("Bagel analysis complete (placeholder)")
286
+ return description, metadata
287
+
288
+ except Exception as e:
289
+ logger.error(f"Bagel analysis failed: {e}")
290
+ return "Analysis failed", {"error": str(e)}
291
+
292
+
293
+ class ModelManager:
294
+ """Manager for handling multiple analysis models"""
295
+
296
+ def __init__(self, preferred_model: str = None):
297
+ self.preferred_model = preferred_model or MODEL_CONFIG["primary_model"]
298
+ self.analyzers = {}
299
+ self.current_analyzer = None
300
+
301
+ def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]:
302
+ """Get or create analyzer for specified model"""
303
+ model_name = model_name or self.preferred_model
304
+
305
+ if model_name not in self.analyzers:
306
+ if model_name == "florence2":
307
+ self.analyzers[model_name] = Florence2Analyzer()
308
+ elif model_name == "bagel":
309
+ self.analyzers[model_name] = BagelAnalyzer()
310
+ else:
311
+ logger.error(f"Unknown model: {model_name}")
312
+ return None
313
+
314
+ return self.analyzers[model_name]
315
+
316
+ def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
317
+ """Analyze image with specified or preferred model"""
318
+ analyzer = self.get_analyzer(model_name)
319
+ if analyzer is None:
320
+ return "No analyzer available", {"error": "Model not found"}
321
+
322
+ success, result = safe_execute(analyzer.analyze_image, image)
323
+ if success:
324
+ return result
325
+ else:
326
+ return "Analysis failed", {"error": result}
327
+
328
+ def cleanup_all(self) -> None:
329
+ """Clean up all model resources"""
330
+ for analyzer in self.analyzers.values():
331
+ analyzer.cleanup()
332
+ self.analyzers.clear()
333
+ clean_memory()
334
+
335
+
336
+ # Global model manager instance
337
+ model_manager = ModelManager()
338
+
339
+
340
+ def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
341
+ """
342
+ Convenience function for image analysis
343
+
344
+ Args:
345
+ image: PIL Image to analyze
346
+ model_name: Optional model name ("florence2" or "bagel")
347
+
348
+ Returns:
349
+ Tuple of (description, metadata)
350
+ """
351
+ return model_manager.analyze_image(image, model_name)
352
+
353
+
354
+ # Export main components
355
+ __all__ = [
356
+ "BaseImageAnalyzer",
357
+ "Florence2Analyzer",
358
+ "BagelAnalyzer",
359
+ "ModelManager",
360
+ "model_manager",
361
+ "analyze_image"
362
+ ]