yummyu commited on
Commit
091616b
·
verified ·
1 Parent(s): bb40a26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -104
app.py CHANGED
@@ -1,20 +1,22 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import pipeline
4
  import os
5
  from typing import Optional
6
 
7
- # FastAPIアプリの初期化
8
- app = FastAPI(
9
- title="Hugging Face API on Spaces",
10
- description="Hugging Face Transformersを使ったAPI",
11
- version="1.0.0"
12
- )
 
 
 
 
13
 
14
- # リクエスト用のPydanticモデル
15
  class TextRequest(BaseModel):
16
  text: str
17
- max_length: Optional[int] = 100
18
 
19
  class SentimentResponse(BaseModel):
20
  text: str
@@ -27,62 +29,59 @@ class GenerateResponse(BaseModel):
27
  generated_text: str
28
  model_name: str
29
 
30
- # グローバル変数でモデルを保持
31
  sentiment_classifier = None
32
  text_generator = None
33
 
34
  @app.on_event("startup")
35
  async def load_models():
36
- """アプリ起動時にモデルをロード"""
37
  global sentiment_classifier, text_generator
38
 
39
- print("モデルをロード中...")
40
 
41
  try:
42
- # 感情分析モデル(軽量版を使用)
43
- print("Loading sentiment analysis model...")
 
 
44
  sentiment_classifier = pipeline(
45
  "sentiment-analysis",
46
- model="cardiffnlp/twitter-roberta-base-sentiment-latest",
47
- return_all_scores=False
48
  )
49
- print("✅ Sentiment model loaded")
50
 
51
- # テキスト生成モデル(軽量版)
52
- print("Loading text generation model...")
53
  text_generator = pipeline(
54
  "text-generation",
55
- model="distilgpt2",
56
- pad_token_id=50256 # GPT-2のEOSトークンID
57
  )
58
- print("✅ Text generation model loaded")
59
 
60
- print("✅ 全てのモデルのロードが完了しました")
61
 
62
  except Exception as e:
63
- print(f"❌ モデルロードエラー: {e}")
64
- print(f"エラーの詳細: {type(e).__name__}")
65
  import traceback
66
  traceback.print_exc()
67
 
68
  @app.get("/")
69
  async def root():
70
- """ヘルスチェック用エンドポイント"""
71
  return {
72
- "message": "🤗 Hugging Face API is running on Spaces!",
73
  "status": "healthy",
74
- "endpoints": ["/sentiment", "/generate", "/models"]
75
  }
76
 
77
  @app.post("/sentiment", response_model=SentimentResponse)
78
  async def analyze_sentiment(request: TextRequest):
79
- """感情分析エンドポイント"""
80
  try:
81
  if sentiment_classifier is None:
82
- raise HTTPException(
83
- status_code=503,
84
- detail="Sentiment model not loaded. Please try again later."
85
- )
86
 
87
  result = sentiment_classifier(request.text)
88
 
@@ -90,7 +89,7 @@ async def analyze_sentiment(request: TextRequest):
90
  text=request.text,
91
  sentiment=result[0]["label"],
92
  confidence=round(result[0]["score"], 4),
93
- model_name="cardiffnlp/twitter-roberta-base-sentiment-latest"
94
  )
95
 
96
  except Exception as e:
@@ -98,98 +97,46 @@ async def analyze_sentiment(request: TextRequest):
98
 
99
  @app.post("/generate", response_model=GenerateResponse)
100
  async def generate_text(request: TextRequest):
101
- """テキスト生成エンドポイント"""
102
  try:
103
  if text_generator is None:
104
- raise HTTPException(
105
- status_code=503,
106
- detail="Text generation model not loaded. Please try again later."
107
- )
108
 
109
- # 入力テキストの検証
110
- if not request.text or len(request.text.strip()) == 0:
111
- raise HTTPException(status_code=400, detail="Text cannot be empty")
112
-
113
- # Spacesの制限を考慮して短めに設定
114
  max_length = min(request.max_length, 100)
115
- input_length = len(request.text.split())
116
 
117
- # 入力より長い出力を生成するように調整
118
- if max_length <= input_length:
119
- max_length = input_length + 20
 
 
 
 
 
120
 
121
- try:
122
- result = text_generator(
123
- request.text,
124
- max_length=max_length,
125
- num_return_sequences=1,
126
- temperature=0.7,
127
- do_sample=True,
128
- truncation=True,
129
- pad_token_id=text_generator.tokenizer.eos_token_id
130
- )
131
-
132
- generated_text = result[0]["generated_text"]
133
-
134
- return GenerateResponse(
135
- input_text=request.text,
136
- generated_text=generated_text,
137
- model_name="distilgpt2"
138
- )
139
-
140
- except Exception as model_error:
141
- # モデル固有のエラーをキャッチ
142
- print(f"Model error: {model_error}")
143
- raise HTTPException(
144
- status_code=500,
145
- detail=f"Model processing failed: {str(model_error)}"
146
- )
147
 
148
- except HTTPException:
149
- # HTTPExceptionは再発生
150
- raise
151
  except Exception as e:
152
- print(f"Unexpected error: {e}")
153
  raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
154
 
155
  @app.get("/models")
156
  async def get_models():
