|
import os |
|
import io |
|
import tempfile |
|
import shutil |
|
|
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import onnxruntime as rt |
|
from PIL import Image |
|
import gradio as gr |
|
from transformers import pipeline |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
try: |
|
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip |
|
except ImportError: |
|
print("Warning: aesthetic_predictor_v2_5.py not found. Using a mock for AestheticPredictorV25.") |
|
def convert_v2_5_from_siglip(low_cpu_mem_usage=True, trust_remote_code=True): |
|
|
|
mock_model_output = torch.randn(1, 1) |
|
|
|
class MockModel(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.dummy_param = torch.nn.Parameter(torch.empty(0)) |
|
|
|
def forward(self, pixel_values): |
|
|
|
|
|
batch_size = pixel_values.size(0) |
|
|
|
class Output: |
|
pass |
|
output = Output() |
|
output.logits = torch.randn(batch_size, 1).to(self.dummy_param.device) |
|
return output |
|
|
|
def to(self, device_or_dtype): |
|
if isinstance(device_or_dtype, torch.dtype): |
|
|
|
pass |
|
elif isinstance(device_or_dtype, str) or isinstance(device_or_dtype, torch.device): |
|
self.dummy_param = torch.nn.Parameter(torch.empty(0, device=device_or_dtype)) |
|
return self |
|
|
|
def cuda(self): |
|
return self.to(torch.device('cuda')) |
|
|
|
|
|
mock_model_instance = MockModel() |
|
|
|
|
|
mock_preprocessor = lambda images, return_tensors: {"pixel_values": torch.randn(len(images) if isinstance(images, list) else 1, 3, 224, 224)} |
|
return mock_model_instance, mock_preprocessor |
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
DTYPE_WAIFU = torch.float32 |
|
CACHE_DIR = None |
|
|
|
|
|
|
|
class MLP(torch.nn.Module): |
|
"""Custom MLP for WaifuScorer.""" |
|
def __init__(self, input_size: int, batch_norm: bool = True): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(self.input_size, 2048), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3), |
|
torch.nn.Linear(2048, 512), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.3), |
|
torch.nn.Linear(512, 256), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.2), |
|
torch.nn.Linear(256, 128), torch.nn.ReLU(), |
|
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), torch.nn.Dropout(0.1), |
|
torch.nn.Linear(128, 32), torch.nn.ReLU(), |
|
torch.nn.Linear(32, 1) |
|
) |
|
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) |
|
|
|
class BaseImageScorer: |
|
"""Abstract base class for image scorers.""" |
|
def __init__(self, model_key: str, model_display_name: str, device: str = DEVICE, verbose: bool = False): |
|
self.model_key = model_key |
|
self.model_display_name = model_display_name |
|
self.device = device |
|
self.verbose = verbose |
|
self.model = None |
|
self.preprocessor = None |
|
self._load_model() |
|
|
|
def _load_model(self): raise NotImplementedError |
|
def predict(self, images: list[Image.Image]) -> list[float | None]: raise NotImplementedError |
|
|
|
def __call__(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.model: |
|
if self.verbose: print(f"{self.model_display_name} model not loaded.") |
|
return [None] * len(images) |
|
|
|
rgb_images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] |
|
return self.predict(rgb_images) |
|
|
|
class WaifuScorerModel(BaseImageScorer): |
|
def _load_model(self): |
|
try: |
|
import clip |
|
model_hf_path = "Eugeoter/waifu-scorer-v3/model.pth" |
|
|
|
repo_id, filename = os.path.split(model_hf_path) |
|
actual_model_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=CACHE_DIR) |
|
if self.verbose: print(f"Loading WaifuScorer MLP from: {actual_model_path}") |
|
|
|
self.mlp = MLP(input_size=768) |
|
if actual_model_path.endswith(".safetensors"): |
|
from safetensors.torch import load_file |
|
state_dict = load_file(actual_model_path, device=self.device) |
|
else: |
|
state_dict = torch.load(actual_model_path, map_location=self.device) |
|
self.mlp.load_state_dict(state_dict) |
|
self.mlp.to(self.device).eval() |
|
|
|
if self.verbose: print("Loading CLIP model ViT-L/14 for WaifuScorer.") |
|
self.model, self.preprocessor = clip.load("ViT-L/14", device=self.device) |
|
self.model.eval() |
|
except ImportError: |
|
if self.verbose: print("CLIP library not found. WaifuScorer will not be available.") |
|
except Exception as e: |
|
if self.verbose: print(f"Error loading WaifuScorer ({self.model_display_name}): {e}") |
|
|
|
@torch.no_grad() |
|
def predict(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.model or not self.mlp: return [None] * len(images) |
|
|
|
original_n = len(images) |
|
processed_images = list(images) |
|
if original_n == 1: processed_images.append(images[0]) |
|
|
|
try: |
|
image_tensors = torch.cat([self.preprocessor(img).unsqueeze(0) for img in processed_images]).to(self.device) |
|
image_features = self.model.encode_image(image_tensors) |
|
norm = image_features.norm(p=2, dim=-1, keepdim=True) |
|
norm[norm == 0] = 1e-6 |
|
im_emb = (image_features / norm).to(device=self.device, dtype=DTYPE_WAIFU) |
|
|
|
predictions = self.mlp(im_emb) |
|
scores = predictions.clamp(0, 10).cpu().numpy().flatten().tolist() |
|
return scores[:original_n] |
|
except Exception as e: |
|
if self.verbose: print(f"Error during {self.model_display_name} prediction: {e}") |
|
return [None] * original_n |
|
|
|
class AestheticPredictorV25(BaseImageScorer): |
|
def _load_model(self): |
|
try: |
|
if self.verbose: print(f"Loading {self.model_display_name}...") |
|
self.model, self.preprocessor = convert_v2_5_from_siglip(low_cpu_mem_usage=True, trust_remote_code=True) |
|
|
|
self.model = self.model.to(self.device) |
|
if self.device == 'cuda' and torch.cuda.is_available() and hasattr(self.model, 'to'): |
|
self.model = self.model.to(torch.bfloat16) |
|
self.model.eval() |
|
except Exception as e: |
|
if self.verbose: print(f"Error loading {self.model_display_name}: {e}") |
|
|
|
@torch.no_grad() |
|
def predict(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.model or not self.preprocessor: return [None] * len(images) |
|
try: |
|
inputs = self.preprocessor(images=images, return_tensors="pt") |
|
pixel_values = inputs["pixel_values"].to(self.model.dummy_param.device if hasattr(self.model, 'dummy_param') else self.device) |
|
if self.device == 'cuda' and torch.cuda.is_available() and pixel_values.dtype != torch.bfloat16 : |
|
pixel_values = pixel_values.to(torch.bfloat16) |
|
|
|
output = self.model(pixel_values) |
|
scores_tensor = output.logits if hasattr(output, 'logits') else output |
|
scores = scores_tensor.squeeze().float().cpu().numpy() |
|
|
|
scores_list = [float(np.round(np.clip(s, 0.0, 10.0), 4)) for s in np.atleast_1d(scores)] |
|
return scores_list |
|
except Exception as e: |
|
if self.verbose: print(f"Error during {self.model_display_name} prediction: {e}") |
|
return [None] * len(images) |
|
|
|
class AnimeAestheticONNX(BaseImageScorer): |
|
def _load_model(self): |
|
try: |
|
if self.verbose: print(f"Loading {self.model_display_name} (ONNX)...") |
|
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx", cache_dir=CACHE_DIR) |
|
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device == 'cuda' else ['CPUExecutionProvider'] |
|
valid_providers = [p for p in providers if p in rt.get_available_providers()] or ['CPUExecutionProvider'] |
|
self.model = rt.InferenceSession(model_path, providers=valid_providers) |
|
if self.verbose: print(f"{self.model_display_name} loaded with providers: {self.model.get_providers()}") |
|
except Exception as e: |
|
if self.verbose: print(f"Error loading {self.model_display_name}: {e}") |
|
|
|
def _preprocess_image(self, img: Image.Image) -> np.ndarray: |
|
img_np = np.array(img).astype(np.float32) / 255.0 |
|
s = 768 |
|
h, w = img_np.shape[:2] |
|
r = min(s/h, s/w) |
|
new_h, new_w = int(h*r), int(w*r) |
|
|
|
resized = cv2.resize(img_np, (new_w, new_h), interpolation=cv2.INTER_AREA if r < 1 else cv2.INTER_LANCZOS4) |
|
|
|
canvas = np.zeros((s, s, 3), dtype=np.float32) |
|
pad_h, pad_w = (s - new_h) // 2, (s - new_w) // 2 |
|
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized |
|
return np.transpose(canvas, (2, 0, 1))[np.newaxis, :] |
|
|
|
def predict(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.model: return [None] * len(images) |
|
scores = [] |
|
for img in images: |
|
try: |
|
input_tensor = self._preprocess_image(img) |
|
pred = self.model.run(None, {"img": input_tensor})[0].item() |
|
scores.append(float(np.clip(pred * 10.0, 0.0, 10.0))) |
|
except Exception as e: |
|
if self.verbose: print(f"Error predicting with {self.model_display_name} for one image: {e}") |
|
scores.append(None) |
|
return scores |
|
|
|
class AestheticShadowPipeline(BaseImageScorer): |
|
def _load_model(self): |
|
try: |
|
if self.verbose: print(f"Loading {self.model_display_name} pipeline...") |
|
pipeline_device = 0 if self.device == 'cuda' else -1 |
|
self.model = pipeline("image-classification", model="NeoChen1024/aesthetic-shadow-v2-backup", device=pipeline_device) |
|
except Exception as e: |
|
if self.verbose: print(f"Error loading {self.model_display_name}: {e}") |
|
|
|
def predict(self, images: list[Image.Image]) -> list[float | None]: |
|
if not self.model: return [None] * len(images) |
|
scores = [] |
|
try: |
|
pipeline_results = self.model(images, top_k=None) |
|
|
|
|
|
if images and pipeline_results and not isinstance(pipeline_results[0], list): |
|
pipeline_results = [pipeline_results] |
|
|
|
for res_set in pipeline_results: |
|
try: |
|
hq_score_dict = next(p for p in res_set if p['label'] == 'hq') |
|
scores.append(float(np.clip(hq_score_dict['score'] * 10.0, 0.0, 10.0))) |
|
except (StopIteration, TypeError, KeyError): scores.append(None) |
|
except Exception as e: |
|
if self.verbose: print(f"Error during {self.model_display_name} prediction: {e}") |
|
return [None] * len(images) |
|
return scores |
|
|
|
|
|
MODEL_REGISTRY = { |
|
"aesthetic_shadow": {"class": AestheticShadowPipeline, "name": "Aesthetic Shadow"}, |
|
"waifu_scorer": {"class": WaifuScorerModel, "name": "Waifu Scorer"}, |
|
"aesthetic_predictor_v2_5": {"class": AestheticPredictorV25, "name": "Aesthetic V2.5"}, |
|
"anime_aesthetic": {"class": AnimeAestheticONNX, "name": "Anime Score"}, |
|
} |
|
LOADED_MODELS = {} |
|
|
|
def initialize_models(verbose_loading=False): |
|
print(f"Using device: {DEVICE}") |
|
print("Initializing models...") |
|
for key, config in MODEL_REGISTRY.items(): |
|
LOADED_MODELS[key] = config["class"](key, config['name'], device=DEVICE, verbose=verbose_loading) |
|
print("Model initialization complete.") |
|
|
|
|
|
@torch.no_grad() |
|
def auto_tune_batch_size(images: list[Image.Image], selected_model_keys: list[str], |
|
initial_bs: int = 1, max_bs_limit: int = 64, verbose: bool = False) -> int: |
|
if not images or not selected_model_keys: return initial_bs |
|
if verbose: print("Auto-tuning batch size...") |
|
|
|
test_image = images[0] |
|
active_models = [LOADED_MODELS[key] for key in selected_model_keys if key in LOADED_MODELS and LOADED_MODELS[key].model] |
|
if not active_models: return initial_bs |
|
|
|
bs = initial_bs |
|
optimal_bs = initial_bs |
|
while bs <= len(images) and bs <= max_bs_limit: |
|
try: |
|
batch_test_images = [test_image] * bs |
|
for model in active_models: |
|
if verbose: print(f" Testing {model.model_display_name} with batch size {bs}") |
|
model.predict(batch_test_images) |
|
if DEVICE == 'cuda': torch.cuda.empty_cache() |
|
|
|
optimal_bs = bs |
|
if bs == max_bs_limit: break |
|
bs = min(bs * 2, max_bs_limit) |
|
except Exception as e: |
|
if verbose: print(f" Failed at batch size {bs} ({type(e).__name__}). Optimal so far: {optimal_bs}. Error: {str(e)[:100]}") |
|
break |
|
if verbose: print(f"Auto-tuned batch size: {optimal_bs}") |
|
return max(1, optimal_bs) |
|
|
|
async def evaluate_images_core( |
|
pil_images: list[Image.Image], file_names: list[str], |
|
selected_model_keys: list[str], batch_size: int, |
|
progress_tracker: gr.Progress |
|
) -> tuple[pd.DataFrame, list[str]]: |
|
|
|
logs = [] |
|
num_images = len(pil_images) |
|
if num_images == 0: return pd.DataFrame(), ["No images to process."] |
|
|
|
|
|
results_data = [{'File Name': fn, 'Thumbnail': img.copy().resize((150,150)), 'Final Score': np.nan} |
|
for fn, img in zip(file_names, pil_images)] |
|
for r_dict in results_data: |
|
for cfg in MODEL_REGISTRY.values(): r_dict[cfg['name']] = np.nan |
|
|
|
progress_tracker(0, desc="Starting evaluation...") |
|
total_models_to_run = len(selected_model_keys) |
|
|
|
for model_idx, model_key in enumerate(selected_model_keys): |
|
model = LOADED_MODELS.get(model_key) |
|
if not model or not model.model: |
|
logs.append(f"Skipping {MODEL_REGISTRY[model_key]['name']} (not loaded).") |
|
continue |
|
|
|
model_name = model.model_display_name |
|
logs.append(f"Processing with {model_name}...") |
|
|
|
current_img_offset = 0 |
|
for batch_start_idx in range(0, num_images, batch_size): |
|
|
|
model_progress_fraction = (batch_start_idx / num_images) |
|
overall_progress = (model_idx + model_progress_fraction) / total_models_to_run |
|
progress_tracker(overall_progress, desc=f"{model_name} (Batch {batch_start_idx//batch_size + 1})") |
|
|
|
batch_images = pil_images[batch_start_idx : batch_start_idx + batch_size] |
|
try: |
|
scores = model(batch_images) |
|
for i, score in enumerate(scores): |
|
results_data[current_img_offset + i][model_name] = score if score is not None else np.nan |
|
except Exception as e: |
|
logs.append(f"Error with {model_name} on batch: {e}") |
|
current_img_offset += len(batch_images) |
|
logs.append(f"Finished with {model_name}.") |
|
|
|
|
|
for i in range(num_images): |
|
img_scores = [results_data[i][MODEL_REGISTRY[mk]['name']] for mk in selected_model_keys |
|
if pd.notna(results_data[i].get(MODEL_REGISTRY[mk]['name']))] |
|
if img_scores: |
|
results_data[i]['Final Score'] = float(np.clip(np.mean(img_scores), 0.0, 10.0)) |
|
|
|
df = pd.DataFrame(results_data) |
|
|
|
ordered_cols = ['Thumbnail', 'File Name'] + \ |
|
[MODEL_REGISTRY[k]['name'] for k in MODEL_REGISTRY.keys() if MODEL_REGISTRY[k]['name'] in df.columns] + \ |
|
['Final Score'] |
|
df = df[[col for col in ordered_cols if col in df.columns]] |
|
|
|
logs.append("Evaluation complete.") |
|
progress_tracker(1.0, desc="Evaluation complete.") |
|
return df, logs |
|
|
|
def results_df_to_csv_bytes(df: pd.DataFrame, selected_model_display_names: list[str]) -> bytes | None: |
|
if df.empty: return None |
|
|
|
cols_for_csv = ['File Name', 'Final Score'] + \ |
|
[name for name in selected_model_display_names if name in df.columns and name not in cols_for_csv] |
|
|
|
df_csv = df[cols_for_csv].copy() |
|
for col in df_csv.select_dtypes(include=['float']).columns: |
|
df_csv[col] = df_csv[col].apply(lambda x: f"{x:.4f}" if pd.notnull(x) else "N/A") |
|
|
|
s_io = io.StringIO() |
|
df_csv.to_csv(s_io, index=False) |
|
return s_io.getvalue().encode('utf-8') |
|
|
|
|
|
def create_gradio_interface(): |
|
model_name_choices = [config['name'] for config in MODEL_REGISTRY.values()] |
|
|
|
|
|
initial_df_cols = ['Thumbnail', 'File Name'] + model_name_choices + ['Final Score'] |
|
initial_datatypes = ['image', 'str'] + ['number'] * (len(model_name_choices) + 1) |
|
|
|
with gr.Blocks(theme=gr.themes.Glass()) as demo: |
|
gr.Markdown("## ✨ Comprehensive Image Evaluation Tool ✨") |
|
|
|
|
|
results_state = gr.State(pd.DataFrame(columns=initial_df_cols)) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
gr.Markdown("#### Controls") |
|
files_input = gr.Files(label="Upload Images", file_count="multiple", type="filepath") |
|
models_checkbox_group = gr.CheckboxGroup(choices=model_name_choices, value=model_name_choices, label="Select Models") |
|
|
|
with gr.Accordion("Batch Settings", open=False): |
|
auto_batch_toggle = gr.Checkbox(label="Auto-detect Batch Size", value=True) |
|
manual_batch_input = gr.Number(label="Manual Batch Size", value=4, minimum=1, step=1, interactive=False) |
|
|
|
evaluate_button = gr.Button("🚀 Evaluate Images", variant="primary") |
|
with gr.Row(): |
|
clear_button = gr.Button("🧹 Clear") |
|
download_button = gr.Button("💾 Download CSV") |
|
|
|
|
|
csv_file_output = gr.File(label="Download CSV File", visible=False) |
|
|
|
with gr.Column(scale=3, min_width=600): |
|
gr.Markdown("#### Results") |
|
|
|
progress_slider = gr.Slider(label="Progress", minimum=0, maximum=1, value=0, interactive=False) |
|
|
|
results_dataframe = gr.DataFrame( |
|
label="Evaluation Scores", |
|
headers=initial_df_cols, |
|
datatype=initial_datatypes, |
|
interactive=True, |
|
height=500, |
|
wrap=True |
|
) |
|
logs_textbox = gr.Textbox(label="Process Logs", lines=5, max_lines=10, interactive=False) |
|
|
|
|
|
def map_display_names_to_keys(display_names: list[str]) -> list[str]: |
|
return [key for key, cfg in MODEL_REGISTRY.items() if cfg['name'] in display_names] |
|
|
|
async def run_evaluation(uploaded_files, selected_model_names, auto_batch, manual_batch, |
|
current_results_df, progress=gr.Progress(track_tqdm=True)): |
|
if not uploaded_files: |
|
return { |
|
results_state: current_results_df, logs_textbox: "No files uploaded. Please upload images first.", |
|
progress_slider: gr.update(value=0, label="Progress") |
|
} |
|
|
|
yield {logs_textbox: "Loading images...", progress_slider: gr.update(value=0.01, label="Loading images...")} |
|
|
|
pil_images, file_names = [], [] |
|
for f_obj in uploaded_files: |
|
try: |
|
pil_images.append(Image.open(f_obj.name).convert("RGB")) |
|
file_names.append(os.path.basename(f_obj.name)) |
|
except Exception as e: |
|
print(f"Error loading image {f_obj.name}: {e}") |
|
|
|
if not pil_images: |
|
return {logs_textbox: "No valid images could be loaded.", progress_slider: gr.update(value=0, label="Error")} |
|
|
|
selected_keys = map_display_names_to_keys(selected_model_names) |
|
|
|
batch_size_to_use = manual_batch |
|
if auto_batch: |
|
yield {logs_textbox: "Auto-tuning batch size...", progress_slider: gr.update(value=0.1, label="Auto-tuning...")} |
|
batch_size_to_use = auto_tune_batch_size(pil_images, selected_keys, verbose=True) |
|
yield {manual_batch_input: gr.update(value=batch_size_to_use)} |
|
|
|
yield {logs_textbox: f"Starting evaluation with batch size {batch_size_to_use}...", |
|
progress_slider: gr.update(value=0.15, label=f"Evaluating (Batch: {batch_size_to_use})...")} |
|
|
|
df_new_results, log_messages = await evaluate_images_core( |
|
pil_images, file_names, selected_keys, batch_size_to_use, progress |
|
) |
|
|
|
|
|
if not df_new_results.empty and 'Final Score' in df_new_results.columns: |
|
df_new_results = df_new_results.sort_values(by='Final Score', ascending=False, na_position='last') |
|
|
|
return { |
|
results_state: df_new_results, results_dataframe: df_new_results, |
|
logs_textbox: "\n".join(log_messages), |
|
progress_slider: gr.update(value=1.0, label="Evaluation Complete") |
|
} |
|
|
|
def clear_all_outputs(): |
|
empty_df = pd.DataFrame(columns=initial_df_cols) |
|
return { |
|
results_state: empty_df, results_dataframe: empty_df, |
|
files_input: None, logs_textbox: "Outputs cleared.", |
|
progress_slider: gr.update(value=0, label="Progress") |
|
} |
|
|
|
def download_csv_file(current_df, selected_names): |
|
if current_df.empty: |
|
gr.Warning("No results available to download.") |
|
return None |
|
|
|
csv_data = results_df_to_csv_bytes(current_df, selected_names) |
|
if csv_data: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode='wb') as tmp_f: |
|
tmp_f.write(csv_data) |
|
gr.Info("CSV file prepared for download.") |
|
return tmp_f.name |
|
gr.Error("Failed to generate CSV.") |
|
return None |
|
|
|
def update_final_scores_on_model_select(selected_model_names, current_df): |
|
if current_df.empty: return current_df |
|
|
|
df_updated = current_df.copy() |
|
selected_keys = map_display_names_to_keys(selected_model_names) |
|
|
|
for i, row in df_updated.iterrows(): |
|
img_scores = [row[MODEL_REGISTRY[mk]['name']] for mk in selected_keys |
|
if pd.notna(row.get(MODEL_REGISTRY[mk]['name']))] |
|
if img_scores: |
|
df_updated.loc[i, 'Final Score'] = float(np.clip(np.mean(img_scores), 0.0, 10.0)) |
|
else: |
|
df_updated.loc[i, 'Final Score'] = np.nan |
|
|
|
if 'Final Score' in df_updated.columns: |
|
df_updated = df_updated.sort_values(by='Final Score', ascending=False, na_position='last') |
|
|
|
return {results_state: df_updated, results_dataframe: df_updated} |
|
|
|
auto_batch_toggle.change(lambda x: gr.update(interactive=not x), inputs=auto_batch_toggle, outputs=manual_batch_input) |
|
|
|
evaluate_button.click( |
|
fn=run_evaluation, |
|
inputs=[files_input, models_checkbox_group, auto_batch_toggle, manual_batch_input, results_state], |
|
outputs=[results_state, results_dataframe, logs_textbox, manual_batch_input, progress_slider] |
|
) |
|
clear_button.click(fn=clear_all_outputs, outputs=[results_state, results_dataframe, files_input, logs_textbox, progress_slider]) |
|
download_button.click(fn=download_csv_file, inputs=[results_state, models_checkbox_group], outputs=csv_file_output) |
|
models_checkbox_group.change( |
|
fn=update_final_scores_on_model_select, |
|
inputs=[models_checkbox_group, results_state], |
|
outputs=[results_state, results_dataframe] |
|
) |
|
|
|
|
|
demo.load(lambda: pd.DataFrame(columns=initial_df_cols), outputs=[results_dataframe]) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
initialize_models(verbose_loading=True) |
|
gradio_app = create_gradio_interface() |
|
gradio_app.queue().launch(debug=False) |