consent_project / app.py
Mahavaury2's picture
version 2
92d7c75 verified
raw
history blame
2.16 kB
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
import uvicorn
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
app = FastAPI()
# Chargement du modèle uniquement si CUDA est disponible
if torch.cuda.is_available():
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
model = None
tokenizer = None
MAX_INPUT_TOKEN_LENGTH = 4096
def generate_response(message: str, history: list) -> str:
conversation = history + [{"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": 1024,
"do_sample": True,
"top_p": 0.9,
"top_k": 50,
"temperature": 0.6,
"num_beams": 1,
"repetition_penalty": 1.2,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
response_text = ""
for text in streamer:
response_text += text
return response_text
@app.post("/chat")
async def chat_endpoint(request: Request):
data = await request.json()
message = data.get("message", "")
# Utilisation d'un historique vide pour simplifier
response_text = generate_response(message, history=[])
return JSONResponse({"response": response_text})
@app.get("/", response_class=HTMLResponse)
async def root():
with open("index.html", "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)