Spaces:
Sleeping
Sleeping
File size: 2,342 Bytes
7fdb8e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
from rag_demo.pipeline import process_pdf
import nest_asyncio
from rag_demo.rag.retriever import RAGPipeline
from loguru import logger
app = FastAPI()
# Apply nest_asyncio at the start of the application
nest_asyncio.apply()
# Create templates directory if it doesn't exist
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
class ChatRequest(BaseModel):
question: str
@app.get("/", response_class=HTMLResponse)
async def upload_page(request: Request):
return templates.TemplateResponse("upload.html", {"request": request})
@app.get("/chat", response_class=HTMLResponse)
async def chat_page(request: Request):
return templates.TemplateResponse("chat.html", {"request": request})
@app.post("/upload")
async def upload_pdf(request: Request, file: UploadFile = File(...)):
try:
# Create uploads directory if it doesn't exist
os.makedirs("data", exist_ok=True)
file_path = f"data/{file.filename}"
with open(file_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
# Process the PDF file with proper await statements
await process_pdf(file_path)
# Return template response with success message
return templates.TemplateResponse(
"upload.html",
{
"request": request,
"message": f"Successfully processed {file.filename}",
"processing": False,
},
)
except Exception as e:
return templates.TemplateResponse(
"upload.html", {"request": request, "error": str(e), "processing": False}
)
@app.post("/chat")
async def chat(chat_request: ChatRequest):
rag_pipeline = RAGPipeline()
try:
answer = rag_pipeline.rag(chat_request.question)
print(answer)
logger.info(answer)
return {"answer": answer}
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|