Spaces:
Running
Running
| import os | |
| import random | |
| from statistics import mean | |
| from typing import Iterator, Union, Any | |
| import fasttext | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import logging | |
| from toolz import concat, groupby, valmap | |
| from pathlib import Path | |
| logger = logging.get_logger(__name__) | |
| load_dotenv() | |
| DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
| def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
| model_path = hf_hub_download(repo_id, filename="model.bin") | |
| return fasttext.load_model(model_path) | |
| def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
| for row in rows: | |
| if isinstance(row, str): | |
| # split on lines and remove empty lines | |
| line = row.split("\n") | |
| for line in line: | |
| if line: | |
| yield line | |
| elif isinstance(row, list): | |
| try: | |
| line = " ".join(row) | |
| if len(line) < min_length: | |
| continue | |
| else: | |
| yield line | |
| except TypeError: | |
| continue | |
| FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
| # Load the model | |
| Path("code/models").mkdir(parents=True, exist_ok=True) | |
| model = fasttext.load_model( | |
| hf_hub_download( | |
| "facebook/fasttext-language-identification", | |
| "model.bin", | |
| cache_dir="code/models", | |
| local_dir="code/models", | |
| local_dir_use_symlinks=False, | |
| ) | |
| ) | |
| def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
| predictions = model.predict(inputs, k=k) | |
| return [ | |
| {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
| for label, prob in zip(predictions[0], predictions[1]) | |
| ] | |
| def get_label(x): | |
| return x.get("label") | |
| def get_mean_score(preds): | |
| return mean([pred.get("score") for pred in preds]) | |
| def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
| """Filter a dict to include items whose value is above `threshold_percent`""" | |
| total = sum(counts_dict.values()) | |
| threshold = total * threshold_percent | |
| return {k for k, v in counts_dict.items() if v >= threshold} | |
| def simple_predict(text, num_predictions=3): | |
| """Simple language detection function for Gradio interface""" | |
| if not text or not text.strip(): | |
| return {"error": "Please enter some text for language detection."} | |
| try: | |
| # Clean the text | |
| cleaned_lines = list(yield_clean_rows([text])) | |
| if not cleaned_lines: | |
| return {"error": "No valid text found after cleaning."} | |
| # Get predictions for each line | |
| all_predictions = [] | |
| for line in cleaned_lines: | |
| predictions = model_predict(line, k=num_predictions) | |
| all_predictions.extend(predictions) | |
| if not all_predictions: | |
| return {"error": "No predictions could be made."} | |
| # Group predictions by language | |
| predictions_by_lang = groupby(get_label, all_predictions) | |
| language_counts = valmap(len, predictions_by_lang) | |
| # Calculate average scores for each language | |
| language_scores = valmap(get_mean_score, predictions_by_lang) | |
| # Format results | |
| results = { | |
| "detected_languages": dict(language_scores), | |
| "language_counts": dict(language_counts), | |
| "total_predictions": len(all_predictions), | |
| "text_lines_analyzed": len(cleaned_lines) | |
| } | |
| return results | |
| except Exception as e: | |
| return {"error": f"Error during prediction: {str(e)}"} | |
| def batch_predict(text, threshold_percent=0.2): | |
| """More advanced prediction with filtering""" | |
| if not text or not text.strip(): | |
| return {"error": "Please enter some text for language detection."} | |
| try: | |
| # Clean the text | |
| cleaned_lines = list(yield_clean_rows([text])) | |
| if not cleaned_lines: | |
| return {"error": "No valid text found after cleaning."} | |
| # Get predictions | |
| predictions = [model_predict(line) for line in cleaned_lines] | |
| predictions = [pred for pred in predictions if pred is not None] | |
| predictions = list(concat(predictions)) | |
| if not predictions: | |
| return {"error": "No predictions could be made."} | |
| # Group and filter | |
| predictions_by_lang = groupby(get_label, predictions) | |
| language_counts = valmap(len, predictions_by_lang) | |
| keys_to_keep = filter_by_frequency(language_counts, threshold_percent=threshold_percent) | |
| filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
| results = { | |
| "predictions": dict(valmap(get_mean_score, filtered_dict)), | |
| "all_language_counts": dict(language_counts), | |
| "filtered_languages": list(keys_to_keep), | |
| "threshold_used": threshold_percent | |
| } | |
| return results | |
| except Exception as e: | |
| return {"error": f"Error during prediction: {str(e)}"} | |
| def build_demo_interface(): | |
| app_title = "Language Detection Tool" | |
| with gr.Blocks(title=app_title) as demo: | |
| gr.Markdown(f"# {app_title}") | |
| gr.Markdown("Enter text below to detect the language(s) it contains.") | |
| with gr.Tab("Simple Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input1 = gr.Textbox( | |
| label="Enter text for language detection", | |
| placeholder="Type or paste your text here...", | |
| lines=5 | |
| ) | |
| num_predictions = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of top predictions per line" | |
| ) | |
| predict_btn1 = gr.Button("Detect Language") | |
| with gr.Column(): | |
| output1 = gr.JSON(label="Detection Results") | |
| predict_btn1.click( | |
| simple_predict, | |
| inputs=[text_input1, num_predictions], | |
| outputs=output1 | |
| ) | |
| with gr.Tab("Advanced Detection"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input2 = gr.Textbox( | |
| label="Enter text for advanced language detection", | |
| placeholder="Type or paste your text here...", | |
| lines=5 | |
| ) | |
| threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.1, | |
| label="Threshold percentage for filtering" | |
| ) | |
| predict_btn2 = gr.Button("Advanced Detect") | |
| with gr.Column(): | |
| output2 = gr.JSON(label="Advanced Detection Results") | |
| predict_btn2.click( | |
| batch_predict, | |
| inputs=[text_input2, threshold], | |
| outputs=output2 | |
| ) | |
| gr.Markdown("### About") | |
| gr.Markdown("This tool uses Facebook's FastText language identification model to detect languages in text.") | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |