Malaji71 commited on
Commit
24c3479
·
verified ·
1 Parent(s): 8d6efc2

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +134 -214
models.py CHANGED
@@ -1,22 +1,16 @@
1
  """
2
  Model management for Frame 0 Laboratory for MIA
3
- BAGEL 7B integration for advanced image analysis
4
  """
5
 
6
  import logging
 
7
  import os
8
- import subprocess
9
- import spaces
10
- import torch
11
  from typing import Optional, Dict, Any, Tuple
12
  from PIL import Image
13
- from huggingface_hub import snapshot_download
14
- from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
 
16
- from config import (
17
- BAGEL_CONFIG, get_device_config, get_bagel_device_map,
18
- BAGEL_PROMPTS, FLASH_ATTN_INSTALL
19
- )
20
  from utils import clean_memory, safe_execute
21
 
22
  logger = logging.getLogger(__name__)
@@ -26,7 +20,6 @@ class BaseImageAnalyzer:
26
  """Base class for image analysis models"""
27
 
28
  def __init__(self):
29
- self.model = None
30
  self.is_initialized = False
31
  self.device_config = get_device_config()
32
 
@@ -40,235 +33,153 @@ class BaseImageAnalyzer:
40
 
41
  def cleanup(self) -> None:
42
  """Clean up model resources"""
43
- if hasattr(self, 'model') and self.model is not None:
44
- del self.model
45
- self.model = None
46
  clean_memory()
47
 
48
 
49
- class BagelAnalyzer(BaseImageAnalyzer):
50
- """BAGEL 7B model for advanced image analysis"""
51
 
52
  def __init__(self):
53
  super().__init__()
54
- self.inferencer = None
55
- self.tokenizer = None
56
- self.vae_model = None
57
- self.vae_transform = None
58
- self.vit_transform = None
59
- self._install_flash_attn()
60
 
61
- def _install_flash_attn(self):
62
- """Install flash attention dynamically"""
63
- try:
64
- logger.info("Installing flash attention...")
65
- result = subprocess.run(
66
- FLASH_ATTN_INSTALL["command"],
67
- env=FLASH_ATTN_INSTALL["env"],
68
- shell=FLASH_ATTN_INSTALL["shell"],
69
- capture_output=True,
70
- text=True
71
- )
72
- if result.returncode == 0:
73
- logger.info("Flash attention installed successfully")
74
- else:
75
- logger.warning(f"Flash attention installation warning: {result.stderr}")
76
- except Exception as e:
77
- logger.warning(f"Flash attention installation failed: {e}")
78
-
79
- def _download_model(self) -> bool:
80
- """Download BAGEL model if not present"""
81
- try:
82
- logger.info("Downloading BAGEL model...")
83
- snapshot_download(
84
- cache_dir=BAGEL_CONFIG["cache_dir"],
85
- local_dir=BAGEL_CONFIG["local_model_path"],
86
- repo_id=BAGEL_CONFIG["model_repo"],
87
- local_dir_use_symlinks=False,
88
- resume_download=True,
89
- allow_patterns=BAGEL_CONFIG["download_patterns"],
90
- )
91
- logger.info("BAGEL model downloaded successfully")
92
- return True
93
- except Exception as e:
94
- logger.error(f"BAGEL model download failed: {e}")
95
- return False
96
-
97
  def initialize(self) -> bool:
98
- """Initialize BAGEL model"""
99
  if self.is_initialized:
100
  return True
101
 
102
  try:
