from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch import logging import aiofiles import json from typing import List, Optional from datetime import datetime # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI() # Enable CORS (Cross-Origin Resource Sharing) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Serve static files (HTML/CSS/JavaScript) app.mount("/static", StaticFiles(directory="static"), name="static") # Load AI model and tokenizer MODEL_NAME = "mistralai/Mistral-8x7B" # Replace with your AI model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) # In-memory storage for search history search_history = [] # Pydantic models class InferenceRequest(BaseModel): prompt: str max_length: Optional[int] = 100 class InferenceResponse(BaseModel): generated_text: str timestamp: str class SearchHistoryResponse(BaseModel): history: List[InferenceResponse] # API Endpoints @app.post("/inference") async def run_inference(request: InferenceRequest): """Run inference using the AI model.""" try: inputs = tokenizer(request.prompt, return_tensors="pt") outputs = model.generate(inputs.input_ids, max_length=request.max_length) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Log the search search_entry = InferenceResponse( generated_text=generated_text, timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) search_history.append(search_entry) logger.info(f"Inference completed for prompt: {request.prompt}") return search_entry except Exception as e: logger.error(f"Error during inference: {e}") raise HTTPException(status_code=500, detail="Failed to run inference.") @app.get("/search-history") async def get_search_history(): """Get the history of all searches.""" return SearchHistoryResponse(history=search_history) # WebSocket for real-time interaction @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: data = await websocket.receive_text() request = json.loads(data) prompt = request.get("prompt") max_length = request.get("max_length", 100) if not prompt: await websocket.send_text(json.dumps({"error": "Prompt is required."})) continue # Run inference inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(inputs.input_ids, max_length=max_length) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Send response response = { "generated_text": generated_text, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } await websocket.send_text(json.dumps(response)) except WebSocketDisconnect: logger.info("WebSocket disconnected.") except Exception as e: logger.error(f"WebSocket error: {e}") await websocket.send_text(json.dumps({"error": str(e)})) # Serve frontend @app.get("/") async def serve_frontend(): """Serve the frontend HTML file.""" async with aiofiles.open("static/index.html", mode="r") as file: return await file.read()