yummyu commited on
Commit
7b7a628
·
verified ·
1 Parent(s): 870565f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import pipeline
4
+ import os
5
+ from typing import Optional
6
+
7
+ # FastAPIアプリの初期化
8
+ app = FastAPI(
9
+ title="Hugging Face API on Spaces",
10
+ description="Hugging Face Transformersを使ったAPI",
11
+ version="1.0.0"
12
+ )
13
+
14
+ # リクエスト用のPydanticモデル
15
+ class TextRequest(BaseModel):
16
+ text: str
17
+ max_length: Optional[int] = 100
18
+
19
+ class SentimentResponse(BaseModel):
20
+ text: str
21
+ sentiment: str
22
+ confidence: float
23
+ model_name: str
24
+
25
+ class GenerateResponse(BaseModel):
26
+ input_text: str
27
+ generated_text: str
28
+ model_name: str
29
+
30
+ # グローバル変数でモデルを保持
31
+ sentiment_classifier = None
32
+ text_generator = None
33
+
34
+ @app.on_event("startup")
35
+ async def load_models():
36
+ """アプリ起動時にモデルをロード"""
37
+ global sentiment_classifier, text_generator
38
+
39
+ print("モデルをロード中...")
40
+
41
+ try:
42
+ # 感情分析モデル(軽量版を使用)
43
+ sentiment_classifier = pipeline(
44
+ "sentiment-analysis",
45
+ model="cardiffnlp/twitter-roberta-base-sentiment-latest"
46
+ )
47
+
48
+ # テキスト生成モデル(軽量版)
49
+ text_generator = pipeline(
50
+ "text-generation",
51
+ model="distilgpt2" # GPT-2より軽量
52
+ )
53
+
54
+ print("✅ モデルのロードが完了しました")
55
+
56
+ except Exception as e:
57
+ print(f"❌ モデルロードエラー: {e}")
58
+
59
+ @app.get("/")
60
+ async def root():
61
+ """ヘルスチェック用エンドポイント"""
62
+ return {
63
+ "message": "🤗 Hugging Face API is running on Spaces!",
64
+ "status": "healthy",
65
+ "endpoints": ["/sentiment", "/generate", "/models"]
66
+ }
67
+
68
+ @app.post("/sentiment", response_model=SentimentResponse)
69
+ async def analyze_sentiment(request: TextRequest):
70
+ """感情分析エンドポイント"""
71
+ try:
72
+ if sentiment_classifier is None:
73
+ raise HTTPException(
74
+ status_code=503,
75
+ detail="Sentiment model not loaded. Please try again later."
76
+ )
77
+
78
+ result = sentiment_classifier(request.text)
79
+
80
+ return SentimentResponse(
81
+ text=request.text,
82
+ sentiment=result[0]["label"],
83
+ confidence=round(result[0]["score"], 4),
84
+ model_name="cardiffnlp/twitter-roberta-base-sentiment-latest"
85
+ )
86
+
87
+ except Exception as e:
88
+ raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
89
+
90
+ @app.post("/generate", response_model=GenerateResponse)
91
+ async def generate_text(request: TextRequest):
92
+ """テキスト生成エンドポイント"""
93
+ try:
94
+ if text_generator is None:
95
+ raise HTTPException(
96
+ status_code=503,
97
+ detail="Text generation model not loaded. Please try again later."
98
+ )
99
+
100
+ # Spacesの制限を考慮して短めに設定
101
+ max_length = min(request.max_length, 150)
102
+
103
+ result = text_generator(
104
+ request.text,
105
+ max_length=max_length,
106
+ num_return_sequences=1,
107
+ temperature=0.7,
108
+ do_sample=True,
109
+ pad_token_id=text_generator.tokenizer.eos_token_id
110
+ )
111
+
112
+ generated_text = result[0]["generated_text"]
113
+
114
+ return GenerateResponse(
115
+ input_text=request.text,
116
+ generated_text=generated_text,
117
+ model_name="distilgpt2"
118
+ )
119
+
120
+ except Exception as e:
121
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
122
+
123
+ @app.get("/models")
124
+ async def get_models():
125
+ """利用可能なモデル情報を取得"""
126
+ return {
127
+ "sentiment_analysis": {
128
+ "model": "cardiffnlp/twitter-roberta-base-sentiment-latest",
129
+ "status": "loaded" if sentiment_classifier else "not loaded"
130
+ },
131
+ "text_generation": {
132
+ "model": "distilgpt2",
133
+ "status": "loaded" if text_generator else "not loaded"
134
+ },
135
+ "platform": "Hugging Face Spaces"
136
+ }
137
+
138
+ @app.get("/health")
139
+ async def health_check():
140
+ """詳細なヘルスチェック"""
141
+ return {
142
+ "status": "healthy",
143
+ "models": {
144
+ "sentiment": sentiment_classifier is not None,
145
+ "generation": text_generator is not None
146
+ },
147
+ "memory_usage": "optimized for Spaces"
148
+ }
149
+
150
+ # Spaces用の追加設定
151
+ if __name__ == "__main__":
152
+ import uvicorn
153
+ # Spacesでは通常ポート7860を使用
154
+ port = int(os.environ.get("PORT", 7860))
155
+ uvicorn.run(app, host="0.0.0.0", port=port)