Spaces:
Running
Running
import os | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer | |
from langdetect import detect, DetectorFactory | |
# Ensure consistent language detection results | |
DetectorFactory.seed = 0 | |
# Set Hugging Face cache directory | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" | |
# Create cache directory if it doesn't exist | |
cache_dir = os.environ["HF_HOME"] | |
os.makedirs(cache_dir, exist_ok=True) | |
# Retrieve Hugging Face token from environment variable | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
if not HF_TOKEN: | |
raise RuntimeError("Hugging Face token is missing! Please set the HF_TOKEN environment variable.") | |
app = FastAPI() | |
# Model names | |
MULTILINGUAL_MODEL_NAME = "Ehrii/sentiment" | |
MULTILINGUAL_TOKENIZER_NAME = "tabularisai/multilingual-sentiment-analysis" | |
ENGLISH_MODEL_NAME = "siebert/sentiment-roberta-large-english" | |
# Load multilingual sentiment model | |
try: | |
multilingual_tokenizer = AutoTokenizer.from_pretrained( | |
MULTILINGUAL_TOKENIZER_NAME, | |
use_auth_token=HF_TOKEN, | |
cache_dir=cache_dir | |
) | |
multilingual_model = pipeline( | |
"sentiment-analysis", | |
model=MULTILINGUAL_MODEL_NAME, | |
tokenizer=multilingual_tokenizer, | |
use_auth_token=HF_TOKEN | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load multilingual model: {e}") | |
# Load English sentiment model | |
try: | |
english_model = pipeline( | |
"sentiment-analysis", | |
model=ENGLISH_MODEL_NAME, | |
use_auth_token=HF_TOKEN | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load English sentiment model: {e}") | |
class SentimentRequest(BaseModel): | |
text: str | |
class SentimentResponse(BaseModel): | |
original_text: str | |
language_detected: str | |
sentiment: str | |
confidence_score: float | |
def detect_language(text): | |
"""Detect the language of the given text.""" | |
try: | |
return detect(text) | |
except Exception: | |
return "unknown" | |
def home(): | |
return {"message": "Sentiment Analysis API is running!"} | |
def analyze_sentiment(request: SentimentRequest): | |
text = request.text.strip() | |
if not text: | |
raise HTTPException(status_code=400, detail="Text input cannot be empty.") | |
language = detect_language(text) | |
model = english_model if language == "en" else multilingual_model | |
result = model(text) | |
return SentimentResponse( | |
original_text=text, | |
language_detected=language, | |
sentiment=result[0]["label"].lower(), | |
confidence_score=result[0]["score"], | |
) | |