Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
from pydantic import BaseModel | |
import os | |
import tempfile | |
import uvicorn | |
from typing import List, Optional | |
import logging | |
from contextlib import asynccontextmanager | |
# Import your existing RAG system | |
from .rag import RAG | |
from .vector_store import VectorStore | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Pydantic models | |
class QuestionRequest(BaseModel): | |
question: str | |
class QuestionResponse(BaseModel): | |
answer: str | |
sources: Optional[List[str]] = [] | |
class SearchRequest(BaseModel): | |
query: str | |
limit: Optional[int] = 5 | |
class StatusResponse(BaseModel): | |
status: str | |
message: str | |
version: str | |
# Global variables | |
rag_system = None | |
async def lifespan(app: FastAPI): | |
# Startup | |
global rag_system | |
try: | |
# Initialize RAG system | |
google_api_key = os.getenv("GOOGLE_API_KEY") | |
if not google_api_key: | |
raise ValueError("GOOGLE_API_KEY environment variable not set") | |
collection_name = os.getenv("COLLECTION_NAME", "ca-documents") | |
rag_system = RAG(google_api_key, collection_name) | |
await rag_system.initialize() | |
logger.info("RAG system initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize RAG system: {e}") | |
raise | |
yield | |
# Shutdown | |
logger.info("Shutting down...") | |
# Create FastAPI app | |
app = FastAPI( | |
title="CA Study Assistant API", | |
description="Backend API for the CA Study Assistant RAG system", | |
version="2.0.0", | |
lifespan=lifespan | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Health check endpoint | |
async def health_check(): | |
return {"status": "healthy", "message": "CA Study Assistant API is running"} | |
async def ask_question_stream(request: QuestionRequest): | |
""" | |
Ask a question to the RAG system and get a streaming response | |
""" | |
try: | |
if not rag_system: | |
raise HTTPException(status_code=500, detail="RAG system not initialized") | |
logger.info(f"Processing streaming question: {request.question[:100]}...") | |
async def event_generator(): | |
try: | |
async for chunk in rag_system.ask_question_stream(request.question): | |
if chunk: # Only yield non-empty chunks | |
yield chunk | |
except Exception as e: | |
logger.error(f"Error during stream generation: {e}") | |
# This part may not be sent if the connection is already closed. | |
yield f"Error generating answer: {str(e)}" | |
return StreamingResponse(event_generator(), media_type="text/plain") | |
except Exception as e: | |
logger.error(f"Error processing streaming question: {e}") | |
raise HTTPException(status_code=500, detail=f"Error processing streaming question: {str(e)}") | |
async def upload_document(file: UploadFile = File(...)): | |
""" | |
Upload a document to the RAG system | |
""" | |
try: | |
if not rag_system: | |
raise HTTPException(status_code=500, detail="RAG system not initialized") | |
# Validate file type | |
allowed_extensions = ['.pdf', '.docx', '.txt'] | |
file_extension = os.path.splitext(file.filename)[1].lower() | |
if file_extension not in allowed_extensions: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Unsupported file type. Allowed types: {', '.join(allowed_extensions)}" | |
) | |
# Create temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: | |
content = await file.read() | |
temp_file.write(content) | |
temp_file_path = temp_file.name | |
try: | |
# Process the uploaded file | |
logger.info(f"Processing uploaded file: {file.filename}") | |
success = await rag_system.upload_document(temp_file_path) | |
if success: | |
return { | |
"status": "success", | |
"message": f"File '{file.filename}' uploaded and processed successfully", | |
"filename": file.filename, | |
"size": len(content) | |
} | |
else: | |
raise HTTPException(status_code=500, detail="Failed to process uploaded file") | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_file_path): | |
os.unlink(temp_file_path) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error uploading document: {e}") | |
raise HTTPException(status_code=500, detail=f"Error uploading document: {str(e)}") | |
async def search_documents(request: SearchRequest): | |
""" | |
Search for similar documents | |
""" | |
try: | |
if not rag_system: | |
raise HTTPException(status_code=500, detail="RAG system not initialized") | |
results = await rag_system.vector_store.search_similar(request.query, limit=request.limit) | |
return { | |
"status": "success", | |
"results": results, | |
"count": len(results) | |
} | |
except Exception as e: | |
logger.error(f"Error searching documents: {e}") | |
raise HTTPException(status_code=500, detail=f"Error searching documents: {str(e)}") | |
async def get_status(): | |
""" | |
Get system status | |
""" | |
try: | |
status = "healthy" if rag_system else "unhealthy" | |
message = "RAG system is operational" if rag_system else "RAG system not initialized" | |
return StatusResponse( | |
status=status, | |
message=message, | |
version="2.0.0" | |
) | |
except Exception as e: | |
logger.error(f"Error getting status: {e}") | |
raise HTTPException(status_code=500, detail=f"Error getting status: {str(e)}") | |
async def get_collection_info(): | |
""" | |
Get information about the vector collection | |
""" | |
try: | |
if not rag_system: | |
raise HTTPException(status_code=500, detail="RAG system not initialized") | |
info = await rag_system.vector_store.get_collection_info() | |
return { | |
"status": "success", | |
"collection_info": info | |
} | |
except Exception as e: | |
logger.error(f"Error getting collection info: {e}") | |
raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}") | |
frontend_build_path = "../frontend/build" | |
if os.path.exists(frontend_build_path): | |
app.mount("/static", StaticFiles(directory=f"{frontend_build_path}/static"), name="static") | |
async def serve_react_app(request: Request, full_path: str): | |
""" | |
Serve React app for all non-API routes | |
""" | |
# If it's an API route, let FastAPI handle it | |
if full_path.startswith("api/"): | |
raise HTTPException(status_code=404, detail="API endpoint not found") | |
# For static files (images, etc.) | |
if "." in full_path: | |
file_path = f"{frontend_build_path}/{full_path}" | |
if os.path.exists(file_path): | |
return FileResponse(file_path) | |
else: | |
raise HTTPException(status_code=404, detail="File not found") | |
# For all other routes, serve index.html (React Router will handle it) | |
return FileResponse(f"{frontend_build_path}/index.html") | |
# Error handlers | |
async def not_found_handler(request: Request, exc: HTTPException): | |
if request.url.path.startswith("/api/"): | |
return JSONResponse( | |
status_code=404, | |
content={"detail": "API endpoint not found"} | |
) | |
# For non-API routes, serve React app | |
if os.path.exists(f"{frontend_build_path}/index.html"): | |
return FileResponse(f"{frontend_build_path}/index.html") | |
else: | |
return JSONResponse( | |
status_code=404, | |
content={"detail": "React app not built. Run 'npm run build' in the frontend directory."} | |
) | |
async def internal_error_handler(request: Request, exc: Exception): | |
logger.error(f"Internal server error: {exc}") | |
return JSONResponse( | |
status_code=500, | |
content={"detail": "Internal server error"} | |
) | |
if __name__ == "__main__": | |
# Get port from environment or default to 8000 | |
port = int(os.getenv("PORT", 8000)) | |
uvicorn.run( | |
"backend_api:app", | |
host="0.0.0.0", | |
port=port, | |
reload=True, | |
log_level="info" | |
) |