VOIDER commited on
Commit
36c6ae2
·
verified ·
1 Parent(s): fab9033

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -181
app.py CHANGED
@@ -12,15 +12,16 @@ import torch
12
  import onnxruntime as rt
13
  from PIL import Image
14
  from huggingface_hub import hf_hub_download
15
- from transformers import pipeline, Pipeline
16
  from tqdm import tqdm
17
 
18
- # Suppress a specific PIL warning about image size
19
  Image.MAX_IMAGE_PIXELS = None
20
 
21
  # --- Configuration ---
22
  CACHE_DIR = "./hf_cache"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
24
  DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32
25
 
26
  print(f"Using device: {DEVICE} with dtype: {DTYPE}")
@@ -31,21 +32,20 @@ print(f"Using device: {DEVICE} with dtype: {DTYPE}")
31
 
32
  class AestheticScorer(ABC):
33
  """Abstract base class for all aesthetic scoring models."""
34
-
35
  def __init__(self, model_name: str, repo_id: str, filename: str = None):
36
  self.model_name = model_name
37
  self.repo_id = repo_id
38
  self.filename = filename
39
  self._model = None
40
- print(f"Initializing scorer: {self.model_name}")
41
 
42
  @property
43
  def model(self):
44
  """Lazy-loads the model on first access."""
45
  if self._model is None:
46
- print(f"Loading model for '{self.model_name}'...")
47
  self._model = self.load_model()
48
- print(f"'{self.model_name}' model loaded.")
49
  return self._model
50
 
51
  def _download_model(self) -> str:
@@ -63,9 +63,9 @@ class AestheticScorer(ABC):
63
  pass
64
 
65
  def release_model(self):
66
- """Releases model from memory."""
67
  if self._model is not None:
68
- print(f"Releasing model: {self.model_name}")
69
  del self._model
70
  self._model = None
71
  gc.collect()
@@ -74,23 +74,15 @@ class AestheticScorer(ABC):
74
 
75
  class PipelineScorer(AestheticScorer):
76
  """Scorer for models compatible with Hugging Face pipelines."""
77
-
78
  def load_model(self) -> Pipeline:
79
- """Loads a pipeline model."""
80
- return pipeline(
81
- "image-classification",
82
- model=self.repo_id,
83
- device=DEVICE,
84
- )
85
 
86
  @torch.no_grad()
87
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
88
- """Scores a batch using the pipeline and extracts the 'hq' score."""
89
- results = self.model(image_batch)
90
  scores = []
91
  for res in results:
92
  try:
93
- # Find the score for the 'hq' (high quality) label
94
  hq_score = next(item['score'] for item in res if item['label'] == 'hq')
95
  scores.append(round(hq_score * 10.0, 4))
96
  except (StopIteration, TypeError):
@@ -99,21 +91,16 @@ class PipelineScorer(AestheticScorer):
99
 
100
  class ONNXScorer(AestheticScorer):
101
  """Scorer for ONNX-based models."""
102
-
103
  def load_model(self) -> rt.InferenceSession:
104
- """Loads an ONNX inference session."""
105
  model_path = self._download_model()
