tuan243 commited on
Commit
433d5e2
·
verified ·
1 Parent(s): 4a48a1c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -62
main.py CHANGED
@@ -1,26 +1,47 @@
1
- from fastapi import FastAPI
2
  from fastapi.responses import JSONResponse
3
  import firebase_admin
4
  from firebase_admin import credentials, firestore
5
- from transformers import pipeline
6
  from pydantic import BaseModel
7
  import os
8
  from huggingface_hub import login
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # AutoModelForCausalLM,
 
 
 
 
10
 
11
  # Load Firebase
12
- cred = credentials.Certificate("firebase_config.json")
13
- firebase_admin.initialize_app(cred)
14
- db = firestore.client()
 
 
 
 
 
15
 
16
  # Đăng nhập vào Hugging Face (nếu cần)
17
  HF_TOKEN = os.getenv("HF_TOKEN")
18
  if HF_TOKEN:
19
- login(HF_TOKEN)
20
-
21
- # Load AI Model
22
- tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large",token=HF_TOKEN)
23
- ai_model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large",token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  app = FastAPI()
26
 
@@ -36,75 +57,91 @@ class FocusHistoryRequest(BaseModel):
36
 
37
  class AIRequest(BaseModel):
38
  user_id: str
 
39
 
40
  # API cập nhật bios
41
  @app.post("/update_bios")
42
  async def update_bios(request: UpdateBiosRequest):
43
- user_ref = db.collection("user_profiles").document(request.user_id)
44
- user_ref.set({"bios": request.bios}, merge=True)
45
- return {"message": "Cập nhật bios thành công"}
 
 
 
 
46
 
47
  # API thêm lịch sử focus
48
  @app.post("/add_focus_history")
49
  async def add_focus_history(request: FocusHistoryRequest):
50
- user_ref = db.collection("user_profiles").document(request.user_id)
51
- user_doc = user_ref.get()
52
- data = user_doc.to_dict() or {}
53
- focus_history = data.get("focus_history", [])
54
- focus_history.append({"time_start": request.time_start, "total_time": request.total_time})
55
- user_ref.set({"focus_history": focus_history}, merge=True)
56
- return {"message": "Thêm lịch sử focus thành công"}
 
 
 
 
57
 
58
  # API lấy dữ liệu người dùng
59
  @app.get("/get_user_data")
60
  async def get_user_data(user_id: str):
61
- user_doc = db.collection("user_profiles").document(user_id).get()
62
- data = user_doc.to_dict() or {}
63
- return {
64
- "bios": data.get("bios", "Chưa có bios."),
65
- "focus_history": data.get("focus_history", [])
66
- }
 
 
 
 
67
 
68
  # API AI tư vấn
69
  @app.post("/ai_personal_advice")
70
  async def ai_personal_advice(request: AIRequest):
