ajsbsd's picture
Update app.py
e31e6bd verified
# app.py
import os
import gradio as gr
import spaces
from datasets import load_dataset
from transformers import MarianTokenizer, MarianMTModel
from huggingface_hub import login
# Set cache paths for models
os.environ['TRANSFORMERS_CACHE'] = "/tmp/cache/models"
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
# Log in to Hugging Face Hub if HF_TOKEN is available
hf_token = os.getenv("HF_TOKEN")
if hf_token:
try:
login(token=hf_token)
print("✅ Authenticated with Hugging Face Hub")
except Exception as e:
print(f"[ERROR] Hugging Face login failed: {e}")
else:
print("[WARNING] HF_TOKEN not found — dataset may be inaccessible")
# Load Dataset from Hub (Private or Public)
DATASET_NAME = "ajsbsd/legalese-sentences_estonian-english"
try:
ds = load_dataset(DATASET_NAME)
print("✅ Dataset loaded successfully")
except Exception as e:
print(f"[ERROR] Failed to load dataset '{DATASET_NAME}': {e}")
ds = None
# Load Translation Model
MODEL_NAME = "Helsinki-NLP/opus-mt-et-en"
try:
tokenizer = MarianTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = MarianMTModel.from_pretrained(MODEL_NAME, use_safetensors=True)
print("✅ Translation model loaded successfully")
except Exception as e:
print(f"[ERROR] Failed to load model '{MODEL_NAME}': {e}")
tokenizer = None
model = None
def get_sample(split="train", index=0):
"""Fetch a sample from the dataset."""
print(f"[DEBUG] Loading sample from '{split}' split, index {index}")
if not ds:
return {"error": "Dataset not loaded or unavailable."}
try:
sample = ds[split][index]
return {
"Estonian": sample["input"],
"English": sample["output"]
}
except IndexError:
return {"error": f"Index {index} out of range for '{split}' split"}
except Exception as e:
return {"error": str(e)}
@spaces.GPU
def translate(text):
"""Translate Estonian legal text to English."""
if not tokenizer or not model:
return "[Error] Translation model failed to load."
try:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model.generate(**inputs)
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
return translation
except Exception as e:
return f"[Error] Translation failed: {str(e)}"
# --- Gradio Interfaces ---
# Dataset Explorer Tab
dataset_interface = gr.Interface(
fn=get_sample,
inputs=[
gr.Dropdown(choices=["train", "test"], label="Split", value="train"),
gr.Number(label="Sample Index", value=0, precision=0)
],
outputs=gr.JSON(label="Sample"),
title="📚 Dataset Explorer",
description="Browse Estonian–English legal sentence pairs.",
flagging_mode="never"
)
# Translation Demo Tab
translate_interface = gr.Interface(
fn=translate,
inputs=gr.Textbox(lines=5, placeholder="Enter Estonian legal text...", label="Estonian Input"),
outputs=gr.Textbox(label="English Translation"),
title="🔤 Translate Legal Text",
description="Translate Estonian legal sentences to English using Helsinki-NLP/opus-mt-et-en",
flagging_mode="never"
)
# Combined App with Tabs
demo = gr.TabbedInterface(
interface_list=[dataset_interface, translate_interface],
tab_names=["📚 Dataset Explorer", "🔤 Translation Demo"],
)
# For Hugging Face Spaces compatibility
app = demo.app
if __name__ == "__main__":
demo.launch(share=True)