106
  return rt.InferenceSession(model_path, providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'])
107
 
108
  def _preprocess(self, img: Image.Image) -> np.ndarray:
109
- """Preprocesses a single image for the Anime Aesthetic model."""
110
  img_np = np.array(img.convert("RGB")).astype(np.float32) / 255.0
111
  s = 768
112
  h, w = img_np.shape[:2]
113
- if h > w:
114
- new_h, new_w = s, int(s * w / h)
115
- else:
116
- new_h, new_w = int(s * h / w), s
117
 
118
  resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
119
  canvas = np.zeros((s, s, 3), dtype=np.float32)
@@ -123,7 +110,6 @@ class ONNXScorer(AestheticScorer):
123
  return np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
124
 
125
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
126
- """Scores images one by one as this model doesn't support batching."""
127
  scores = []
128
  for img in image_batch:
129
  try:
@@ -135,264 +121,192 @@ class ONNXScorer(AestheticScorer):
135
  return scores
136
 
137
  class CLIPMLPScorer(AestheticScorer):
138
- """Scorer for models using a CLIP backbone and an MLP head."""
139
-
140
  class MLP(torch.nn.Module):
 
141
  def __init__(self, input_size: int):
142
  super().__init__()
143
  self.layers = torch.nn.Sequential(
144
- torch.nn.Linear(input_size, 1024),
145
  torch.nn.ReLU(),
146
- torch.nn.Dropout(0.2),
147
- torch.nn.Linear(1024, 128),
 
 
 
 
 
148
  torch.nn.ReLU(),
 
149
  torch.nn.Dropout(0.2),
150
- torch.nn.Linear(128, 64),
151
  torch.nn.ReLU(),
152
- torch.nn.Linear(64, 16),
 
 
153
  torch.nn.ReLU(),
154
- torch.nn.Linear(16, 1),
155
  )
156
  def forward(self, x):
157
  return self.layers(x)
158
 
159
  def load_model(self) -> Dict[str, Any]:
160
- """Loads both the CLIP model and the custom MLP head."""
161
- import clip # Lazy import
162
-
163
  model_path = self._download_model()
164
-
165
  mlp = self.MLP(input_size=768) # ViT-L/14 has 768 features
166
  state_dict = torch.load(model_path, map_location=DEVICE)
167
  mlp.load_state_dict(state_dict)
168
- mlp.to(device=DEVICE, dtype=DTYPE)
169
  mlp.eval()
170
-
171
  clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE)
172
-
173
  return {"mlp": mlp, "clip": clip_model, "preprocess": preprocess}
174
 
175
  @torch.no_grad()
176
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
177
- """Scores a batch using CLIP features and the MLP head."""
178
  preprocess = self.model['preprocess']
 
 
 
 
 
 
 
179
  image_tensors = torch.cat([preprocess(img).unsqueeze(0) for img in image_batch]).to(DEVICE)
180
-
181
- image_features = self.model['clip'].encode_image(image_tensors)
182
  image_features /= image_features.norm(dim=-1, keepdim=True)
 
 
183
 
184
- # Pass features through MLP
185
- predictions = self.model['mlp'](image_features.to(DTYPE)).squeeze(-1)
186
- scores = predictions.float().cpu().numpy()
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  return [round(float(s), 4) for s in scores]
189
 
190
  # --- Model Registry ---
191
- MODEL_REGISTRY: Dict[str, Type[AestheticScorer]] = {
192
- "Aesthetic Shadow V2": PipelineScorer(
193
- "Aesthetic Shadow V2", "shadowlilac/aesthetic-shadow-v2"
194
- ),
195
- "Waifu Scorer V2": CLIPMLPScorer(
196
- "Waifu Scorer V2", "skytnt/waifu-aesthetic-scorer", "model.pth"
197
- ),
198
- "Anime Scorer": ONNXScorer(
199
- "Anime Scorer", "skytnt/anime-aesthetic", "model.onnx"
200
- )
201
  }
202
-
203
- # In-memory cache for loaded model instances
204
  _loaded_models_cache: Dict[str, AestheticScorer] = {}
205
 
206
-
207
  # ==================================================================================
208
  # 2. CORE PROCESSING LOGIC
209
  # ==================================================================================
210
 
211
  def get_scorers(model_names: List[str]) -> List[AestheticScorer]:
212
  """Retrieves and caches scorer instances based on selected names."""
213
- # Release models that are no longer selected
214
- for name, scorer in list(_loaded_models_cache.items()):
215
  if name not in model_names:
216
- scorer.release_model()
217
  del _loaded_models_cache[name]
218
-
219
- # Load newly selected models
220
- scorers = []
221
- for name in model_names:
222
- if name in _loaded_models_cache:
223
- scorers.append(_loaded_models_cache[name])
224
- elif name in MODEL_REGISTRY:
225
- scorer = MODEL_REGISTRY[name]
226
- _loaded_models_cache[name] = scorer
227
- scorers.append(scorer)
228
- return scorers
229
 
