import gradio as gr from transformers import pipeline, set_seed import re import numpy as np import pandas as pd # Set a seed for reproducibility set_seed(42) # Define five small models for generation (free, lightweight) small_models = [ "distilgpt2", # ~82M parameters "gpt2", # ~124M parameters "EleutherAI/gpt-neo-125M", # ~125M parameters "sshleifer/tiny-gpt2", # extremely small variant "microsoft/DialoGPT-small" # dialoGPT in small size ] # Define five languages (English, German, Spanish, French, Portuguese) languages = { "en": "English", "de": "German", "es": "Spanish", "fr": "French", "pt": "Portuguese" } # Define two cost-effective grammar evaluation models grammar_model_names = [ "vennify/t5-base-grammar-correction", "hassaanik/grammar-correction-model" ] # Functions to load pipelines on demand def load_generation_pipeline(model_name): try: # Use text-generation pipeline for causal LM models return pipeline("text-generation", model=model_name) except Exception as e: print(f"Error loading generation model {model_name}: {e}") return None def load_grammar_pipeline(model_name): try: return pipeline("text2text-generation", model=model_name) except Exception as e: print(f"Error loading grammar model {model_name}: {e}") return None # Pre-load grammar evaluator pipelines rater_models = [] for model_name in grammar_model_names: p = load_grammar_pipeline(model_name) if p is not None: rater_models.append(p) # Utility functions for checking palindromes and cleaning text def clean_text(text): return re.sub(r'[^a-zA-Z0-9]', '', text.lower()) def is_palindrome(text): cleaned = clean_text(text) return cleaned == cleaned[::-1] def grammar_prompt(pal, lang): return f'''Rate from 0 to 100 how grammatically correct this palindrome is in {lang}. Only return a number with no explanation:\n\n"{pal}"\n''' def extract_score(text): match = re.search(r"\d{1,3}", text) if match: score = int(match.group()) return min(max(score, 0), 100) return 0 # Main benchmark function that runs all tests at once def run_benchmark_all(): results = [] for model_name in small_models: # Load the generation pipeline for the current small model gen_pipeline = load_generation_pipeline(model_name) if gen_pipeline is None: continue # Skip if model fails to load for code, lang in languages.items(): # Prompt for generating a palindrome in the given language prompt = ( f"Write the longest original palindrome you can in {lang}. " "It should be creative and not a known palindrome. " "If it is not a correct palindrome, you will lose points according to how correct it is." ) try: gen_output = gen_pipeline(prompt, max_new_tokens=50, do_sample=True)[0]['generated_text'].strip() except Exception as e: gen_output = f"Error generating text: {e}" valid = is_palindrome(gen_output) cleaned_len = len(clean_text(gen_output)) # Measure grammar evaluation using both rater models scores = [] for rater in rater_models: rprompt = grammar_prompt(gen_output, lang) try: rtext = rater(rprompt, max_new_tokens=10)[0]['generated_text'] score = extract_score(rtext) scores.append(score) except Exception as e: scores.append(0) avg_score = np.mean(scores) if scores else 0 # Apply a penalty if the text is not a valid palindrome penalty = (avg_score / 100) if valid else (avg_score / 100) * 0.5 final_score = round(cleaned_len * penalty, 2) results.append({ "Model": model_name, "Language": lang, "Palindrome": gen_output, "Valid": "✅" if valid else "❌", "Length": cleaned_len, "Grammar Score": avg_score, "Final Score": final_score }) df = pd.DataFrame(results).sort_values(by="Final Score", ascending=False).reset_index(drop=True) return gr.Dataframe(df) # Build Gradio UI using Blocks (canvas layout) with gr.Blocks(title="Small Model Palindrome Benchmark") as demo: gr.Markdown("# Small Model Palindrome Benchmark") gr.Markdown("This benchmark runs automatically during the night over 5 small text-generation models and 5 languages (English, German, Spanish, French, Portuguese). All tests are run at once.") with gr.Row(): run_button = gr.Button("Run All Benchmarks") output_table = gr.Dataframe(label="Benchmark Results") run_button.click(fn=run_benchmark_all, inputs=[], outputs=output_table) demo.launch()