Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import sys
|
4 |
import json
|
5 |
import logging
|
6 |
import re
|
7 |
import hashlib
|
|
|
|
|
8 |
from datetime import datetime
|
9 |
from typing import List, Dict, Optional, Tuple
|
10 |
from enum import Enum
|
11 |
|
12 |
-
from fastapi import FastAPI, HTTPException
|
13 |
-
from fastapi.responses import StreamingResponse
|
14 |
from fastapi.middleware.cors import CORSMiddleware
|
15 |
from pydantic import BaseModel
|
16 |
import asyncio
|
17 |
-
from fastapi import Query
|
18 |
from bson import ObjectId
|
|
|
|
|
|
|
|
|
19 |
from txagent.txagent import TxAgent
|
20 |
from db.mongo import get_mongo_client
|
21 |
|
@@ -24,7 +28,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
|
|
24 |
logger = logging.getLogger("TxAgentAPI")
|
25 |
|
26 |
# App
|
27 |
-
app = FastAPI(title="TxAgent API", version="2.
|
28 |
|
29 |
app.add_middleware(
|
30 |
CORSMiddleware,
|
@@ -32,7 +36,7 @@ app.add_middleware(
|
|
32 |
allow_methods=["*"], allow_headers=["*"]
|
33 |
)
|
34 |
|
35 |
-
# Pydantic
|
36 |
class ChatRequest(BaseModel):
|
37 |
message: str
|
38 |
temperature: float = 0.7
|
@@ -40,6 +44,16 @@ class ChatRequest(BaseModel):
|
|
40 |
history: Optional[List[Dict]] = None
|
41 |
format: Optional[str] = "clean"
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# Enums
|
44 |
class RiskLevel(str, Enum):
|
45 |
NONE = "none"
|
@@ -243,6 +257,39 @@ async def analyze_all_patients():
|
|
243 |
await analyze_patient(patient)
|
244 |
await asyncio.sleep(0.1)
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
@app.on_event("startup")
|
247 |
async def startup_event():
|
248 |
global agent, patients_collection, analysis_collection, alerts_collection
|
@@ -276,7 +323,8 @@ async def status():
|
|
276 |
return {
|
277 |
"status": "running",
|
278 |
"timestamp": datetime.utcnow().isoformat(),
|
279 |
-
"version": "2.
|
|
|
280 |
}
|
281 |
|
282 |
@app.get("/patients/analysis-results")
|
@@ -342,3 +390,89 @@ async def chat_stream_endpoint(request: ChatRequest):
|
|
342 |
yield f"⚠️ Error: {e}"
|
343 |
|
344 |
return StreamingResponse(token_stream(), media_type="text/plain")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import json
|
4 |
import logging
|
5 |
import re
|
6 |
import hashlib
|
7 |
+
import io
|
8 |
+
import base64
|
9 |
from datetime import datetime
|
10 |
from typing import List, Dict, Optional, Tuple
|
11 |
from enum import Enum
|
12 |
|
13 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Query
|
14 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
15 |
from fastapi.middleware.cors import CORSMiddleware
|
16 |
from pydantic import BaseModel
|
17 |
import asyncio
|
|
|
18 |
from bson import ObjectId
|
19 |
+
import speech_recognition as sr
|
20 |
+
from gtts import gTTS
|
21 |
+
from pydub import AudioSegment
|
22 |
+
from pydub.playback import play
|
23 |
from txagent.txagent import TxAgent
|
24 |
from db.mongo import get_mongo_client
|
25 |
|
|
|
28 |
logger = logging.getLogger("TxAgentAPI")
|
29 |
|
30 |
# App
|
31 |
+
app = FastAPI(title="TxAgent API", version="2.3.0") # Updated version for voice support
|
32 |
|
33 |
app.add_middleware(
|
34 |
CORSMiddleware,
|
|
|
36 |
allow_methods=["*"], allow_headers=["*"]
|
37 |
)
|
38 |
|
39 |
+
# Pydantic Models
|
40 |
class ChatRequest(BaseModel):
|
41 |
message: str
|
42 |
temperature: float = 0.7
|
|
|
44 |
history: Optional[List[Dict]] = None
|
45 |
format: Optional[str] = "clean"
|
46 |
|
47 |
+
class VoiceInputRequest(BaseModel):
|
48 |
+
audio_format: str = "wav"
|
49 |
+
language: str = "en-US"
|
50 |
+
|
51 |
+
class VoiceOutputRequest(BaseModel):
|
52 |
+
text: str
|
53 |
+
language: str = "en"
|
54 |
+
slow: bool = False
|
55 |
+
return_format: str = "mp3" # mp3 or base64
|
56 |
+
|
57 |
# Enums
|
58 |
class RiskLevel(str, Enum):
|
59 |
NONE = "none"
|
|
|
257 |
await analyze_patient(patient)
|
258 |
await asyncio.sleep(0.1)
|
259 |
|
260 |
+
def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
|
261 |
+
"""Convert speech to text using Google's speech recognition"""
|
262 |
+
recognizer = sr.Recognizer()
|
263 |
+
|
264 |
+
try:
|
265 |
+
# Convert bytes to AudioFile
|
266 |
+
with io.BytesIO(audio_data) as audio_file:
|
267 |
+
with sr.AudioFile(audio_file) as source:
|
268 |
+
audio = recognizer.record(source)
|
269 |
+
text = recognizer.recognize_google(audio, language=language)
|
270 |
+
return text
|
271 |
+
except sr.UnknownValueError:
|
272 |
+
logger.error("Google Speech Recognition could not understand audio")
|
273 |
+
raise HTTPException(status_code=400, detail="Could not understand audio")
|
274 |
+
except sr.RequestError as e:
|
275 |
+
logger.error(f"Could not request results from Google Speech Recognition service; {e}")
|
276 |
+
raise HTTPException(status_code=503, detail="Speech recognition service unavailable")
|
277 |
+
except Exception as e:
|
278 |
+
logger.error(f"Error in speech recognition: {e}")
|
279 |
+
raise HTTPException(status_code=500, detail="Error processing speech")
|
280 |
+
|
281 |
+
def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
|
282 |
+
"""Convert text to speech using gTTS and return as MP3 bytes"""
|
283 |
+
try:
|
284 |
+
tts = gTTS(text=text, lang=language, slow=slow)
|
285 |
+
mp3_fp = io.BytesIO()
|
286 |
+
tts.write_to_fp(mp3_fp)
|
287 |
+
mp3_fp.seek(0)
|
288 |
+
return mp3_fp.read()
|
289 |
+
except Exception as e:
|
290 |
+
logger.error(f"Error in text-to-speech conversion: {e}")
|
291 |
+
raise HTTPException(status_code=500, detail="Error generating speech")
|
292 |
+
|
293 |
@app.on_event("startup")
|
294 |
async def startup_event():
|
295 |
global agent, patients_collection, analysis_collection, alerts_collection
|
|
|
323 |
return {
|
324 |
"status": "running",
|
325 |
"timestamp": datetime.utcnow().isoformat(),
|
326 |
+
"version": "2.3.0",
|
327 |
+
"features": ["chat", "voice-input", "voice-output", "patient-analysis"]
|
328 |
}
|
329 |
|
330 |
@app.get("/patients/analysis-results")
|
|
|
390 |
yield f"⚠️ Error: {e}"
|
391 |
|
392 |
return StreamingResponse(token_stream(), media_type="text/plain")
|
393 |
+
|
394 |
+
@app.post("/voice/transcribe")
|
395 |
+
async def transcribe_voice(
|
396 |
+
audio: UploadFile = File(...),
|
397 |
+
language: str = Query("en-US", description="Language code for speech recognition")
|
398 |
+
):
|
399 |
+
"""Convert speech to text"""
|
400 |
+
try:
|
401 |
+
# Read audio file
|
402 |
+
audio_data = await audio.read()
|
403 |
+
|
404 |
+
# Validate audio format
|
405 |
+
if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
|
406 |
+
raise HTTPException(status_code=400, detail="Unsupported audio format")
|
407 |
+
|
408 |
+
# Convert speech to text
|
409 |
+
text = recognize_speech(audio_data, language)
|
410 |
+
|
411 |
+
return {"text": text}
|
412 |
+
|
413 |
+
except HTTPException:
|
414 |
+
raise
|
415 |
+
except Exception as e:
|
416 |
+
logger.error(f"Error in voice transcription: {e}")
|
417 |
+
raise HTTPException(status_code=500, detail="Error processing voice input")
|
418 |
+
|
419 |
+
@app.post("/voice/synthesize")
|
420 |
+
async def synthesize_voice(request: VoiceOutputRequest):
|
421 |
+
"""Convert text to speech"""
|
422 |
+
try:
|
423 |
+
# Generate speech from text
|
424 |
+
audio_data = text_to_speech(request.text, request.language, request.slow)
|
425 |
+
|
426 |
+
if request.return_format == "base64":
|
427 |
+
# Return as base64 encoded string
|
428 |
+
return {"audio": base64.b64encode(audio_data).decode('utf-8')}
|
429 |
+
else:
|
430 |
+
# Return as MP3 file
|
431 |
+
return StreamingResponse(
|
432 |
+
io.BytesIO(audio_data),
|
433 |
+
media_type="audio/mpeg",
|
434 |
+
headers={"Content-Disposition": "attachment; filename=speech.mp3"}
|
435 |
+
)
|
436 |
+
|
437 |
+
except HTTPException:
|
438 |
+
raise
|
439 |
+
except Exception as e:
|
440 |
+
logger.error(f"Error in voice synthesis: {e}")
|
441 |
+
raise HTTPException(status_code=500, detail="Error generating voice output")
|
442 |
+
|
443 |
+
@app.post("/voice/chat")
|
444 |
+
async def voice_chat_endpoint(
|
445 |
+
audio: UploadFile = File(...),
|
446 |
+
language: str = Query("en-US", description="Language code for speech recognition"),
|
447 |
+
temperature: float = Query(0.7, ge=0.1, le=1.0),
|
448 |
+
max_new_tokens: int = Query(512, ge=50, le=1024)
|
449 |
+
):
|
450 |
+
"""Complete voice chat interaction (speech-to-text -> AI -> text-to-speech)"""
|
451 |
+
try:
|
452 |
+
# Step 1: Convert speech to text
|
453 |
+
audio_data = await audio.read()
|
454 |
+
user_message = recognize_speech(audio_data, language)
|
455 |
+
|
456 |
+
# Step 2: Get AI response
|
457 |
+
chat_response = agent.chat(
|
458 |
+
message=user_message,
|
459 |
+
history=[],
|
460 |
+
temperature=temperature,
|
461 |
+
max_new_tokens=max_new_tokens
|
462 |
+
)
|
463 |
+
|
464 |
+
# Step 3: Convert response to speech
|
465 |
+
audio_data = text_to_speech(chat_response, language.split('-')[0])
|
466 |
+
|
467 |
+
# Return as MP3 file
|
468 |
+
return StreamingResponse(
|
469 |
+
io.BytesIO(audio_data),
|
470 |
+
media_type="audio/mpeg",
|
471 |
+
headers={"Content-Disposition": "attachment; filename=response.mp3"}
|
472 |
+
)
|
473 |
+
|
474 |
+
except HTTPException:
|
475 |
+
raise
|
476 |
+
except Exception as e:
|
477 |
+
logger.error(f"Error in voice chat: {e}")
|
478 |
+
raise HTTPException(status_code=500, detail="Error processing voice chat")
|