230
  def evaluate_images(
231
- files: List[gr.File],
232
- selected_model_names: List[str],
233
- batch_size: int,
234
- progress: gr.Progress = gr.Progress(track_tqdm=True),
235
  ) -> pd.DataFrame:
236
- """
237
- Main function to process images, run them through selected models,
238
- and return results as a Pandas DataFrame.
239
- """
240
  if not files:
241
  gr.Warning("No images uploaded. Please upload files to evaluate.")
242
  return pd.DataFrame()
243
-
244
  if not selected_model_names:
245
  gr.Warning("No models selected. Please select at least one model.")
246
  return pd.DataFrame()
247
 
248
  try:
249
  image_paths = [Path(f.name) for f in files]
250
- all_results = []
251
- scorers = get_scorers(selected_model_names)
252
-
253
- # Use a single tqdm instance for progress tracking
254
- pbar = tqdm(total=len(image_paths), desc="Processing images")
255
 
256
- for i in range(0, len(image_paths), batch_size):
257
  batch_paths = image_paths[i : i + batch_size]
258
-
259
- # Load images for the current batch
260
  try:
261
  batch_images = [Image.open(p).convert("RGB") for p in batch_paths]
262
  except Exception as e:
263
  gr.Warning(f"Skipping a batch due to an error loading an image: {e}")
264
- pbar.update(len(batch_paths))
265
  continue
266
-
267
- # Get scores from all selected models for the batch
268
- batch_scores: Dict[str, List[float]] = {}
269
- for scorer in scorers:
270
- batch_scores[scorer.model_name] = scorer.score_batch(batch_images)
271
 
272
- # Collate results for the batch
 
273
  for j, path in enumerate(batch_paths):
274
- result_row = {"Image": Image.open(path), "Filename": path.name}
275
-
276
- scores_for_avg = []
277
  for scorer in scorers:
278
- score = batch_scores[scorer.model_name][j]
279
- result_row[scorer.model_name] = score
280
- scores_for_avg.append(score)
281
-
282
- # Calculate average score
283
- if scores_for_avg:
284
- result_row["Average Score"] = round(np.mean(scores_for_avg), 4)
285
- else:
286
- result_row["Average Score"] = 0.0
287
-
288
  all_results.append(result_row)
289
-
290
- pbar.update(len(batch_paths))
291
 
292
- pbar.close()
293
-
294
- if not all_results:
295
- gr.Warning("Processing completed, but no results were generated.")
296
- return pd.DataFrame()
297
-
298
- return pd.DataFrame(all_results)
299
 
300
  except Exception as e:
301
  gr.Error(f"A critical error occurred: {e}")
302
- # Clean up in case of failure
303
- for scorer in _loaded_models_cache.values():
304
- scorer.release_model()
305
- _loaded_models_cache.clear()
306
  return pd.DataFrame()
307
 
308
-
309
  # ==================================================================================
310
  # 3. GRADIO USER INTERFACE
311
  # ==================================================================================
312
 
313
  def create_ui() -> gr.Blocks:
314
  """Creates and configures the Gradio web interface."""
315
-
316
  all_model_names = list(MODEL_REGISTRY.keys())
317
-
318
- # Define headers and datatypes for the results table
319
  dataframe_headers = ["Image", "Filename"] + all_model_names + ["Average Score"]
320
  dataframe_datatypes = ["image", "str"] + ["number"] * (len(all_model_names) + 1)
321
 
322
- with gr.Blocks(theme=gr.themes.Soft(), title="Image Aesthetic Scorer") as demo:
323
- gr.Markdown(
324
- """
325
- # 🖼️ Modern Image Aesthetic Scorer
326
- Upload your images, select the scoring models, and click 'Evaluate'.
327
- The results table supports **interactive sorting** (click on headers) and can be **downloaded as a CSV**.
328
- """
329
- )
330
 
331
  with gr.Row():
332
  with gr.Column(scale=1):
