File size: 3,863 Bytes
a5e50a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
import aiofiles
import json
from typing import List, Optional
from datetime import datetime

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Enable CORS (Cross-Origin Resource Sharing)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve static files (HTML/CSS/JavaScript)
app.mount("/static", StaticFiles(directory="static"), name="static")

# Load AI model and tokenizer
MODEL_NAME = "mistralai/Mistral-8x7B"  # Replace with your AI model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)

# In-memory storage for search history
search_history = []

# Pydantic models
class InferenceRequest(BaseModel):
    prompt: str
    max_length: Optional[int] = 100

class InferenceResponse(BaseModel):
    generated_text: str
    timestamp: str

class SearchHistoryResponse(BaseModel):
    history: List[InferenceResponse]

# API Endpoints
@app.post("/inference")
async def run_inference(request: InferenceRequest):
    """Run inference using the AI model."""
    try:
        inputs = tokenizer(request.prompt, return_tensors="pt")
        outputs = model.generate(inputs.input_ids, max_length=request.max_length)
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Log the search
        search_entry = InferenceResponse(
            generated_text=generated_text,
            timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        )
        search_history.append(search_entry)

        logger.info(f"Inference completed for prompt: {request.prompt}")
        return search_entry
    except Exception as e:
        logger.error(f"Error during inference: {e}")
        raise HTTPException(status_code=500, detail="Failed to run inference.")

@app.get("/search-history")
async def get_search_history():
    """Get the history of all searches."""
    return SearchHistoryResponse(history=search_history)

# WebSocket for real-time interaction
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            request = json.loads(data)
            prompt = request.get("prompt")
            max_length = request.get("max_length", 100)

            if not prompt:
                await websocket.send_text(json.dumps({"error": "Prompt is required."}))
                continue

            # Run inference
            inputs = tokenizer(prompt, return_tensors="pt")
            outputs = model.generate(inputs.input_ids, max_length=max_length)
            generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Send response
            response = {
                "generated_text": generated_text,
                "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            }
            await websocket.send_text(json.dumps(response))

    except WebSocketDisconnect:
        logger.info("WebSocket disconnected.")
    except Exception as e:
        logger.error(f"WebSocket error: {e}")
        await websocket.send_text(json.dumps({"error": str(e)}))

# Serve frontend
@app.get("/")
async def serve_frontend():
    """Serve the frontend HTML file."""
    async with aiofiles.open("static/index.html", mode="r") as file:
        return await file.read()