157
- """利用可能なモデル情報を取得"""
158
  return {
159
  "sentiment_analysis": {
160
- "model": "cardiffnlp/twitter-roberta-base-sentiment-latest",
161
  "status": "loaded" if sentiment_classifier else "not loaded"
162
  },
163
  "text_generation": {
164
- "model": "distilgpt2",
165
  "status": "loaded" if text_generator else "not loaded"
166
  },
167
- "platform": "Hugging Face Spaces"
168
- }
169
-
170
- @app.get("/debug")
171
- async def debug_info():
172
- """デバッグ情報を取得"""
173
- import sys
174
- import torch
175
-
176
- return {
177
- "python_version": sys.version,
178
- "torch_version": torch.__version__ if 'torch' in sys.modules else "not installed",
179
- "models_loaded": {
180
- "sentiment": sentiment_classifier is not None,
181
- "generator": text_generator is not None
182
- },
183
- "generator_tokenizer": {
184
- "vocab_size": text_generator.tokenizer.vocab_size if text_generator else None,
185
- "eos_token_id": text_generator.tokenizer.eos_token_id if text_generator else None,
186
- "pad_token_id": text_generator.tokenizer.pad_token_id if text_generator else None
187
- } if text_generator else None
188
  }
189
 
190
- # Spaces用の追加設定
191
  if __name__ == "__main__":
192
  import uvicorn
193
- # Spacesでは通常ポート7860を使用
194
  port = int(os.environ.get("PORT", 7860))
195
  uvicorn.run(app, host="0.0.0.0", port=port)
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  import os
4
  from typing import Optional
5
 
6
+ # キャッシュディレクトリの設定
7
+ os.environ["HF_HOME"] = "/app/.cache"
8
+ os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
9
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/app/.cache"
10
+
11
+ # キャッシュディレクトリを作成
12
+ cache_dir = "/app/.cache"
13
+ os.makedirs(cache_dir, exist_ok=True)
14
+
15
+ app = FastAPI(title="Lightweight Hugging Face API")
16
 
 
17
  class TextRequest(BaseModel):
18
  text: str
19
+ max_length: Optional[int] = 50
20
 
21
  class SentimentResponse(BaseModel):
22
  text: str
 
29
  generated_text: str
30
  model_name: str
31
 
32
+ # グローバル変数
33
  sentiment_classifier = None
34
  text_generator = None
35
 
36
  @app.on_event("startup")
37
  async def load_models():
38
+ """軽量モデルをロード"""
39
  global sentiment_classifier, text_generator
40
 
41
+ print("🚀 軽量モデルのロード開始...")
42
 
43
  try:
44
+ from transformers import pipeline
45
+
46
+ # 非常に軽量な感情分析モデル
47
+ print("📥 軽量感情分析モデルをロード中...")
48
  sentiment_classifier = pipeline(
49
  "sentiment-analysis",
50
+ model="prajjwal1/bert-tiny", # 非常に軽量
51
+ cache_dir="/app/.cache"
52
  )
53
+ print("✅ 感情分析モデル読み込み完了")
54
 
55
+ # 軽量テキスト生成モデル
56
+ print("📥 軽量テキスト生成モデルをロード中...")
57
  text_generator = pipeline(
58
  "text-generation",
59
+ model="sshleifer/tiny-gpt2", # 非常に軽量
60
+ cache_dir="/app/.cache"
61
  )
62
+ print("✅ テキスト生成モデル読み込み完了")
63
 
64
+ print("✅ 全てのモデル読み込み完了")
65
 
66
  except Exception as e:
67
+ print(f"❌ モデル読み込みエラー: {e}")
 
68
  import traceback
69
  traceback.print_exc()
70
 
71
  @app.get("/")
72
  async def root():
 
73
  return {
74
+ "message": "🤗 Lightweight Hugging Face API is running!",
75
  "status": "healthy",
76
+ "models": "lightweight versions"
77
  }
78
 
79
  @app.post("/sentiment", response_model=SentimentResponse)
80
  async def analyze_sentiment(request: TextRequest):
81
+ """軽量感情分析"""
82
  try:
83
  if sentiment_classifier is None:
84
+ raise HTTPException(status_code=503, detail="Sentiment model not loaded")
 
 
 
85
 
86
  result = sentiment_classifier(request.text)
87
 
 
89
  text=request.text,
90
  sentiment=result[0]["label"],
91
  confidence=round(result[0]["score"], 4),
92
+ model_name="prajjwal1/bert-tiny"
93
  )
94
 
95
  except Exception as e:
 
97
 
98
  @app.post("/generate", response_model=GenerateResponse)
99
  async def generate_text(request: TextRequest):
100
+ """軽量テキスト生成"""
101
  try:
102
  if text_generator is None:
103
+ raise HTTPException(status_code=503, detail="Generation model not loaded")
 
 
 
104
 
 
 
 
 
 
105
  max_length = min(request.max_length, 100)
 
106
 
107
+ result = text_generator(
108
+ request.text,
109
+ max_length=max_length,
110
+ num_return_sequences=1,
111
+ temperature=0.7,
112
+ do_sample=True,
113
+ pad_token_id=text_generator.tokenizer.eos_token_id
114
+ )
115
 
116
+ return GenerateResponse(
117
+ input_text=request.text,
118
+ generated_text=result[0]["generated_text"],
119
+ model_name="sshleifer/tiny-gpt2"
120
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
 
 
 
122
  except Exception as e:
 
123
  raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
124
 
125
  @app.get("/models")
126
  async def get_models():
 
127
  return {
128
  "sentiment_analysis": {
129
+ "model": "prajjwal1/bert-tiny",
130
  "status": "loaded" if sentiment_classifier else "not loaded"
131
  },
132
  "text_generation": {
133
+ "model": "sshleifer/tiny-gpt2",
134
  "status": "loaded" if text_generator else "not loaded"
135
  },
136
+ "note": "Using lightweight models for Spaces compatibility"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  }
138
 
 
139
  if __name__ == "__main__":
140
  import uvicorn
 
141
  port = int(os.environ.get("PORT", 7860))
142
  uvicorn.run(app, host="0.0.0.0", port=port)