333
- gr.Markdown("### ⚙️ Settings")
334
- input_files = gr.Files(
335
- label="Upload Images",
336
- file_count="multiple",
337
- file_types=["image"],
338
- )
339
-
340
- with gr.Accordion("Advanced Configuration", open=False):
341
- model_checkboxes = gr.CheckboxGroup(
342
- choices=all_model_names,
343
- value=all_model_names,
344
- label="Scoring Models",
345
- info="Choose which models to use for evaluation.",
346
- )
347
- batch_size_slider = gr.Slider(
348
- minimum=1,
349
- maximum=64,
350
- value=8,
351
- step=1,
352
- label="Batch Size",
353
- info="Adjust based on your VRAM. Higher is faster.",
354
- )
355
-
356
  with gr.Row():
357
  process_button = gr.Button("🚀 Evaluate Images", variant="primary")
358
  clear_button = gr.Button("🧹 Clear All")
359
 
360
  with gr.Column(scale=3):
361
- gr.Markdown("### 📊 Results")
362
  results_dataframe = gr.DataFrame(
363
  headers=dataframe_headers,
364
  datatype=dataframe_datatypes,
365
  label="Evaluation Scores",
366
  interactive=True,
367
- # Enable the download button directly on the component
 
368
  )
369
- # This is a cleaner way to show the download button
370
- results_dataframe.style(height=800, show_download_button=True)
371
-
372
 
373
- # --- Event Handlers ---
374
  process_button.click(
375
  fn=evaluate_images,
376
  inputs=[input_files, model_checkboxes, batch_size_slider],
377
- outputs=[results_dataframe],
378
- concurrency_limit=1 # Only one evaluation at a time
379
  )
380
 
381
  def clear_outputs():
382
- # Release all models from memory when clearing
383
- for scorer in _loaded_models_cache.values():
384
  scorer.release_model()
385
  _loaded_models_cache.clear()
386
  gr.Info("Cleared results and released models from memory.")
387
- # Return an empty DataFrame to clear the table
388
- return pd.DataFrame()
389
-
390
- clear_button.click(
391
- fn=clear_outputs,
392
- inputs=[],
393
- outputs=[results_dataframe],
394
- )
395
 
 
396
  return demo
397
 
398
  # ==================================================================================
@@ -400,8 +314,6 @@ def create_ui() -> gr.Blocks:
400
  # ==================================================================================
401
 
402
  if __name__ == "__main__":
403
- # Ensure cache directory exists
404
  os.makedirs(CACHE_DIR, exist_ok=True)
405
-
406
  app = create_ui()
407
- app.queue().launch(share=False)
 
12
  import onnxruntime as rt
13
  from PIL import Image
14
  from huggingface_hub import hf_hub_download
15
+ from transformers import pipeline, Pipeline, AutoModel, AutoProcessor
16
  from tqdm import tqdm
17
 
18
+ # Suppress a specific PIL warning about image size to handle large images
19
  Image.MAX_IMAGE_PIXELS = None
20
 
21
  # --- Configuration ---
22
  CACHE_DIR = "./hf_cache"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ # Use bfloat16 for modern GPUs, float32 for others (including CPU)
25
  DTYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32
26
 
27
  print(f"Using device: {DEVICE} with dtype: {DTYPE}")
 
32
 
33
  class AestheticScorer(ABC):
34
  """Abstract base class for all aesthetic scoring models."""
 
35
  def __init__(self, model_name: str, repo_id: str, filename: str = None):
36
  self.model_name = model_name
37
  self.repo_id = repo_id
38
  self.filename = filename
39
  self._model = None
40
+ print(f"Initializing scorer definition: {self.model_name}")
41
 
42
  @property
43
  def model(self):
44
  """Lazy-loads the model on first access."""
45
  if self._model is None:
46
+ print(f"Loading model weights for '{self.model_name}'...")
47
  self._model = self.load_model()
48
+ print(f"'{self.model_name}' model weights loaded.")
49
  return self._model
50
 
51
  def _download_model(self) -> str:
 
63
  pass
64
 
65
  def release_model(self):
66
+ """Releases model from memory to conserve VRAM/RAM."""
67
  if self._model is not None:
68
+ print(f"Releasing model from memory: {self.model_name}")
69
  del self._model
