Ali2206 commited on
Commit
ea3d9f9
·
verified ·
1 Parent(s): 60e4c3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -26
app.py CHANGED
@@ -6,11 +6,12 @@ import logging
6
  from datetime import datetime
7
  from typing import List, Dict, Optional
8
  from fastapi import FastAPI, HTTPException, UploadFile, File
9
- from fastapi.responses import JSONResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
  import markdown
13
  import PyPDF2
 
14
 
15
  # Setup logging
16
  logging.basicConfig(
@@ -102,31 +103,44 @@ async def startup_event():
102
  except Exception as e:
103
  logger.error(f"Startup error: {str(e)}")
104
 
105
- @app.post("/chat")
106
- async def chat_endpoint(request: ChatRequest):
107
- try:
108
- raw_response = agent.chat(
109
- message=request.message,
110
- history=request.history,
111
- temperature=request.temperature,
112
- max_new_tokens=request.max_new_tokens
113
- )
114
- formatted_response = {
115
- "raw": raw_response,
116
- "clean": clean_text_response(raw_response),
117
- "structured": structure_medical_response(raw_response),
118
- "html": markdown.markdown(raw_response)
119
- }
120
- return JSONResponse({
121
- "status": "success",
122
- "format": request.format,
123
- "response": formatted_response.get(request.format, formatted_response["clean"]),
124
- "timestamp": datetime.now().isoformat(),
125
- "available_formats": list(formatted_response.keys())
126
- })
127
- except Exception as e:
128
- logger.error(f"Chat error: {str(e)}")
129
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  @app.post("/upload")
132
  async def upload_file(file: UploadFile = File(...)):
 
6
  from datetime import datetime
7
  from typing import List, Dict, Optional
8
  from fastapi import FastAPI, HTTPException, UploadFile, File
9
+ from fastapi.responses import JSONResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
  import markdown
13
  import PyPDF2
14
+ import asyncio
15
 
16
  # Setup logging
17
  logging.basicConfig(
 
103
  except Exception as e:
104
  logger.error(f"Startup error: {str(e)}")
105
 
106
+ @app.post("/chat-stream")
107
+ async def chat_stream_endpoint(request: ChatRequest):
108
+ async def token_stream():
109
+ try:
110
+ conversation = []
111
+ conversation.append({"role": "system", "content": agent.chat_prompt})
112
+ if request.history:
113
+ for msg in request.history:
114
+ conversation.append({"role": msg["role"], "content": msg["content"]})
115
+ conversation.append({"role": "user", "content": request.message})
116
+
117
+ input_ids = agent.tokenizer.apply_chat_template(
118
+ conversation,
119
+ add_generation_prompt=True,
120
+ return_tensors="pt"
121
+ ).to(agent.device)
122
+
123
+ streamer = agent.model.generate(
124
+ input_ids,
125
+ do_sample=True,
126
+ temperature=request.temperature,
127
+ max_new_tokens=request.max_new_tokens,
128
+ pad_token_id=agent.tokenizer.eos_token_id,
129
+ return_dict_in_generate=True,
130
+ output_scores=False
131
+ )
132
+
133
+ output = agent.tokenizer.decode(streamer["sequences"][0][input_ids.shape[1]:], skip_special_tokens=True)
134
+
135
+ for chunk in output.split():
136
+ yield chunk + " "
137
+ await asyncio.sleep(0.05)
138
+
139
+ except Exception as e:
140
+ logger.error(f"Streaming chat error: {str(e)}")
141
+ yield f"\n⚠️ Error: {str(e)}"
142
+
143
+ return StreamingResponse(token_stream(), media_type="text/plain")
144
 
145
  @app.post("/upload")
146
  async def upload_file(file: UploadFile = File(...)):