Ehrii commited on
Commit
7eced3d
·
1 Parent(s): 589cfa5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -52
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import pipeline, AutoTokenizer
5
  from langdetect import detect, DetectorFactory
@@ -7,52 +7,24 @@ from langdetect import detect, DetectorFactory
7
  # Ensure consistent language detection results
8
  DetectorFactory.seed = 0
9
 
10
- # Set Hugging Face cache directory
11
- os.environ["HF_HOME"] = "/tmp/huggingface_cache"
12
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
13
-
14
- # Create cache directory if it doesn't exist
15
- cache_dir = os.environ["HF_HOME"]
16
- os.makedirs(cache_dir, exist_ok=True)
17
-
18
- # Retrieve Hugging Face token from environment variable
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
- if not HF_TOKEN:
21
- raise RuntimeError("Hugging Face token is missing! Please set the HF_TOKEN environment variable.")
22
-
23
- # Set the Hugging Face token in the environment variable
24
- os.environ["HUGGINGFACE_HUB_TOKEN"] = HF_TOKEN
25
 
26
  app = FastAPI()
27
 
28
- # Model names
29
- MULTILINGUAL_MODEL_NAME = "Ehrii/sentiment"
30
- MULTILINGUAL_TOKENIZER_NAME = "tabularisai/multilingual-sentiment-analysis"
31
- ENGLISH_MODEL_NAME = "siebert/sentiment-roberta-large-english"
32
-
33
- # Load multilingual sentiment model
34
- try:
35
- multilingual_tokenizer = AutoTokenizer.from_pretrained(
36
- MULTILINGUAL_TOKENIZER_NAME,
37
- cache_dir=cache_dir
38
- )
39
 
40
- multilingual_model = pipeline(
41
- "sentiment-analysis",
42
- model=MULTILINGUAL_MODEL_NAME,
43
- tokenizer=multilingual_tokenizer
44
- )
45
- except Exception as e:
46
- raise RuntimeError(f"Failed to load multilingual model: {e}")
47
 
48
- # Load English sentiment model
49
- try:
50
- english_model = pipeline(
51
- "sentiment-analysis",
52
- model=ENGLISH_MODEL_NAME
53
- )
54
- except Exception as e:
55
- raise RuntimeError(f"Failed to load English sentiment model: {e}")
56
 
57
  class SentimentRequest(BaseModel):
58
  text: str
@@ -64,7 +36,6 @@ class SentimentResponse(BaseModel):
64
  confidence_score: float
65
 
66
  def detect_language(text):
67
- """Detect the language of the given text."""
68
  try:
69
  return detect(text)
70
  except Exception:
@@ -76,17 +47,15 @@ def home():
76
 
77
  @app.post("/analyze/", response_model=SentimentResponse)
78
  def analyze_sentiment(request: SentimentRequest):
79
- text = request.text.strip()
80
- if not text:
81
- raise HTTPException(status_code=400, detail="Text input cannot be empty.")
82
-
83
  language = detect_language(text)
84
-
85
- # Use English model if detected language is English; otherwise, use multilingual model
86
- model = english_model if language == "en" else multilingual_model
87
 
88
- result = model(text)
89
-
 
 
 
 
90
  return SentimentResponse(
91
  original_text=text,
92
  language_detected=language,
 
1
  import os
2
+ from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import pipeline, AutoTokenizer
5
  from langdetect import detect, DetectorFactory
 
7
  # Ensure consistent language detection results
8
  DetectorFactory.seed = 0
9
 
10
+ # Set Hugging Face cache directory to a writable location
11
+ os.environ["HF_HOME"] = "/tmp/huggingface"
12
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
16
+ # Load the original tokenizer from the base model
17
+ original_tokenizer = AutoTokenizer.from_pretrained("tabularisai/multilingual-sentiment-analysis")
 
 
 
 
 
 
 
 
 
18
 
19
+ # Load the fine-tuned model and pass the tokenizer explicitly
20
+ multilingual_model = pipeline(
21
+ "sentiment-analysis",
22
+ model="Ehrii/sentiment",
23
+ tokenizer=original_tokenizer
24
+ )
 
25
 
26
+ # English model remains unchanged
27
+ english_model = pipeline("sentiment-analysis", model="siebert/sentiment-roberta-large-english")
 
 
 
 
 
 
28
 
29
  class SentimentRequest(BaseModel):
30
  text: str
 
36
  confidence_score: float
37
 
38
  def detect_language(text):
 
39
  try:
40
  return detect(text)
41
  except Exception:
 
47
 
48
  @app.post("/analyze/", response_model=SentimentResponse)
49
  def analyze_sentiment(request: SentimentRequest):
50
+ text = request.text
 
 
 
51
  language = detect_language(text)
 
 
 
52
 
53
+ # Choose the appropriate model based on language
54
+ if language == "en":
55
+ result = english_model(text)
56
+ else:
57
+ result = multilingual_model(text)
58
+
59
  return SentimentResponse(
60
  original_text=text,
61
  language_detected=language,