70
  self._model = None
71
  gc.collect()
 
74
 
75
  class PipelineScorer(AestheticScorer):
76
  """Scorer for models compatible with Hugging Face pipelines."""
 
77
  def load_model(self) -> Pipeline:
78
+ return pipeline("image-classification", model=self.repo_id, device=DEVICE)
 
 
 
 
 
79
 
80
  @torch.no_grad()
81
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
82
+ results = self.model(image_batch, top_k=None) # Get all class scores
 
83
  scores = []
84
  for res in results:
85
  try:
 
86
  hq_score = next(item['score'] for item in res if item['label'] == 'hq')
87
  scores.append(round(hq_score * 10.0, 4))
88
  except (StopIteration, TypeError):
 
91
 
92
  class ONNXScorer(AestheticScorer):
93
  """Scorer for ONNX-based models."""
 
94
  def load_model(self) -> rt.InferenceSession:
 
95
  model_path = self._download_model()
96
  return rt.InferenceSession(model_path, providers=['CUDAExecutionProvider' if DEVICE == 'cuda' else 'CPUExecutionProvider'])
97
 
98
  def _preprocess(self, img: Image.Image) -> np.ndarray:
 
99
  img_np = np.array(img.convert("RGB")).astype(np.float32) / 255.0
100
  s = 768
101
  h, w = img_np.shape[:2]
102
+ ratio = s / max(h, w)
103
+ new_h, new_w = int(h * ratio), int(w * ratio)
 
 
104
 
105
  resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
106
  canvas = np.zeros((s, s, 3), dtype=np.float32)
 
110
  return np.transpose(canvas, (2, 0, 1))[np.newaxis, :]
111
 
112
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
 
113
  scores = []
114
  for img in image_batch:
115
  try:
 
121
  return scores
122
 
123
  class CLIPMLPScorer(AestheticScorer):
124
+ """Scorer for models using a CLIP backbone and a custom MLP head."""
 
125
  class MLP(torch.nn.Module):
126
+ """Re-implementation of the exact MLP from the original code."""
127
  def __init__(self, input_size: int):
128
  super().__init__()
129
  self.layers = torch.nn.Sequential(
130
+ torch.nn.Linear(input_size, 2048),
131
  torch.nn.ReLU(),
132
+ torch.nn.BatchNorm1d(2048),
133
+ torch.nn.Dropout(0.3),
134
+ torch.nn.Linear(2048, 512),
135
+ torch.nn.ReLU(),
136
+ torch.nn.BatchNorm1d(512),
137
+ torch.nn.Dropout(0.3),
138
+ torch.nn.Linear(512, 256),
139
  torch.nn.ReLU(),
140
+ torch.nn.BatchNorm1d(256),
141
  torch.nn.Dropout(0.2),
142
+ torch.nn.Linear(256, 128),
143
  torch.nn.ReLU(),
144
+ torch.nn.BatchNorm1d(128),
145
+ torch.nn.Dropout(0.1),
146
+ torch.nn.Linear(128, 32),
147
  torch.nn.ReLU(),
148
+ torch.nn.Linear(32, 1)
149
  )
150
  def forward(self, x):
151
  return self.layers(x)
152
 
153
  def load_model(self) -> Dict[str, Any]:
154
+ import clip
 
 
155
  model_path = self._download_model()
 
156
  mlp = self.MLP(input_size=768) # ViT-L/14 has 768 features
157
  state_dict = torch.load(model_path, map_location=DEVICE)
158
  mlp.load_state_dict(state_dict)
159
+ mlp.to(device=DEVICE)
160
  mlp.eval()
 
161
  clip_model, preprocess = clip.load("ViT-L/14", device=DEVICE)
 
162
  return {"mlp": mlp, "clip": clip_model, "preprocess": preprocess}
163
 
164
  @torch.no_grad()
165
  def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
 
166
  preprocess = self.model['preprocess']
