Ali2206 commited on
Commit
7e095f4
·
verified ·
1 Parent(s): 520104a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -30
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import sys
3
  import json
4
  import shutil
 
5
  from fastapi import FastAPI, HTTPException, UploadFile, File
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
@@ -10,26 +11,41 @@ import torch
10
  from datetime import datetime
11
  from pydantic import BaseModel
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Add src directory to Python path
14
  current_dir = os.path.dirname(os.path.abspath(__file__))
15
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
16
  sys.path.insert(0, src_path)
17
 
18
- # Now import TxAgent after adding to path
19
- from txagent.txagent import TxAgent
 
 
 
 
20
 
21
  # Configuration
22
  persistent_dir = "/data/hf_cache"
 
23
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
24
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
25
  file_cache_dir = os.path.join(persistent_dir, "cache")
26
  report_dir = os.path.join(persistent_dir, "reports")
27
 
28
  # Create directories if they don't exist
29
- os.makedirs(model_cache_dir, exist_ok=True)
30
- os.makedirs(tool_cache_dir, exist_ok=True)
31
- os.makedirs(file_cache_dir, exist_ok=True)
32
- os.makedirs(report_dir, exist_ok=True)
33
 
34
  # Set environment variables
35
  os.environ["HF_HOME"] = model_cache_dir
@@ -71,92 +87,116 @@ agent = None
71
  async def startup_event():
72
  global agent
73
  try:
 
74
  agent = init_agent()
75
  logger.info("TxAgent initialized successfully")
76
  except Exception as e:
77
- logger.error(f"Failed to initialize agent: {str(e)}")
78
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
79
 
80
  def init_agent():
81
  """Initialize and return the TxAgent instance"""
82
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
83
- if not os.path.exists(tool_path):
84
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
85
-
86
- agent = TxAgent(
87
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
88
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
89
- tool_files_dict={"new_tool": tool_path},
90
- enable_finish=True,
91
- enable_rag=False,
92
- force_finish=True,
93
- enable_checker=True,
94
- step_rag_num=4,
95
- seed=100
96
- )
97
- agent.init_model()
98
- return agent
 
 
 
 
 
 
 
 
 
 
99
 
100
  @app.post("/chat")
101
  async def chat_endpoint(request: ChatRequest):
102
  """Handle chat conversations"""
103
  try:
 
104
  response = agent.chat(
105
  message=request.message,
106
  history=request.history,
107
  temperature=request.temperature,
108
  max_new_tokens=request.max_new_tokens
109
  )
 
110
  return JSONResponse({
111
  "status": "success",
112
  "response": response,
113
  "timestamp": datetime.now().isoformat()
114
  })
115
  except Exception as e:
 
116
  raise HTTPException(status_code=500, detail=str(e))
117
 
118
  @app.post("/multistep")
119
  async def multistep_endpoint(request: MultistepRequest):
120
  """Run multi-step reasoning"""
121
  try:
 
122
  response = agent.run_multistep_agent(
123
  message=request.message,
124
  temperature=request.temperature,
125
  max_new_tokens=request.max_new_tokens,
126
  max_round=request.max_round
127
  )
 
128
  return JSONResponse({
129
  "status": "success",
130
  "response": response,
131
  "timestamp": datetime.now().isoformat()
132
  })
133
  except Exception as e:
 
134
  raise HTTPException(status_code=500, detail=str(e))
135
 
136
  @app.post("/analyze")
137
  async def analyze_document(file: UploadFile = File(...)):
138
  """Analyze a medical document"""
139
  try:
 
 
140
  # Save the uploaded file temporarily
141
  temp_path = os.path.join(file_cache_dir, file.filename)
142
  with open(temp_path, "wb") as f:
143
  f.write(await file.read())
 
144
 
145
  # Process the document
146
  text = agent.extract_text_from_file(temp_path)
147
  analysis = agent.analyze_text(text)
 
148
 
149
  # Generate report
