Ali2206 commited on
Commit
bdcc052
·
verified ·
1 Parent(s): 3cfe99a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -140
app.py CHANGED
@@ -1,24 +1,18 @@
1
  import os
2
  import sys
3
  import json
4
- import shutil
5
  import logging
6
  from fastapi import FastAPI, HTTPException, UploadFile, File
7
  from fastapi.responses import JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from typing import List, Dict, Optional
10
- import torch
11
  from datetime import datetime
12
  from pydantic import BaseModel
13
 
14
- # Configure logging
15
  logging.basicConfig(
16
  level=logging.INFO,
17
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
- handlers=[
19
- logging.StreamHandler(),
20
- logging.FileHandler('txagent_api.log')
21
- ]
22
  )
23
  logger = logging.getLogger("TxAgentAPI")
24
 
@@ -34,36 +28,6 @@ except ImportError as e:
34
  logger.error(f"Failed to import TxAgent: {str(e)}")
35
  raise
36
 
37
- # Configuration
38
- persistent_dir = "/data/hf_cache"
39
- os.makedirs(persistent_dir, exist_ok=True)
40
- model_cache_dir = os.path.join(persistent_dir, "txagent_models")
41
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
42
- file_cache_dir = os.path.join(persistent_dir, "cache")
43
- report_dir = os.path.join(persistent_dir, "reports")
44
-
45
- # Create directories if they don't exist
46
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
47
- os.makedirs(directory, exist_ok=True)
48
- logger.info(f"Created directory: {directory}")
49
-
50
- # Set environment variables
51
- os.environ["HF_HOME"] = model_cache_dir
52
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
53
-
54
- # Request models
55
- class ChatRequest(BaseModel):
56
- message: str
57
- temperature: float = 0.7
58
- max_new_tokens: int = 512
59
- history: Optional[List[Dict]] = None
60
-
61
- class MultistepRequest(BaseModel):
62
- message: str
63
- temperature: float = 0.7
64
- max_new_tokens: int = 512
65
- max_round: int = 5
66
-
67
  # Initialize FastAPI app
68
  app = FastAPI(
69
  title="TxAgent API",
@@ -88,29 +52,10 @@ async def startup_event():
88
  global agent
89
  try:
90
  logger.info("Initializing TxAgent...")
91
- agent = init_agent()
92
- logger.info("TxAgent initialized successfully")
93
- except Exception as e:
94
- logger.error(f"Failed to initialize agent: {str(e)}", exc_info=True)
95
- raise RuntimeError(f"Failed to initialize agent: {str(e)}")
96
-
97
- def init_agent():
98
- """Initialize and return the TxAgent instance"""
99
- try:
100
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
101
- if not os.path.exists(tool_path):
102
- logger.info(f"Copying tool file to {tool_path}")
103
- default_tool = os.path.abspath("data/new_tool.json")
104
- if os.path.exists(default_tool):
105
- shutil.copy(default_tool, tool_path)
106
- else:
107
- raise FileNotFoundError(f"Default tool file not found at {default_tool}")
108
-
109
- logger.info("Creating TxAgent instance")
110
  agent = TxAgent(
111
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
112
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
113
- tool_files_dict={"new_tool": tool_path},
114
  enable_finish=True,
115
  enable_rag=False,
116
  force_finish=True,
@@ -119,10 +64,16 @@ def init_agent():
119
  seed=100
120
  )
121
  agent.init_model()
122
- return agent
123
  except Exception as e:
124
- logger.error(f"Error in init_agent: {str(e)}", exc_info=True)
125
- raise
 
 
 
 
 
 
126
 
127
  @app.post("/chat")
128
  async def chat_endpoint(request: ChatRequest):
@@ -135,97 +86,21 @@ async def chat_endpoint(request: ChatRequest):
135
  temperature=request.temperature,
136
  max_new_tokens=request.max_new_tokens
137
  )
138
- logger.info("Chat response generated successfully")
139
- return JSONResponse({
140
- "status": "success",
141
- "response": response,
142
- "timestamp": datetime.now().isoformat()
143
- })
144
- except Exception as e:
145
- logger.error(f"Chat error: {str(e)}", exc_info=True)
146
- raise HTTPException(status_code=500, detail=str(e))
147
-
148
- @app.post("/multistep")
149
- async def multistep_endpoint(request: MultistepRequest):
150
- """Run multi-step reasoning"""
151
- try:
152
- logger.info(f"Multistep request received: {request.message[:50]}...")
153
- response = agent.run_multistep_agent(
154
- message=request.message,
155
- temperature=request.temperature,
156
- max_new_tokens=request.max_new_tokens,
157
- max_round=request.max_round
158
- )
159
- logger.info("Multistep reasoning completed successfully")
160
  return JSONResponse({
161
  "status": "success",
162
  "response": response,
163
  "timestamp": datetime.now().isoformat()
164
  })
