Ali2206 commited on
Commit
7757822
·
verified ·
1 Parent(s): 61c414a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -15
app.py CHANGED
@@ -1,13 +1,15 @@
 
1
  import os
2
  import sys
3
  import json
4
  import shutil
5
- from fastapi import FastAPI, UploadFile, File, HTTPException
6
- from fastapi.responses import JSONResponse, FileResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from typing import List, Dict, Optional
9
  import torch
10
  from datetime import datetime
 
11
 
12
  # Configuration
13
  persistent_dir = "/data/hf_cache"
@@ -31,10 +33,23 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
31
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
32
  sys.path.insert(0, src_path)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Initialize FastAPI app
35
  app = FastAPI(
36
- title="Clinical Patient Support System API",
37
- description="API for analyzing medical documents",
38
  version="1.0.0"
39
  )
40
 
@@ -59,52 +74,98 @@ async def startup_event():
59
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
60
 
61
  def init_agent():
62
- """Initialize and return the TxAgent instance."""
63
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
64
  if not os.path.exists(tool_path):
65
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
66
 
67
- from txagent.txagent import TxAgent
68
  agent = TxAgent(
69
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
70
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
71
  tool_files_dict={"new_tool": tool_path},
 
 
72
  force_finish=True,
73
  enable_checker=True,
74
  step_rag_num=4,
75
- seed=100,
76
- use_vllm=False # Disable vLLM for Hugging Face Spaces
77
  )
78
  agent.init_model()
79
  return agent
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @app.post("/analyze")
82
  async def analyze_document(file: UploadFile = File(...)):
83
- """Analyze a medical document and return results."""
84
  try:
85
  # Save the uploaded file temporarily
86
  temp_path = os.path.join(file_cache_dir, file.filename)
87
  with open(temp_path, "wb") as f:
88
  f.write(await file.read())
89
 
90
- # Process the file and generate response
91
- result = agent.process_document(temp_path)
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Clean up
94
  os.remove(temp_path)
95
 
96
  return JSONResponse({
97
  "status": "success",
98
- "result": result,
 
99
  "timestamp": datetime.now().isoformat()
100
  })
101
-
102
  except Exception as e:
103
  raise HTTPException(status_code=500, detail=str(e))
104
 
105
  @app.get("/status")
106
  async def service_status():
107
- """Check service status."""
108
  return {
109
  "status": "running",
110
  "version": "1.0.0",
@@ -114,4 +175,4 @@ async def service_status():
114
 
115
  if __name__ == "__main__":
116
  import uvicorn
117
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ # app.py - FastAPI application
2
  import os
3
  import sys
4
  import json
5
  import shutil
6
+ from fastapi import FastAPI, HTTPException, UploadFile, File
7
+ from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from typing import List, Dict, Optional
10
  import torch
11
  from datetime import datetime
12
+ from pydantic import BaseModel
13
 
14
  # Configuration
15
  persistent_dir = "/data/hf_cache"
 
33
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
34
  sys.path.insert(0, src_path)
35
 
36
+ # Request models
37
+ class ChatRequest(BaseModel):
38
+ message: str
39
+ temperature: float = 0.7
40
+ max_new_tokens: int = 512
41
+ history: Optional[List[Dict]] = None
42
+
43
+ class MultistepRequest(BaseModel):
44
+ message: str
45
+ temperature: float = 0.7
46
+ max_new_tokens: int = 512
47
+ max_round: int = 5
48
+
49
  # Initialize FastAPI app
50
  app = FastAPI(
51
+ title="TxAgent API",
52
+ description="API for TxAgent medical document analysis",
53
  version="1.0.0"
54
  )
55
 
 
74
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
75
 
76
  def init_agent():
77
+ """Initialize and return the TxAgent instance"""
78
  tool_path = os.path.join(tool_cache_dir, "new_tool.json")
79
  if not os.path.exists(tool_path):
80
  shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
81
 
 
82
  agent = TxAgent(
83
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
84
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
85
  tool_files_dict={"new_tool": tool_path},
86
+ enable_finish=True,
87
+ enable_rag=False,
88
  force_finish=True,
89
  enable_checker=True,
90
  step_rag_num=4,
91
+ seed=100
 
92
  )
93
  agent.init_model()
94
  return agent
95
 
96
+ @app.post("/chat")
97
+ async def chat_endpoint(request: ChatRequest):
98
+ """Handle chat conversations"""
99
+ try:
100
+ response = agent.chat(
101
+ message=request.message,
102
+ history=request.history,
103
+ temperature=request.temperature,
104
+ max_new_tokens=request.max_new_tokens
105
+ )
106
+ return JSONResponse({
107
+ "status": "success",
108
+ "response": response,
109
+ "timestamp": datetime.now().isoformat()
110
+ })
111
+ except Exception as e:
112
+ raise HTTPException(status_code=500, detail=str(e))
113
+
114
+ @app.post("/multistep")
115
+ async def multistep_endpoint(request: MultistepRequest):
116
+ """Run multi-step reasoning"""
117
+ try:
118
+ response = agent.run_multistep_agent(
119
+ message=request.message,
120
+ temperature=request.temperature,
121
+ max_new_tokens=request.max_new_tokens,
122
+ max_round=request.max_round
123
+ )
124
+ return JSONResponse({
125
+ "status": "success",
126
+ "response": response,
127
+ "timestamp": datetime.now().isoformat()
128
+ })
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=str(e))
131
+
132
  @app.post("/analyze")
133
  async def analyze_document(file: UploadFile = File(...)):
134
+ """Analyze a medical document"""
135
  try:
136
  # Save the uploaded file temporarily
137
  temp_path = os.path.join(file_cache_dir, file.filename)
138
  with open(temp_path, "wb") as f:
139
  f.write(await file.read())
140
 
141
+ # Process the document
142
+ text = agent.extract_text_from_file(temp_path)
143
+ analysis = agent.analyze_text(text)
144
+
145
+ # Generate report
146
+ report_path = os.path.join(report_dir, f"{file.filename}.json")
147
+ with open(report_path, "w") as f:
148
+ json.dump({
149
+ "filename": file.filename,
150
+ "analysis": analysis,
151
+ "timestamp": datetime.now().isoformat()
152
+ }, f)
153
 
154
  # Clean up
155
  os.remove(temp_path)
156
 
157
  return JSONResponse({
158
  "status": "success",
159
+ "analysis": analysis,
160
+ "report_path": report_path,
161
  "timestamp": datetime.now().isoformat()
162
  })
 
163
  except Exception as e:
164
  raise HTTPException(status_code=500, detail=str(e))
165
 
166
  @app.get("/status")
167
  async def service_status():
168
+ """Check service status"""
169
  return {
170
  "status": "running",
171
  "version": "1.0.0",
 
175
 
176
  if __name__ == "__main__":
177
  import uvicorn
178
+ uvicorn.run(app, host="0.0.0.0", port=8000)