Spaces:
Sleeping
Sleeping
File size: 8,885 Bytes
a3c7b61 81c40fc a3c7b61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.models import ChatRequest
from backend.llm_utils import sanitize_history, route_message, get_reply
from backend.rag_utils import get_user_data
from backend.models import ChatRequest, SummaryRequest
from backend.llm_utils import sanitize_history, route_message, get_reply, generate_chat_summary
from backend.voice.stt import transcribe_audio
from backend.voice.tts import synthesize_speech
from fastapi import UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
import json
import io
import base64
from backend.cache_utils import get_cached_user_data, cache_user_data, cleanup_expired_cache
import json
import os
from backend.credentials import setup_google_credentials
setup_google_credentials()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/chat")
async def chat_endpoint(req: ChatRequest):
user_message = req.message
history = req.history or []
user_id = req.uid
if not user_message:
return {"error": "message is required"}
user_data = {}
if user_id:
try:
user_data = get_user_data(user_id)
except Exception as e:
user_data = {}
try:
route = await route_message(user_message)
simple_history = sanitize_history(history)
simple_history.append({"role": "user", "content": user_message})
reply = await get_reply(route, simple_history, user_data, user_id)
if not reply:
reply = "I'm here to help with your wellness journey! What would you like to work on today?"
return {"reply": reply}
except Exception as e:
return {"reply": "Sorry, I'm having trouble right now. Could you try again in a moment?"}
import time
import asyncio
@app.post("/summarize")
async def summarize_endpoint(req: SummaryRequest):
start_time = time.time()
try:
messages = req.messages
if not messages:
print(f"[TIMING] Summary - No messages: {(time.time() - start_time):.2f}ms")
return {"summary": "New Chat"}
import_start = time.time()
from backend.llm_utils import generate_chat_summary
print(f"[TIMING] Summary - Import: {(time.time() - import_start):.2f}s")
summary_start = time.time()
summary = await generate_chat_summary(messages)
print(f"[TIMING] Summary - Generation: {(time.time() - summary_start):.2f}ms")
print(f"[TIMING] Summary - Total: {(time.time() - start_time):.2f}ms")
return {"summary": summary}
except Exception as e:
print(f"[TIMING] Summary - Error after {(time.time() - start_time):.2f}ms:", e)
return {"summary": "New Chat"}
@app.post("/voice-chat")
async def voice_chat_endpoint(
file: UploadFile = File(...),
history: str = Form(None),
uid: str = Form(None),
voice: str = Form("alloy")
):
start_time = time.time()
try:
# Step 1: File reading
file_start = time.time()
audio_bytes = await file.read()
print(f"[TIMING] Voice - File read: {(time.time() - file_start) :.2f}ms ({len(audio_bytes)} bytes)")
# Step 2: Start transcription immediately
transcription_start = time.time()
transcription_task = asyncio.create_task(transcribe_audio(audio_bytes, ".m4a"))
# Step 3: Prepare other data in parallel
user_data_task = None
if uid:
user_data_start = time.time()
user_data_task = asyncio.create_task(get_user_data_async(uid))
print(f"[TIMING] Voice - User data task started: {(time.time() - user_data_start):.2f}ms")
# Step 4: Parse history while transcription runs
history_start = time.time()
simple_history = json.loads(history) if history else []
print(f"[TIMING] Voice - History parsing: {(time.time() - history_start):.2f}ms ({len(simple_history)} messages)")
# Step 5: Wait for transcription
transcription_wait_start = time.time()
user_message = await transcription_task
print(f"[TIMING] Voice - Transcription total: {(time.time() - transcription_start):.2f}ms")
print(f"[TIMING] Voice - Transcription wait: {(time.time() - transcription_wait_start):.2f}ms")
print("WHISPER transcript:", repr(user_message))
if not user_message.strip():
print(f"[TIMING] Voice - Empty transcript, returning early: {(time.time() - start_time) :.2f}ms")
return {"user_transcript": "", "reply": "I didn't catch that", "audio_base64": ""}
# Step 6: Get user data (if task was started)
user_data = {}
if user_data_task:
user_data_wait_start = time.time()
try:
user_data = await user_data_task
print(f"[TIMING] Voice - User data retrieval: {(time.time() - user_data_wait_start) :.2f}ms")
except Exception as e:
print(f"[TIMING] Voice - User data error after {(time.time() - user_data_wait_start) :.2f}ms: {e}")
user_data = {}
# Step 7: Process through your logic
history_append_start = time.time()
simple_history.append({"role": "user", "content": user_message})
print(f"[TIMING] Voice - History append: {(time.time() - history_append_start) :.2f}ms")
# Step 8: Run routing
routing_start = time.time()
route_task = asyncio.create_task(route_message(user_message))
route = await route_task
print(f"[TIMING] Voice - Message routing: {(time.time() - routing_start):.2f}ms (route: {route})")
# Step 9: Generate reply
reply_start = time.time()
reply = await get_reply(route, simple_history, user_data, uid)
if not reply:
reply = "I'm here to help with your wellness journey! What would you like to work on today?"
print(f"[TIMING] Voice - Reply generation: {(time.time() - reply_start) :.2f}ms")
# Step 10: Generate speech
tts_start = time.time()
audio_data = await synthesize_speech(reply, voice)
print(f"[TIMING] Voice - TTS generation: {(time.time() - tts_start):.2f}ms")
# Step 11: Base64 encoding
encoding_start = time.time()
base64_audio = base64.b64encode(audio_data).decode()
print(f"[TIMING] Voice - Base64 encoding: {(time.time() - encoding_start) :.2f}ms")
# Total timing
total_time = (time.time() - start_time)
print(f"[TIMING] Voice - TOTAL PIPELINE: {total_time:.2f}ms")
# Breakdown summary
print(f"[TIMING] Voice - BREAKDOWN:")
print(f" • File read: {(file_start - start_time) :.2f}ms")
print(f" • Transcription: {(time.time() - transcription_start) :.2f}ms")
print(f" • Routing: {(time.time() - routing_start) :.2f}ms")
print(f" • Reply: {(time.time() - reply_start) :.2f}ms")
print(f" • TTS: {(time.time() - tts_start) :.2f}ms")
return {
"user_transcript": user_message,
"reply": reply,
"audio_base64": base64_audio
}
except Exception as e:
error_time = (time.time() - start_time)
print(f"[TIMING] Voice - ERROR after {error_time:.2f}ms:", e)
return JSONResponse({"error": str(e)}, status_code=500)
# Add async wrapper for get_user_data
async def get_user_data_async(uid: str):
start_time = time.time()
# Try to get from cache first
cached_data = get_cached_user_data(uid)
if cached_data:
print(f"[TIMING] User data (cached): {(time.time() - start_time) :.2f}ms")
return cached_data
# Cache miss - fetch fresh data
print("[CACHE] User data cache miss, fetching fresh data...")
result = get_user_data(uid)
print(f"[TIMING] User data fetch: {(time.time() - start_time) :.2f}ms")
return result
@app.get("/cache/stats")
async def cache_stats_endpoint():
"""Get cache performance statistics"""
from backend.cache_utils import get_cache_stats, cleanup_expired_cache
cleanup_expired_cache() # Clean up while we're at it
stats = get_cache_stats()
return stats
@app.post("/cache/clear")
async def clear_cache_endpoint(user_id: str = None):
"""Clear cache for specific user or all users"""
from backend.cache_utils import clear_user_cache
clear_user_cache(user_id)
return {"message": f"Cache cleared for {'all users' if not user_id else f'user {user_id}'}"}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 3000))
uvicorn.run(app, host="0.0.0.0", port=port)
|