165
  except Exception as e:
166
- logger.error(f"Multistep error: {str(e)}", exc_info=True)
167
- raise HTTPException(status_code=500, detail=str(e))
168
-
169
- @app.post("/analyze")
170
- async def analyze_document(file: UploadFile = File(...)):
171
- """Analyze a medical document"""
172
- try:
173
- logger.info(f"Document analysis request received for: {file.filename}")
174
-
175
- # Save the uploaded file temporarily
176
- temp_path = os.path.join(file_cache_dir, file.filename)
177
- with open(temp_path, "wb") as f:
178
- f.write(await file.read())
179
- logger.info(f"File saved temporarily at {temp_path}")
180
-
181
- # Process the document
182
- text = agent.extract_text_from_file(temp_path)
183
- analysis = agent.analyze_text(text)
184
- logger.info("Document analysis completed successfully")
185
-
186
- # Generate report
187
- report_filename = f"{os.path.splitext(file.filename)[0]}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
188
- report_path = os.path.join(report_dir, report_filename)
189
- with open(report_path, "w") as f:
190
- json.dump({
191
- "filename": file.filename,
192
- "analysis": analysis,
193
- "timestamp": datetime.now().isoformat()
194
- }, f, indent=2)
195
- logger.info(f"Report generated at {report_path}")
196
-
197
- # Clean up
198
- os.remove(temp_path)
199
- logger.info(f"Temporary file {temp_path} removed")
200
-
201
- return JSONResponse({
202
- "status": "success",
203
- "analysis": analysis,
204
- "report_path": report_path,
205
- "timestamp": datetime.now().isoformat()
206
- })
207
- except Exception as e:
208
- logger.error(f"Document analysis error: {str(e)}", exc_info=True)
209
  raise HTTPException(status_code=500, detail=str(e))
210
 
211
  @app.get("/status")
212
  async def service_status():
213
  """Check service status"""
214
- status = {
215
  "status": "running",
216
  "version": "1.0.0",
217
  "model": agent.model_name if agent else "not loaded",
218
- "device": str(agent.device) if agent else "unknown",
219
  "timestamp": datetime.now().isoformat()
220
- }
221
- logger.info(f"Status check: {status}")
222
- return status
223
-
224
- if __name__ == "__main__":
225
- try:
226
- logger.info("Starting TxAgent API server")
227
- import uvicorn
228
- uvicorn.run(app, host="0.0.0.0", port=8000, log_config=None)
229
- except Exception as e:
230
- logger.error(f"Failed to start server: {str(e)}", exc_info=True)
231
- raise
 
1
  import os
2
  import sys
3
  import json
 
4
  import logging
5
  from fastapi import FastAPI, HTTPException, UploadFile, File
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from typing import List, Dict, Optional
 
9
  from datetime import datetime
10
  from pydantic import BaseModel
11
 
12
+ # Configure logging for Hugging Face Spaces
13
  logging.basicConfig(
14
  level=logging.INFO,
15
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
 
 
 
16
  )
17
  logger = logging.getLogger("TxAgentAPI")
18
 
 
28
  logger.error(f"Failed to import TxAgent: {str(e)}")
29
  raise
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Initialize FastAPI app
32
  app = FastAPI(
33
  title="TxAgent API",
 
52
  global agent
53
  try:
54
  logger.info("Initializing TxAgent...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  agent = TxAgent(
56
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
57
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
58
+ tool_files_dict={}, # No tool files in this example
59
  enable_finish=True,
60
  enable_rag=False,
61
  force_finish=True,
 
64
  seed=100
65
  )
66
  agent.init_model()
67
+ logger.info("TxAgent initialized successfully")
68
  except Exception as e:
69
+ logger.error(f"Failed to initialize agent: {str(e)}")
70
+ raise RuntimeError(f"Failed to initialize agent: {str(e)}")
71
+
72
+ class ChatRequest(BaseModel):
73
+ message: str
74
+ temperature: float = 0.7
75
+ max_new_tokens: int = 512
76
+ history: Optional[List[Dict]] = None
77
 
78
  @app.post("/chat")
79
  async def chat_endpoint(request: ChatRequest):
 
86
  temperature=request.temperature,
87
  max_new_tokens=request.max_new_tokens
88
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return JSONResponse({
90
  "status": "success",
91
  "response": response,
92
  "timestamp": datetime.now().isoformat()
93
  })
94
  except Exception as e:
95
+ logger.error(f"Chat error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  raise HTTPException(status_code=500, detail=str(e))
97
 
98
  @app.get("/status")
99
  async def service_status():
100
  """Check service status"""
101
+ return {
102
  "status": "running",
103
  "version": "1.0.0",
104
  "model": agent.model_name if agent else "not loaded",
 
105
  "timestamp": datetime.now().isoformat()
106
+ }