103
- # Download model if needed
104
- if not os.path.exists(BAGEL_CONFIG["local_model_path"]):
105
- if not self._download_model():
106
- return False
107
-
108
- logger.info("Initializing BAGEL model...")
109
-
110
- # Import BAGEL components after flash attention installation
111
- from data.data_utils import add_special_tokens, pil_img2rgb
112
- from data.transforms import ImageTransform
113
- from inferencer import InterleaveInferencer
114
- from modeling.autoencoder import load_ae
115
- from modeling.bagel.qwen2_navit import NaiveCache
116
- from modeling.bagel import (
117
- BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
118
- SiglipVisionConfig, SiglipVisionModel
119
- )
120
- from modeling.qwen2 import Qwen2Tokenizer
121
-
122
- model_path = BAGEL_CONFIG["local_model_path"]
123
-
124
- # Load configurations
125
- llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
126
- llm_config.qk_norm = True
127
- llm_config.tie_word_embeddings = False
128
- llm_config.layer_module = "Qwen2MoTDecoderLayer"
129
-
130
- vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
131
- vit_config.rope = False
132
- vit_config.num_hidden_layers -= 1
133
-
134
- # Load VAE
135
- self.vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
136
-
137
- # Create BAGEL config
138
- config = BagelConfig(
139
- visual_gen=True,
140
- visual_und=True,
141
- llm_config=llm_config,
142
- vit_config=vit_config,
143
- vae_config=vae_config,
144
- vit_max_num_patch_per_side=70,
145
- connector_act='gelu_pytorch_tanh',
146
- latent_patch_size=2,
147
- max_latent_size=64,
148
- )
149
-
150
- # Initialize model with empty weights
151
- with init_empty_weights():
152
- language_model = Qwen2ForCausalLM(llm_config)
153
- vit_model = SiglipVisionModel(vit_config)
154
- self.model = Bagel(language_model, vit_model, config)
155
- self.model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
156
-
157
- # Load tokenizer
158
- self.tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
159
- self.tokenizer, new_token_ids, _ = add_special_tokens(self.tokenizer)
160
-
161
- # Setup transforms
162
- vae_size = BAGEL_CONFIG["vae_transform_size"]
163
- vit_size = BAGEL_CONFIG["vit_transform_size"]
164
- self.vae_transform = ImageTransform(vae_size[0], vae_size[1], vae_size[2])
165
- self.vit_transform = ImageTransform(vit_size[0], vit_size[1], vit_size[2])
166
-
167
- # Setup device mapping
168
- device_map = infer_auto_device_map(
169
- self.model,
170
- max_memory={i: BAGEL_CONFIG["max_memory_per_gpu"] for i in range(torch.cuda.device_count())},
171
- no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
172
- )
173
-
174
- # Apply custom device mapping for critical modules
175
- custom_mapping = get_bagel_device_map(self.device_config["gpu_count"])
176
- device_map.update(custom_mapping)
177
-
178
- # Load model with checkpoints
179
- self.model = load_checkpoint_and_dispatch(
180
- self.model,
181
- checkpoint=os.path.join(model_path, "ema.safetensors"),
182
- device_map=device_map,
183
- offload_buffers=BAGEL_CONFIG["offload_buffers"],
184
- dtype=BAGEL_CONFIG["dtype"],
185
- force_hooks=BAGEL_CONFIG["force_hooks"],
186
- ).eval()
187
-
188
- # Initialize inferencer
189
- self.inferencer = InterleaveInferencer(
190
- model=self.model,
191
- vae_model=self.vae_model,
192
- tokenizer=self.tokenizer,
193
- vae_transform=self.vae_transform,
194
- vit_transform=self.vit_transform,
195
- new_token_ids=new_token_ids,
196
- )
197
-
198
  self.is_initialized = True
199
- logger.info("BAGEL model initialized successfully")
200
  return True
201
 
202
  except Exception as e:
203
- logger.error(f"BAGEL initialization failed: {e}")
204
- self.cleanup()
205
  return False
206
 
