Ali2206 commited on
Commit
069d7f4
·
verified ·
1 Parent(s): 97cff3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -7
app.py CHANGED
@@ -1,21 +1,25 @@
1
-
2
  import os
3
  import sys
4
  import json
5
  import logging
6
  import re
7
  import hashlib
 
 
8
  from datetime import datetime
9
  from typing import List, Dict, Optional, Tuple
10
  from enum import Enum
11
 
12
- from fastapi import FastAPI, HTTPException
13
- from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
16
  import asyncio
17
- from fastapi import Query
18
  from bson import ObjectId
 
 
 
 
19
  from txagent.txagent import TxAgent
20
  from db.mongo import get_mongo_client
21
 
@@ -24,7 +28,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(
24
  logger = logging.getLogger("TxAgentAPI")
25
 
26
  # App
27
- app = FastAPI(title="TxAgent API", version="2.2.1") # Version for hash-based analysis
28
 
29
  app.add_middleware(
30
  CORSMiddleware,
@@ -32,7 +36,7 @@ app.add_middleware(
32
  allow_methods=["*"], allow_headers=["*"]
33
  )
34
 
35
- # Pydantic
36
  class ChatRequest(BaseModel):
37
  message: str
38
  temperature: float = 0.7
@@ -40,6 +44,16 @@ class ChatRequest(BaseModel):
40
  history: Optional[List[Dict]] = None
41
  format: Optional[str] = "clean"
42
 
 
 
 
 
 
 
 
 
 
 
43
  # Enums
44
  class RiskLevel(str, Enum):
45
  NONE = "none"
@@ -243,6 +257,39 @@ async def analyze_all_patients():
243
  await analyze_patient(patient)
244
  await asyncio.sleep(0.1)
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  @app.on_event("startup")
247
  async def startup_event():
248
  global agent, patients_collection, analysis_collection, alerts_collection
@@ -276,7 +323,8 @@ async def status():
276
  return {
277
  "status": "running",
278
  "timestamp": datetime.utcnow().isoformat(),
279
- "version": "2.2.1"
 
280
  }
281
 
282
  @app.get("/patients/analysis-results")
@@ -342,3 +390,89 @@ async def chat_stream_endpoint(request: ChatRequest):
342
  yield f"⚠️ Error: {e}"
343
 
344
  return StreamingResponse(token_stream(), media_type="text/plain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  import json
4
  import logging
5
  import re
6
  import hashlib
7
+ import io
8
+ import base64
9
  from datetime import datetime
10
  from typing import List, Dict, Optional, Tuple
11
  from enum import Enum
12
 
13
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Query
14
+ from fastapi.responses import StreamingResponse, JSONResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel
17
  import asyncio
 
18
  from bson import ObjectId
19
+ import speech_recognition as sr
20
+ from gtts import gTTS
21
+ from pydub import AudioSegment
22
+ from pydub.playback import play
23
  from txagent.txagent import TxAgent
24
  from db.mongo import get_mongo_client
25
 
 
28
  logger = logging.getLogger("TxAgentAPI")
29
 
30
  # App
31
+ app = FastAPI(title="TxAgent API", version="2.3.0") # Updated version for voice support
32
 
33
  app.add_middleware(
34
  CORSMiddleware,
 
36
  allow_methods=["*"], allow_headers=["*"]
37
  )
38
 
39
+ # Pydantic Models
40
  class ChatRequest(BaseModel):
41
  message: str
42
  temperature: float = 0.7
 
44
  history: Optional[List[Dict]] = None
45
  format: Optional[str] = "clean"
46
 
47
+ class VoiceInputRequest(BaseModel):
48
+ audio_format: str = "wav"
49
+ language: str = "en-US"
50
+
51
+ class VoiceOutputRequest(BaseModel):
52
+ text: str
53
+ language: str = "en"
54
+ slow: bool = False
55
+ return_format: str = "mp3" # mp3 or base64
56
+
57
  # Enums
58
  class RiskLevel(str, Enum):
59
  NONE = "none"
 
257
  await analyze_patient(patient)
258
  await asyncio.sleep(0.1)
259
 
260
+ def recognize_speech(audio_data: bytes, language: str = "en-US") -> str:
261
+ """Convert speech to text using Google's speech recognition"""
262
+ recognizer = sr.Recognizer()
263
+
264
+ try:
265
+ # Convert bytes to AudioFile
266
+ with io.BytesIO(audio_data) as audio_file:
267
+ with sr.AudioFile(audio_file) as source:
268
+ audio = recognizer.record(source)
269
+ text = recognizer.recognize_google(audio, language=language)
270
+ return text
271
+ except sr.UnknownValueError:
272
+ logger.error("Google Speech Recognition could not understand audio")
273
+ raise HTTPException(status_code=400, detail="Could not understand audio")
274
+ except sr.RequestError as e:
275
+ logger.error(f"Could not request results from Google Speech Recognition service; {e}")
276
+ raise HTTPException(status_code=503, detail="Speech recognition service unavailable")
277
+ except Exception as e:
278
+ logger.error(f"Error in speech recognition: {e}")
279
+ raise HTTPException(status_code=500, detail="Error processing speech")
280
+
281
+ def text_to_speech(text: str, language: str = "en", slow: bool = False) -> bytes:
282
+ """Convert text to speech using gTTS and return as MP3 bytes"""
283
+ try:
284
+ tts = gTTS(text=text, lang=language, slow=slow)
285
+ mp3_fp = io.BytesIO()
286
+ tts.write_to_fp(mp3_fp)
287
+ mp3_fp.seek(0)
288
+ return mp3_fp.read()
289
+ except Exception as e:
290
+ logger.error(f"Error in text-to-speech conversion: {e}")
291
+ raise HTTPException(status_code=500, detail="Error generating speech")
292
+
293
  @app.on_event("startup")
294
  async def startup_event():
295
  global agent, patients_collection, analysis_collection, alerts_collection
 
323
  return {
324
  "status": "running",
325
  "timestamp": datetime.utcnow().isoformat(),
326
+ "version": "2.3.0",
327
+ "features": ["chat", "voice-input", "voice-output", "patient-analysis"]
328
  }
329
 
330
  @app.get("/patients/analysis-results")
 
390
  yield f"⚠️ Error: {e}"
391
 
392
  return StreamingResponse(token_stream(), media_type="text/plain")
393
+
394
+ @app.post("/voice/transcribe")
395
+ async def transcribe_voice(
396
+ audio: UploadFile = File(...),
397
+ language: str = Query("en-US", description="Language code for speech recognition")
398
+ ):
399
+ """Convert speech to text"""
400
+ try:
401
+ # Read audio file
402
+ audio_data = await audio.read()
403
+
404
+ # Validate audio format
405
+ if not audio.filename.lower().endswith(('.wav', '.mp3', '.ogg', '.flac')):
406
+ raise HTTPException(status_code=400, detail="Unsupported audio format")
407
+
408
+ # Convert speech to text
409
+ text = recognize_speech(audio_data, language)
410
+
411
+ return {"text": text}
412
+
413
+ except HTTPException:
414
+ raise
415
+ except Exception as e:
416
+ logger.error(f"Error in voice transcription: {e}")
417
+ raise HTTPException(status_code=500, detail="Error processing voice input")
418
+
419
+ @app.post("/voice/synthesize")
420
+ async def synthesize_voice(request: VoiceOutputRequest):
421
+ """Convert text to speech"""
422
+ try:
423
+ # Generate speech from text
424
+ audio_data = text_to_speech(request.text, request.language, request.slow)
425
+
426
+ if request.return_format == "base64":
427
+ # Return as base64 encoded string
428
+ return {"audio": base64.b64encode(audio_data).decode('utf-8')}
429
+ else:
430
+ # Return as MP3 file
431
+ return StreamingResponse(
432
+ io.BytesIO(audio_data),
433
+ media_type="audio/mpeg",
434
+ headers={"Content-Disposition": "attachment; filename=speech.mp3"}
435
+ )
436
+
437
+ except HTTPException:
438
+ raise
439
+ except Exception as e:
440
+ logger.error(f"Error in voice synthesis: {e}")
441
+ raise HTTPException(status_code=500, detail="Error generating voice output")
442
+
443
+ @app.post("/voice/chat")
444
+ async def voice_chat_endpoint(
445
+ audio: UploadFile = File(...),
446
+ language: str = Query("en-US", description="Language code for speech recognition"),
447
+ temperature: float = Query(0.7, ge=0.1, le=1.0),
448
+ max_new_tokens: int = Query(512, ge=50, le=1024)
449
+ ):
450
+ """Complete voice chat interaction (speech-to-text -> AI -> text-to-speech)"""
451
+ try:
452
+ # Step 1: Convert speech to text
453
+ audio_data = await audio.read()
454
+ user_message = recognize_speech(audio_data, language)
455
+
456
+ # Step 2: Get AI response
457
+ chat_response = agent.chat(
458
+ message=user_message,
459
+ history=[],
460
+ temperature=temperature,
461
+ max_new_tokens=max_new_tokens
462
+ )
463
+
464
+ # Step 3: Convert response to speech
465
+ audio_data = text_to_speech(chat_response, language.split('-')[0])
466
+
467
+ # Return as MP3 file
468
+ return StreamingResponse(
469
+ io.BytesIO(audio_data),
470
+ media_type="audio/mpeg",
471
+ headers={"Content-Disposition": "attachment; filename=response.mp3"}
472
+ )
473
+
474
+ except HTTPException:
475
+ raise
476
+ except Exception as e:
477
+ logger.error(f"Error in voice chat: {e}")
478
+ raise HTTPException(status_code=500, detail="Error processing voice chat")