Spaces:
Running
Running
| import os | |
| import random | |
| from statistics import mean | |
| from typing import Iterator, Union, Any | |
| import fasttext | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from httpx import Client, Timeout | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import logging | |
| from toolz import concat, groupby, valmap | |
| from fastapi import FastAPI | |
| from httpx import AsyncClient | |
| from pathlib import Path | |
| app = FastAPI() | |
| logger = logging.get_logger(__name__) | |
| load_dotenv() | |
| DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
| def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
| model_path = hf_hub_download(repo_id, filename="model.bin") | |
| return fasttext.load_model(model_path) | |
| def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
| for row in rows: | |
| if isinstance(row, str): | |
| # split on lines and remove empty lines | |
| line = row.split("\n") | |
| for line in line: | |
| if line: | |
| yield line | |
| elif isinstance(row, list): | |
| try: | |
| line = " ".join(row) | |
| if len(line) < min_length: | |
| continue | |
| else: | |
| yield line | |
| except TypeError: | |
| continue | |
| FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
| # model = load_model(DEFAULT_FAST_TEXT_MODEL) | |
| Path("code/models").mkdir(parents=True, exist_ok=True) | |
| model = fasttext.load_model( | |
| hf_hub_download( | |
| "facebook/fasttext-language-identification", | |
| "model.bin", | |
| cache_dir="code/models", | |
| local_dir="code/models", | |
| local_dir_use_symlinks=False, | |
| ) | |
| ) | |
| def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
| predictions = model.predict(inputs, k=k) | |
| return [ | |
| {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
| for label, prob in zip(predictions[0], predictions[1]) | |
| ] | |
| def get_label(x): | |
| return x.get("label") | |
| def get_mean_score(preds): | |
| return mean([pred.get("score") for pred in preds]) | |
| def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
| """Filter a dict to include items whose value is above `threshold_percent`""" | |
| total = sum(counts_dict.values()) | |
| threshold = total * threshold_percent | |
| return {k for k, v in counts_dict.items() if v >= threshold} | |
| def predict_rows(rows, target_column, language_threshold_percent=0.2): | |
| rows = (row.get(target_column) for row in rows) | |
| rows = (row for row in rows if row is not None) | |
| rows = list(yield_clean_rows(rows)) | |
| predictions = [model_predict(row) for row in rows] | |
| predictions = [pred for pred in predictions if pred is not None] | |
| predictions = list(concat(predictions)) | |
| predictions_by_lang = groupby(get_label, predictions) | |
| langues_counts = valmap(len, predictions_by_lang) | |
| keys_to_keep = filter_by_frequency( | |
| langues_counts, threshold_percent=language_threshold_percent | |
| ) | |
| filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
| return { | |
| "predictions": dict(valmap(get_mean_score, filtered_dict)), | |
| "pred": predictions, | |
| } | |
| async def predict_language( | |
| hub_id: str, | |
| config: str | None = None, | |
| split: str | None = None, | |
| max_request_calls: int = 10, | |
| number_of_rows: int = 1000, | |
| ) -> dict[Any, Any]: | |
| is_valid = datasets_server_valid_rows(hub_id) | |
| if not is_valid: | |
| gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
| if not config: | |
| config, split = await get_first_config_and_split_name(hub_id) | |
| info = await get_dataset_info(hub_id, config) | |
| if info is None: | |
| gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
| if dataset_info := info.get("dataset_info"): | |
| total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") | |
| features = dataset_info.get("features") | |
| column_names = set(features.keys()) | |
| logger.info(f"Column names: {column_names}") | |
| if not set(column_names).intersection(TARGET_COLUMN_NAMES): | |
| raise gr.Error( | |
| f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}" | |
| ) | |
| for column in TARGET_COLUMN_NAMES: | |
| if column in column_names: | |
| target_column = column | |
| logger.info(f"Using column {target_column} for language detection") | |
| break | |
| random_rows = await get_random_rows( | |
| hub_id, | |
| total_rows_for_split, | |
| number_of_rows, | |
| max_request_calls, | |
| config, | |
| split, | |
| ) | |
| logger.info(f"Predicting language for {len(random_rows)} rows") | |
| predictions = predict_rows(random_rows, target_column) | |
| predictions["hub_id"] = hub_id | |
| predictions["config"] = config | |
| predictions["split"] = split | |
| return predictions | |
| app_title = "Language Detection" | |
| inputs = [ | |
| gr.Textbox( | |
| None, | |
| label="enter content", | |
| ), | |
| gr.Textbox(None, label="split"), | |
| ] | |
| interface = gr.Interface( | |
| predict_language, | |
| inputs=inputs, | |
| outputs="json", | |
| title=app_title, | |
| # article=app_description, | |
| ) | |
| interface.queue() | |
| interface.launch() |