167
+ # Handle single-image batches correctly for CLIP
168
+ if len(image_batch) == 1:
169
+ image_batch = image_batch * 2
170
+ single_image_mode = True
171
+ else:
172
+ single_image_mode = False
173
+
174
  image_tensors = torch.cat([preprocess(img).unsqueeze(0) for img in image_batch]).to(DEVICE)
175
+ image_features = self.model['clip'].encode_image(image_tensors).to(torch.float32)
 
176
  image_features /= image_features.norm(dim=-1, keepdim=True)
177
+ predictions = self.model['mlp'](image_features).squeeze(-1)
178
+ scores = predictions.clamp(0, 10).float().cpu().numpy()
179
 
180
+ final_scores = [round(float(s), 4) for s in scores]
181
+ return final_scores[:1] if single_image_mode else final_scores
182
+
183
+ class SigLIPScorer(AestheticScorer):
184
+ """Scorer for the Aesthetic Predictor V2.5 SigLIP model."""
185
+ def load_model(self) -> Dict[str, Any]:
186
+ model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(DEVICE, DTYPE).eval()
187
+ processor = AutoProcessor.from_pretrained(self.repo_id, trust_remote_code=True)
188
+ return {"model": model, "processor": processor}
189
+
190
+ @torch.no_grad()
191
+ def score_batch(self, image_batch: List[Image.Image]) -> List[float]:
192
+ inputs = self.model['processor'](
193
+ images=[img.convert("RGB") for img in image_batch],
194
+ return_tensors="pt"
195
+ )
196
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
197
+ inputs['pixel_values'] = inputs['pixel_values'].to(DTYPE)
198
+ logits = self.model(**inputs).logits.squeeze(-1)
199
+ scores = logits.float().cpu().numpy()
200
  return [round(float(s), 4) for s in scores]
201
 
202
  # --- Model Registry ---
203
+ MODEL_REGISTRY: Dict[str, AestheticScorer] = {
204
+ "Aesthetic Shadow V2": PipelineScorer("Aesthetic Shadow V2", "NeoChen1024/aesthetic-shadow-v2-backup"),
205
+ "Waifu Scorer V3": CLIPMLPScorer("Waifu Scorer V3", "Eugeoter/waifu-scorer-v3", "model.pth"),
206
+ "Aesthetic V2.5 SigLIP": SigLIPScorer("Aesthetic V2.5 SigLIP", "জিংוניत्र/Aesthetic-Predictor-V2-5-SigLIP"),
207
+ "Anime Scorer": ONNXScorer("Anime Scorer", "skytnt/anime-aesthetic", "model.onnx")
 
 
 
 
 
208
  }
 
 
209
  _loaded_models_cache: Dict[str, AestheticScorer] = {}
210
 
 
211
  # ==================================================================================
212
  # 2. CORE PROCESSING LOGIC
213
  # ==================================================================================
214
 
215
  def get_scorers(model_names: List[str]) -> List[AestheticScorer]:
216
  """Retrieves and caches scorer instances based on selected names."""
217
+ for name in list(_loaded_models_cache.keys()):
 
218
  if name not in model_names:
219
+ _loaded_models_cache[name].release_model()
220
  del _loaded_models_cache[name]
221
+ return [_loaded_models_cache.setdefault(name, MODEL_REGISTRY[name]) for name in model_names]
 
 
 
 
 
 
 
 
 
 
222
 
223
  def evaluate_images(
224
+ files: List[gr.File], selected_model_names: List[str], batch_size: int, progress=gr.Progress(track_tqdm=True)
 
 
 
225
  ) -> pd.DataFrame:
226
+ """Main function to process images and return results as a Pandas DataFrame."""
 
 
 
227
  if not files:
228
  gr.Warning("No images uploaded. Please upload files to evaluate.")
229
  return pd.DataFrame()
 
230
  if not selected_model_names:
231
  gr.Warning("No models selected. Please select at least one model.")
232
  return pd.DataFrame()
233
 
234
  try:
235
  image_paths = [Path(f.name) for f in files]
236
+ all_results, scorers = [], get_scorers(selected_model_names)
 
 
 
 
237
 
238
+ for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing Batches"):
239
  batch_paths = image_paths[i : i + batch_size]
 
 
