|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import logging |
|
import aiofiles |
|
import json |
|
from typing import List, Optional |
|
from datetime import datetime |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
MODEL_NAME = "mistralai/Mistral-8x7B" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) |
|
|
|
|
|
search_history = [] |
|
|
|
|
|
class InferenceRequest(BaseModel): |
|
prompt: str |
|
max_length: Optional[int] = 100 |
|
|
|
class InferenceResponse(BaseModel): |
|
generated_text: str |
|
timestamp: str |
|
|
|
class SearchHistoryResponse(BaseModel): |
|
history: List[InferenceResponse] |
|
|
|
|
|
@app.post("/inference") |
|
async def run_inference(request: InferenceRequest): |
|
"""Run inference using the AI model.""" |
|
try: |
|
inputs = tokenizer(request.prompt, return_tensors="pt") |
|
outputs = model.generate(inputs.input_ids, max_length=request.max_length) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
search_entry = InferenceResponse( |
|
generated_text=generated_text, |
|
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
) |
|
search_history.append(search_entry) |
|
|
|
logger.info(f"Inference completed for prompt: {request.prompt}") |
|
return search_entry |
|
except Exception as e: |
|
logger.error(f"Error during inference: {e}") |
|
raise HTTPException(status_code=500, detail="Failed to run inference.") |
|
|
|
@app.get("/search-history") |
|
async def get_search_history(): |
|
"""Get the history of all searches.""" |
|
return SearchHistoryResponse(history=search_history) |
|
|
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket): |
|
await websocket.accept() |
|
try: |
|
while True: |
|
data = await websocket.receive_text() |
|
request = json.loads(data) |
|
prompt = request.get("prompt") |
|
max_length = request.get("max_length", 100) |
|
|
|
if not prompt: |
|
await websocket.send_text(json.dumps({"error": "Prompt is required."})) |
|
continue |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate(inputs.input_ids, max_length=max_length) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = { |
|
"generated_text": generated_text, |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
|
} |
|
await websocket.send_text(json.dumps(response)) |
|
|
|
except WebSocketDisconnect: |
|
logger.info("WebSocket disconnected.") |
|
except Exception as e: |
|
logger.error(f"WebSocket error: {e}") |
|
await websocket.send_text(json.dumps({"error": str(e)})) |
|
|
|
|
|
@app.get("/") |
|
async def serve_frontend(): |
|
"""Serve the frontend HTML file.""" |
|
async with aiofiles.open("static/index.html", mode="r") as file: |
|
return await file.read() |