Charm_15 / web_api.py
GeminiFan207's picture
Create web_api.py
a5e50a2 verified
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
import aiofiles
import json
from typing import List, Optional
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Enable CORS (Cross-Origin Resource Sharing)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files (HTML/CSS/JavaScript)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Load AI model and tokenizer
MODEL_NAME = "mistralai/Mistral-8x7B" # Replace with your AI model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
# In-memory storage for search history
search_history = []
# Pydantic models
class InferenceRequest(BaseModel):
prompt: str
max_length: Optional[int] = 100
class InferenceResponse(BaseModel):
generated_text: str
timestamp: str
class SearchHistoryResponse(BaseModel):
history: List[InferenceResponse]
# API Endpoints
@app.post("/inference")
async def run_inference(request: InferenceRequest):
"""Run inference using the AI model."""
try:
inputs = tokenizer(request.prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=request.max_length)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Log the search
search_entry = InferenceResponse(
generated_text=generated_text,
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
)
search_history.append(search_entry)
logger.info(f"Inference completed for prompt: {request.prompt}")
return search_entry
except Exception as e:
logger.error(f"Error during inference: {e}")
raise HTTPException(status_code=500, detail="Failed to run inference.")
@app.get("/search-history")
async def get_search_history():
"""Get the history of all searches."""
return SearchHistoryResponse(history=search_history)
# WebSocket for real-time interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
request = json.loads(data)
prompt = request.get("prompt")
max_length = request.get("max_length", 100)
if not prompt:
await websocket.send_text(json.dumps({"error": "Prompt is required."}))
continue
# Run inference
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=max_length)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Send response
response = {
"generated_text": generated_text,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
await websocket.send_text(json.dumps(response))
except WebSocketDisconnect:
logger.info("WebSocket disconnected.")
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket.send_text(json.dumps({"error": str(e)}))
# Serve frontend
@app.get("/")
async def serve_frontend():
"""Serve the frontend HTML file."""
async with aiofiles.open("static/index.html", mode="r") as file:
return await file.read()