71
- user_doc = db.collection("user_profiles").document(request.user_id).get()
72
- data = user_doc.to_dict() or {}
73
- bios = data.get("bios", "Chưa bios.")
74
- focus_history = data.get("focus_history", [])
75
- focus_text = "\n".join([f"- {f['time_start']}: {f['total_time']} phút" for f in focus_history])
76
-
77
- prompt = f"""
78
- Thông tin người dùng:
79
- - Bios: {bios}
80
- - Lịch sử focus:
81
- {focus_text}
82
-
83
- Hãy vấn cách cải thiện hiệu suất làm việc dựa trên thông tin trên.
84
- """
85
- input_ids = tokenizer(prompt, return_tensors="pt")
86
- response = ai_model.generate(**input_ids, max_new_tokens=500)
87
- # response = ai_model(prompt, max_length=200)
88
- return {"advice": tokenizer.decode(response[0], skip_special_tokens=True)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Trang chủ
91
  @app.get("/")
92
  async def home():
93
  return JSONResponse(content={"message": "Welcome to the Recommendation API!"})
94
-
95
-
96
-
97
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
98
-
99
- # # Kiểm tra biến môi trường HF_TOKEN
100
- # import os
101
- # HF_TOKEN = os.getenv("HF_TOKEN")
102
- # print("HF_TOKEN:", HF_TOKEN is not None)
103
-
104
- # # Load model
105
- # model_name = "google/flan-t5-small"
106
- # tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
107
- # ai_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
108
-
109
- # print("Model loaded successfully!")
110
-
 
1
+ from fastapi import FastAPI, HTTPException
2
  from fastapi.responses import JSONResponse
3
  import firebase_admin
4
  from firebase_admin import credentials, firestore
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  from pydantic import BaseModel
7
  import os
8
  from huggingface_hub import login
9
+ import traceback
10
+ import logging
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
  # Load Firebase
16
+ try:
17
+ cred = credentials.Certificate("firebase_config.json")
18
+ firebase_admin.initialize_app(cred)
19
+ db = firestore.client()
20
+ logging.info("Firebase initialized successfully.")
21
+ except Exception as e:
22
+ logging.error(f"Error initializing Firebase: {e}")
23
+ db = None
24
 
25
  # Đăng nhập vào Hugging Face (nếu cần)
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
  if HF_TOKEN:
28
+ try:
29
+ login(token=HF_TOKEN) # Pass HF_TOKEN as keyword argument
30
+ logging.info("Hugging Face login successful.")
31
+ except Exception as e:
32
+ logging.error(f"Error logging into Hugging Face: {e}")
33
+
34
+ # Initialize tokenizer and model outside the request handler
35
+ tokenizer = None
36
+ ai_model = None
37
+
38
+ try:
39
+ # Load AI Model
40
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-large") # Removed token argument here
41
+ ai_model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large") # Removed token argument here
42
+ logging.info("Tokenizer and model loaded successfully.")
43
+ except Exception as e:
44
+ logging.error(f"Error loading tokenizer/model: {e}")
45
 
46
  app = FastAPI()
47
 
 
57
 
58
  class AIRequest(BaseModel):
59
  user_id: str
60
+ bios: str
61
 
62
  # API cập nhật bios
63
  @app.post("/update_bios")
64
  async def update_bios(request: UpdateBiosRequest):
65
+ try:
66
+ user_ref = db.collection("user_profiles").document(request.user_id)
67
+ user_ref.set({"bios": request.bios}, merge=True)
68
+ return {"message": "Cập nhật bios thành công"}
69
+ except Exception as e:
70
+ logging.error(f"Error in /update_bios: {e}\n{traceback.format_exc()}")
71
+ raise HTTPException(status_code=500, detail=f"Error updating bios: {e}")
72
 
73
  # API thêm lịch sử focus
74
  @app.post("/add_focus_history")
75
  async def add_focus_history(request: FocusHistoryRequest):
76
+ try:
77
+ user_ref = db.collection("user_profiles").document(request.user_id)
78
+ user_doc = user_ref.get()
79
+ data = user_doc.to_dict() or {}
80
+ focus_history = data.get("focus_history", [])
81
+ focus_history.append({"time_start": request.time_start, "total_time": request.total_time})
82
+ user_ref.set({"focus_history": focus_history}, merge=True)
83
+ return {"message": "Thêm lịch sử focus thành công"}
84
+ except Exception as e:
85
+ logging.error(f"Error in /add_focus_history: {e}\n{traceback.format_exc()}")
86
+ raise HTTPException(status_code=500, detail=f"Error adding focus history: {e}")
87
 
88
  # API lấy dữ liệu người dùng
89
  @app.get("/get_user_data")
90
  async def get_user_data(user_id: str):
91
+ try:
92
+ user_doc = db.collection("user_profiles").document(user_id).get()
93
+ data = user_doc.to_dict() or {}
94
+ return {
95
+ "bios": data.get("bios", "Chưa có bios."),
96
+ "focus_history": data.get("focus_history", [])
97
+ }
98
+ except Exception as e:
99
+ logging.error(f"Error in /get_user_data: {e}\n{traceback.format_exc()}")
100
+ raise HTTPException(status_code=500, detail=f"Error getting user data: {e}")
101
 
102
  # API AI tư vấn
103
  @app.post("/ai_personal_advice")
104
  async def ai_personal_advice(request: AIRequest):
105
+ try:
106
+ if tokenizer is None or ai_model is None:
107
+ logging.error("Tokenizer or AI model not loaded.")
108
+ raise HTTPException(status_code=500, detail="Tokenizer or AI model not loaded.")
109
+
110
+ if db is None:
111
+ logging.error("Firebase not initialized.")
112
+ raise HTTPException(status_code=500, detail="Firebase not initialized.")
113
+
114
+ user_doc = db.collection("user_profiles").document(request.user_id).get()
115
+ if not user_doc.exists:
116
+ logging.warning(f"User profile not found for user_id: {request.user_id}")
117
+ return {"advice": "Không tìm thấy thông tin người dùng."}
118
+
119
+ data = user_doc.to_dict() or {}
120
+ bios = request.bios if request.bios else data.get("bios", "Chưa có bios.")
121
+ focus_history = data.get("focus_history", [])
122
+ focus_text = "\n".join([f"- {f['time_start']}: {f['total_time']} phút" for f in focus_history])
123
+
124
+ prompt = f"""
125
+ Thông tin người dùng:
126
+ - Bios: {bios}
127
+ - Lịch sử focus:
128
+ {focus_text}
129
+
130
+ Hãy tư vấn cách cải thiện hiệu suất làm việc dựa trên thông tin trên.
131
+ """
132
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
133
+ response = ai_model.generate(**input_ids, max_new_tokens=500)
134
+ advice = tokenizer.decode(response[0], skip_special_tokens=True)
135
+
136
+ return {"advice": advice}
137
+
138
+ except Exception as e:
139
+ error_message = f"Error in /ai_personal_advice: {e}"
140
+ logging.error(error_message)
141
+ logging.error(traceback.format_exc())
142
+ raise HTTPException(status_code=500, detail=error_message)
143
 
144
  # Trang chủ
145
  @app.get("/")
146
  async def home():
147
  return JSONResponse(content={"message": "Welcome to the Recommendation API!"})