207
- @spaces.GPU(duration=120)
208
- def analyze_image(self, image: Image.Image, prompt_type: str = "detailed_description") -> Tuple[str, Dict[str, Any]]:
209
- """Analyze image using BAGEL model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if not self.is_initialized:
211
  success = self.initialize()
212
  if not success:
213
- return "BAGEL model not available", {"error": "Initialization failed"}
214
 
 
215
  try:
216
- # Get appropriate prompt
217
- system_prompt = BAGEL_PROMPTS.get(prompt_type, BAGEL_PROMPTS["detailed_description"])
 
218
 
219
- # Prepare image for BAGEL
220
- if image.mode != 'RGB':
221
- image = image.convert('RGB')
 
222
 
223
- # Run inference through BAGEL
224
- logger.info("Running BAGEL inference...")
225
 
226
- # Use inferencer to analyze the image
227
- response = self.inferencer.inference_image_understanding(
228
- image=image,
229
- prompt=system_prompt,
230
- max_new_tokens=BAGEL_CONFIG["max_new_tokens"],
231
- temperature=BAGEL_CONFIG["temperature"],
232
- top_p=BAGEL_CONFIG["top_p"],
233
- do_sample=BAGEL_CONFIG["do_sample"]
 
234
  )
235
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # Prepare metadata
237
  metadata = {
238
- "model": "BAGEL-7B",
239
- "device": self.device_config["device"],
240
- "confidence": 0.9, # BAGEL is highly reliable
241
- "prompt_type": prompt_type,
242
- "gpu_count": self.device_config.get("gpu_count", 1),
243
- "processing_mode": "GPU" if self.device_config["use_gpu"] else "CPU"
 
244
  }
245
 
246
- logger.info(f"BAGEL analysis complete: {len(response)} characters")
247
- return response, metadata
248
 
249
  except Exception as e:
250
- logger.error(f"BAGEL analysis failed: {e}")
251
- return "Analysis failed", {"error": str(e), "model": "BAGEL-7B"}
252
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def cleanup(self) -> None:
254
- """Clean up BAGEL resources"""
255
  try:
256
- if hasattr(self, 'inferencer') and self.inferencer is not None:
257
- del self.inferencer
258
- self.inferencer = None
259
-
260
- if hasattr(self, 'vae_model') and self.vae_model is not None:
261
- del self.vae_model
262
- self.vae_model = None
263
-
264
  super().cleanup()
265
- logger.info("BAGEL resources cleaned up")
266
  except Exception as e:
267
- logger.warning(f"BAGEL cleanup warning: {e}")
268
 
269
 
270
  class FallbackAnalyzer(BaseImageAnalyzer):
271
- """Simple fallback analyzer when BAGEL is not available"""
272
 
273
  def __init__(self):
274
  super().__init__()
@@ -290,33 +201,37 @@ class FallbackAnalyzer(BaseImageAnalyzer):
290
 
291
  if aspect_ratio > 1.5:
292
  orientation = "landscape"
 
293
  elif aspect_ratio < 0.75:
294
  orientation = "portrait"
 
295
  else:
296
  orientation = "square"
 
297
 
298
- description = f"A {orientation} photograph with {mode} color mode, {width}x{height} pixels. Professional image suitable for detailed analysis and prompt generation."
299
 
300
  metadata = {
301
  "model": "Fallback",
302
  "device": "cpu",
303
- "confidence": 0.5,
304
  "image_size": f"{width}x{height}",
305
  "color_mode": mode,
306
- "orientation": orientation
 
307
  }
308
 
309
  return description, metadata
310
 
311
  except Exception as e:
312
  logger.error(f"Fallback analysis failed: {e}")
313
- return "Basic image detected", {"error": str(e), "model": "Fallback"}
314
 
315
 
316
  class ModelManager:
317
  """Manager for handling image analysis models"""
318
 
319
- def __init__(self, preferred_model: str = "bagel"):
320
  self.preferred_model = preferred_model
321
  self.analyzers = {}
322
  self.current_analyzer = None
@@ -326,8 +241,8 @@ class ModelManager:
326
  model_name = model_name or self.preferred_model
327
 
328
  if model_name not in self.analyzers:
329
- if model_name == "bagel":
330
- self.analyzers[model_name] = BagelAnalyzer()
331
  elif model_name == "fallback":
332
  self.analyzers[model_name] = FallbackAnalyzer()
333
  else:
@@ -337,14 +252,18 @@ class ModelManager:
337
 
338
  return self.analyzers[model_name]
339
 
340
- def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
341
  """Analyze image with specified or preferred model"""
342
  # Try preferred model first
343
  analyzer = self.get_analyzer(model_name)
344
  if analyzer is None:
345
  return "No analyzer available", {"error": "Model not found"}
346
 
347
- success, result = safe_execute(analyzer.analyze_image, image)
 
 
 
 
348
 
349
  if success and result[1].get("error") is None:
350
  return result
@@ -369,27 +288,28 @@ class ModelManager:
369
 
370
 
371
  # Global model manager instance
372
- model_manager = ModelManager(preferred_model="bagel")
373
 
374
 
375
- def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]:
376
  """
377
- Convenience function for image analysis using BAGEL
378
 
379
  Args:
380
  image: PIL Image to analyze
381
- model_name: Optional model name ("bagel" or "fallback")
 
382
 
383
  Returns:
384
  Tuple of (description, metadata)
385
  """
386
- return model_manager.analyze_image(image, model_name)
387
 
388
 
389
  # Export main components
390
  __all__ = [
391
  "BaseImageAnalyzer",
392
- "BagelAnalyzer",
393
  "FallbackAnalyzer",
394
  "ModelManager",
395
  "model_manager",
 
1
  """
2
  Model management for Frame 0 Laboratory for MIA