150
- report_path = os.path.join(report_dir, f"{file.filename}.json")
 
151
  with open(report_path, "w") as f:
152
  json.dump({
153
  "filename": file.filename,
154
  "analysis": analysis,
155
  "timestamp": datetime.now().isoformat()
156
- }, f)
 
157
 
158
  # Clean up
159
  os.remove(temp_path)
 
160
 
161
  return JSONResponse({
162
  "status": "success",
@@ -165,18 +205,27 @@ async def analyze_document(file: UploadFile = File(...)):
165
  "timestamp": datetime.now().isoformat()
166
  })
167
  except Exception as e:
 
168
  raise HTTPException(status_code=500, detail=str(e))
169
 
170
  @app.get("/status")
171
  async def service_status():
172
  """Check service status"""
173
- return {
174
  "status": "running",
175
  "version": "1.0.0",
176
  "model": agent.model_name if agent else "not loaded",
177
- "device": str(agent.device) if agent else "unknown"
 
178
  }
 
 
179
 
180
  if __name__ == "__main__":
181
- import uvicorn
182
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
 
 
2
  import sys
3
  import json
4
  import shutil
5
+ import logging
6
  from fastapi import FastAPI, HTTPException, UploadFile, File
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
 
11
  from datetime import datetime
12
  from pydantic import BaseModel
13
 
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
+ handlers=[
19
+ logging.StreamHandler(),
20
+ logging.FileHandler('txagent_api.log')
21
+ ]
22
+ )
23
+ logger = logging.getLogger("TxAgentAPI")
24
+
25
  # Add src directory to Python path
26
  current_dir = os.path.dirname(os.path.abspath(__file__))
27
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
28
  sys.path.insert(0, src_path)
29
 
30
+ # Import TxAgent after setting up path
31
+ try:
32
+ from txagent.txagent import TxAgent
33
+ except ImportError as e:
34
+ logger.error(f"Failed to import TxAgent: {str(e)}")
35
+ raise
36
 
37
  # Configuration
38
  persistent_dir = "/data/hf_cache"
39
+ os.makedirs(persistent_dir, exist_ok=True)
40
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
41
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
42
  file_cache_dir = os.path.join(persistent_dir, "cache")
43
  report_dir = os.path.join(persistent_dir, "reports")
44
 
45
  # Create directories if they don't exist
46
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
47
+ os.makedirs(directory, exist_ok=True)
48
+ logger.info(f"Created directory: {directory}")
 
49
 
50
  # Set environment variables
51
  os.environ["HF_HOME"] = model_cache_dir
 
87
  async def startup_event():
88
  global agent
89
  try:
90
+ logger.info("Initializing TxAgent...")
91
  agent = init_agent()
92
  logger.info("TxAgent initialized successfully")
93
  except Exception as e:
94
+ logger.error(f"Failed to initialize agent: {str(e)}", exc_info=True)
95
  raise RuntimeError(f"Failed to initialize agent: {str(e)}")
96
 
97
  def init_agent():
98
  """Initialize and return the TxAgent instance"""
99
+ try:
100
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
101
+ if not os.path.exists(tool_path):
102
+ logger.info(f"Copying tool file to {tool_path}")
103
+ default_tool = os.path.abspath("data/new_tool.json")
104
+ if os.path.exists(default_tool):
105
+ shutil.copy(default_tool, tool_path)
106
+ else:
107
+ raise FileNotFoundError(f"Default tool file not found at {default_tool}")
108
+
109
+ logger.info("Creating TxAgent instance")
110
+ agent = TxAgent(
111
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
112
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
113
+ tool_files_dict={"new_tool": tool_path},
114
+ enable_finish=True,
115
+ enable_rag=False,
116
+ force_finish=True,
117
+ enable_checker=True,
118
+ step_rag_num=4,
119
+ seed=100
120
+ )
121
+ agent.init_model()
122
+ return agent
123
+ except Exception as e:
124
+ logger.error(f"Error in init_agent: {str(e)}", exc_info=True)
125
+ raise
126
 