240
  try:
241
  batch_images = [Image.open(p).convert("RGB") for p in batch_paths]
242
  except Exception as e:
243
  gr.Warning(f"Skipping a batch due to an error loading an image: {e}")
 
244
  continue
 
 
 
 
 
245
 
246
+ batch_scores = {scorer.model_name: scorer.score_batch(batch_images) for scorer in scorers}
247
+
248
  for j, path in enumerate(batch_paths):
249
+ result_row = {"Image": str(path), "Filename": path.name}
250
+ scores_for_avg = [batch_scores[s.model_name][j] for s in scorers]
 
251
  for scorer in scorers:
252
+ result_row[scorer.model_name] = batch_scores[scorer.model_name][j]
253
+ result_row["Average Score"] = round(np.mean(scores_for_avg), 4) if scores_for_avg else 0.0
 
 
 
 
 
 
 
 
254
  all_results.append(result_row)
 
 
255
 
256
+ return pd.DataFrame(all_results) if all_results else pd.DataFrame()
 
 
 
 
 
 
257
 
258
  except Exception as e:
259
  gr.Error(f"A critical error occurred: {e}")
 
 
 
 
260
  return pd.DataFrame()
261
 
 
262
  # ==================================================================================
263
  # 3. GRADIO USER INTERFACE
264
  # ==================================================================================
265
 
266
  def create_ui() -> gr.Blocks:
267
  """Creates and configures the Gradio web interface."""
 
268
  all_model_names = list(MODEL_REGISTRY.keys())
 
 
269
  dataframe_headers = ["Image", "Filename"] + all_model_names + ["Average Score"]
270
  dataframe_datatypes = ["image", "str"] + ["number"] * (len(all_model_names) + 1)
271
 
272
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="Image Aesthetic Scorer") as demo:
273
+ gr.Markdown("# 🖼️ Modern Image Aesthetic Scorer")
274
+ gr.Markdown("Upload images, select models, and click 'Evaluate'. Results table supports **interactive sorting** and **downloading as CSV**.")
 
 
 
 
 
275
 
276
  with gr.Row():
277
  with gr.Column(scale=1):
278
+ input_files = gr.Files(label="Upload Images", file_count="multiple", file_types=["image"])
279
+ model_checkboxes = gr.CheckboxGroup(choices=all_model_names, value=all_model_names, label="Scoring Models")
280
+ batch_size_slider = gr.Slider(minimum=1, maximum=64, value=8, step=1, label="Batch Size", info="Adjust based on your VRAM.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  with gr.Row():
282
  process_button = gr.Button("🚀 Evaluate Images", variant="primary")
283
  clear_button = gr.Button("🧹 Clear All")
284
 
285
  with gr.Column(scale=3):
286
+ # CORRECTED LINE: height and show_download_button are passed directly here.
287
  results_dataframe = gr.DataFrame(
288
  headers=dataframe_headers,
289
  datatype=dataframe_datatypes,
290
  label="Evaluation Scores",
291
  interactive=True,
292
+ height=800,
293
+ show_download_button=True
294
  )
 
 
 
295
 
 
296
  process_button.click(
297
  fn=evaluate_images,
298
  inputs=[input_files, model_checkboxes, batch_size_slider],
299
+ outputs=[results_dataframe]
 
300
  )
301
 
302
  def clear_outputs():
303
+ for scorer in list(_loaded_models_cache.values()):
 
304
  scorer.release_model()
305
  _loaded_models_cache.clear()
306
  gr.Info("Cleared results and released models from memory.")
307
+ return pd.DataFrame(), None # Clear dataframe and file input
 
 
 
 
 
 
 
308
 
309
+ clear_button.click(fn=clear_outputs, outputs=[results_dataframe, input_files])
310
  return demo
311
 
312
  # ==================================================================================
 
314
  # ==================================================================================
315
 
316
  if __name__ == "__main__":
 
317
  os.makedirs(CACHE_DIR, exist_ok=True)
 
318
  app = create_ui()
319
+ app.queue().launch()