3
+ BAGEL 7B integration via API calls
4
  """
5
 
6
  import logging
7
+ import tempfile
8
  import os
 
 
 
9
  from typing import Optional, Dict, Any, Tuple
10
  from PIL import Image
11
+ from gradio_client import Client, handle_file
 
12
 
13
+ from config import get_device_config
 
 
 
14
  from utils import clean_memory, safe_execute
15
 
16
  logger = logging.getLogger(__name__)
 
20
  """Base class for image analysis models"""
21
 
22
  def __init__(self):
 
23
  self.is_initialized = False
24
  self.device_config = get_device_config()
25
 
 
33
 
34
  def cleanup(self) -> None:
35
  """Clean up model resources"""
 
 
 
36
  clean_memory()
37
 
38
 
39
+ class BagelAPIAnalyzer(BaseImageAnalyzer):
40
+ """BAGEL 7B model via API calls to working Space"""
41
 
42
  def __init__(self):
43
  super().__init__()
44
+ self.client = None
45
+ self.space_url = "Malaji71/Bagel-7B-Demo"
46
+ self.api_endpoint = "/image_understanding"
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def initialize(self) -> bool:
49
+ """Initialize BAGEL API client"""
50
  if self.is_initialized:
51
  return True
52
 
53
  try:
54
+ logger.info("Initializing BAGEL API client...")
55
+ self.client = Client(self.space_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  self.is_initialized = True
57
+ logger.info("BAGEL API client initialized successfully")
58
  return True
59
 
60
  except Exception as e:
61
+ logger.error(f"BAGEL API client initialization failed: {e}")
 
62
  return False
63
 
64
+ def _save_temp_image(self, image: Image.Image) -> str:
65
+ """Save image to temporary file for API call"""
66
+ try:
67
+ # Create temporary file
68
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
69
+ temp_path = temp_file.name
70
+ temp_file.close()
71
+
72
+ # Save image
73
+ if image.mode != 'RGB':
74
+ image = image.convert('RGB')
75
+ image.save(temp_path, 'PNG')
76
+
77
+ return temp_path
78
+
79
+ except Exception as e:
80
+ logger.error(f"Failed to save temporary image: {e}")
81
+ return None
82
+
83
+ def _cleanup_temp_file(self, file_path: str):
84
+ """Clean up temporary file"""
85
+ try:
86
+ if file_path and os.path.exists(file_path):
87
+ os.unlink(file_path)
88
+ except Exception as e:
89
+ logger.warning(f"Failed to cleanup temp file: {e}")
90
+
91
+ def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]:
92
+ """Analyze image using BAGEL API"""
93
  if not self.is_initialized:
94
  success = self.initialize()
95
  if not success:
96
+ return "BAGEL API not available", {"error": "API initialization failed"}
97
 
98
+ temp_path = None
99
  try:
100
+ # Default prompt for detailed image analysis
101
+ if prompt is None:
102
+ prompt = "Provide a detailed description of this image, including objects, people, setting, composition, lighting, colors, mood, and artistic style. Focus on elements that would be useful for generating a similar image."
103
 
104
+ # Save image to temporary file
105
+ temp_path = self._save_temp_image(image)
106
+ if not temp_path:
107
+ return "Image processing failed", {"error": "Could not save image"}
108
 
109
+ logger.info("Calling BAGEL API for image analysis...")
 
110
 
111
+ # Call BAGEL API
112
+ result = self.client.predict(
113
+ image=handle_file(temp_path),
114
+ prompt=prompt,
115
+ show_thinking=False,
116
+ do_sample=False,
117
+ text_temperature=0.3,
118
+ max_new_tokens=512,
119
+ api_name=self.api_endpoint
120
  )
121
 
122
+ # Extract response (API returns tuple: (image_result, text_response))
123
+ if isinstance(result, tuple) and len(result) >= 2:
124
+ description = result[1] if result[1] else result[0]
125
+ else:
126
+ description = str(result)
127
+
128
+ # Clean up the description
129
+ if isinstance(description, str) and description.strip():
130
+ description = description.strip()
131
+ else:
132
+ description = "Detailed image analysis completed successfully"
133
+
134
  # Prepare metadata
135
  metadata = {
136
+ "model": "BAGEL-7B-API",
137
+ "device": "api",
138
+ "confidence": 0.9,
139
+ "api_endpoint": self.api_endpoint,
140
+ "space_url": self.space_url,
141
+ "prompt_used": prompt,
142
+ "response_length": len(description)
143
  }
144
 
145
+ logger.info(f"BAGEL API analysis complete: {len(description)} characters")
146
+ return description, metadata
147
 
148
  except Exception as e:
149
+ logger.error(f"BAGEL API analysis failed: {e}")
150
+ return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"}
151
+
152
+ finally:
153
+ # Always cleanup temporary file
154
+ if temp_path:
155
+ self._cleanup_temp_file(temp_path)
156
+
157
+ def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
158
+ """Analyze image specifically for FLUX prompt generation"""
159
+ flux_prompt = """Analyze this image and generate a detailed FLUX prompt description. Focus on:
160
+ - Photographic and artistic style
161
+ - Composition and framing
162
+ - Lighting conditions and mood
163
+ - Colors and visual elements
164
+ - Camera settings that would recreate this image
165
+ - Technical photography details
166
+ Provide a comprehensive description suitable for FLUX image generation."""
167
+
168
+ return self.analyze_image(image, flux_prompt)
169
+
170
  def cleanup(self) -> None:
171
+ """Clean up API client resources"""
172
  try:
173
+ if hasattr(self, 'client'):
174
+ self.client = None
 
 
 
 
 
 
175
  super().cleanup()
176
+ logger.info("BAGEL API resources cleaned up")
177
  except Exception as e:
178
+ logger.warning(f"BAGEL API cleanup warning: {e}")
179
 
180
 
181
  class FallbackAnalyzer(BaseImageAnalyzer):
182
+ """Simple fallback analyzer when BAGEL API is not available"""
183
 
184
  def __init__(self):
185
  super().__init__()
 
201
 
202
  if aspect_ratio > 1.5:
203
  orientation = "landscape"
204
+ camera_suggestion = "wide-angle lens, landscape photography"
205
  elif aspect_ratio < 0.75:
206
  orientation = "portrait"
207
+ camera_suggestion = "portrait lens, shallow depth of field"
208
  else:
209
  orientation = "square"
210
+ camera_suggestion = "standard lens, balanced composition"
211
 
212
+ description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance."
213
 
214
  metadata = {
215
  "model": "Fallback",
216
  "device": "cpu",
217
+ "confidence": 0.6,
218
  "image_size": f"{width}x{height}",
219
  "color_mode": mode,
220
+ "orientation": orientation,
221
+ "aspect_ratio": round(aspect_ratio, 2)
222
  }
223
 
224
  return description, metadata
225
 
226
  except Exception as e:
227
  logger.error(f"Fallback analysis failed: {e}")
228
+ return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"}
229
 
230
 
231
  class ModelManager:
232
  """Manager for handling image analysis models"""
233
 
234
+ def __init__(self, preferred_model: str = "bagel-api"):
235
  self.preferred_model = preferred_model
236
  self.analyzers = {}
237
  self.current_analyzer = None
 
241
  model_name = model_name or self.preferred_model
242
 
243
  if model_name not in self.analyzers:
244
+ if model_name == "bagel-api":
245
+ self.analyzers[model_name] = BagelAPIAnalyzer()
246
  elif model_name == "fallback":
247
  self.analyzers[model_name] = FallbackAnalyzer()
248
  else:
 
252
 
253
  return self.analyzers[model_name]
254
 
255
+ def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]:
256
  """Analyze image with specified or preferred model"""
257
  # Try preferred model first
258
  analyzer = self.get_analyzer(model_name)
259
  if analyzer is None:
260
  return "No analyzer available", {"error": "Model not found"}
261
 
262
+ # Choose analysis method based on type
263
+ if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'):
264
+ success, result = safe_execute(analyzer.analyze_for_flux_prompt, image)
265
+ else:
266
+ success, result = safe_execute(analyzer.analyze_image, image)
267
 
268
  if success and result[1].get("error") is None:
269
  return result
 
288
 
289
 
290
  # Global model manager instance
291
+ model_manager = ModelManager(preferred_model="bagel-api")
292
 
293
 
294
+ def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]:
295
  """
296
+ Convenience function for image analysis using BAGEL API
297
 
298
  Args:
299
  image: PIL Image to analyze
300
+ model_name: Optional model name ("bagel-api" or "fallback")
301
+ analysis_type: Type of analysis ("detailed" or "flux")
302
 
303
  Returns:
304
  Tuple of (description, metadata)
305
  """
306
+ return model_manager.analyze_image(image, model_name, analysis_type)
307
 
308
 
309
  # Export main components
310
  __all__ = [
311
  "BaseImageAnalyzer",
312
+ "BagelAPIAnalyzer",
313
  "FallbackAnalyzer",
314
  "ModelManager",
315
  "model_manager",