127
  @app.post("/chat")
128
  async def chat_endpoint(request: ChatRequest):
129
  """Handle chat conversations"""
130
  try:
131
+ logger.info(f"Chat request received: {request.message[:50]}...")
132
  response = agent.chat(
133
  message=request.message,
134
  history=request.history,
135
  temperature=request.temperature,
136
  max_new_tokens=request.max_new_tokens
137
  )
138
+ logger.info("Chat response generated successfully")
139
  return JSONResponse({
140
  "status": "success",
141
  "response": response,
142
  "timestamp": datetime.now().isoformat()
143
  })
144
  except Exception as e:
145
+ logger.error(f"Chat error: {str(e)}", exc_info=True)
146
  raise HTTPException(status_code=500, detail=str(e))
147
 
148
  @app.post("/multistep")
149
  async def multistep_endpoint(request: MultistepRequest):
150
  """Run multi-step reasoning"""
151
  try:
152
+ logger.info(f"Multistep request received: {request.message[:50]}...")
153
  response = agent.run_multistep_agent(
154
  message=request.message,
155
  temperature=request.temperature,
156
  max_new_tokens=request.max_new_tokens,
157
  max_round=request.max_round
158
  )
159
+ logger.info("Multistep reasoning completed successfully")
160
  return JSONResponse({
161
  "status": "success",
162
  "response": response,
163
  "timestamp": datetime.now().isoformat()
164
  })
165
  except Exception as e:
166
+ logger.error(f"Multistep error: {str(e)}", exc_info=True)
167
  raise HTTPException(status_code=500, detail=str(e))
168
 
169
  @app.post("/analyze")
170
  async def analyze_document(file: UploadFile = File(...)):
171
  """Analyze a medical document"""
172
  try:
173
+ logger.info(f"Document analysis request received for: {file.filename}")
174
+
175
  # Save the uploaded file temporarily
176
  temp_path = os.path.join(file_cache_dir, file.filename)
177
  with open(temp_path, "wb") as f:
178
  f.write(await file.read())
179
+ logger.info(f"File saved temporarily at {temp_path}")
180
 
181
  # Process the document
182
  text = agent.extract_text_from_file(temp_path)
183
  analysis = agent.analyze_text(text)
184
+ logger.info("Document analysis completed successfully")
185
 
186
  # Generate report
187
+ report_filename = f"{os.path.splitext(file.filename)[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
188
+ report_path = os.path.join(report_dir, report_filename)
189
  with open(report_path, "w") as f:
190
  json.dump({
191
  "filename": file.filename,
192
  "analysis": analysis,
193
  "timestamp": datetime.now().isoformat()
194
+ }, f, indent=2)
195
+ logger.info(f"Report generated at {report_path}")
196
 
197
  # Clean up
198
  os.remove(temp_path)
199
+ logger.info(f"Temporary file {temp_path} removed")
200
 
201
  return JSONResponse({
202
  "status": "success",
 
205
  "timestamp": datetime.now().isoformat()
206
  })
207
  except Exception as e:
208
+ logger.error(f"Document analysis error: {str(e)}", exc_info=True)
209
  raise HTTPException(status_code=500, detail=str(e))
210
 
211
  @app.get("/status")
212
  async def service_status():
213
  """Check service status"""
214
+ status = {
215
  "status": "running",
216
  "version": "1.0.0",
217
  "model": agent.model_name if agent else "not loaded",
218
+ "device": str(agent.device) if agent else "unknown",
219
+ "timestamp": datetime.now().isoformat()
220
  }
221
+ logger.info(f"Status check: {status}")
222
+ return status
223
 
224
  if __name__ == "__main__":
225
+ try:
226
+ logger.info("Starting TxAgent API server")
227
+ import uvicorn
228
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
229
+ except Exception as e:
230
+ logger.error(f"Failed to start server: {str(e)}", exc_info=True)
231
+ raise