from fastapi import FastAPI, HTTPException from pydantic import BaseModel from orator import Session, SQLDatabase, DocumentDatabase from langchain.chat_models import init_chat_model from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import asyncio # Initialize FastAPI app app = FastAPI(title="Orator Chat API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # Use a specific origin in production allow_credentials=True, allow_methods=["*"], # Allow all methods (GET, POST, etc.) allow_headers=["*"], # Allow all headers ) # Initialize LLM and databases llm = init_chat_model("o3-mini", model_provider="openai") chinook_db = SQLDatabase.from_uri("sqlite:///data/Chinook.db") pricegram_db = DocumentDatabase("data/pricegram.json", top_k=10) # Initialize session session = Session(llm=llm, datasources=[chinook_db, pricegram_db]) # Pydantic model for request class QueryRequest(BaseModel): query: str source: int @app.post("/query/") async def get_response(request: QueryRequest): """Process a query and return the response.""" try: print("Got Request:", request) response, logs = session.invoke(request.query, datasource=request.source) response = {"response": response} print("Sending Respose:", response) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/query/stream/") async def stream_response(request: QueryRequest): """Stream responses for a given query.""" async def event_generator(): try: events = session.stream(request.query) for event in events: for person, quote in event.items(): yield f"{person}: {quote['messages'][-1].text}\n" await asyncio.sleep(0.1) # Simulate streaming delay except Exception as e: yield f"Error: {str(e)}" return StreamingResponse(event_generator(), media_type="text/plain") @app.get("/") async def root(): return {"message": "Welcome to the Orator Chat API"}