aigen / app /main.py
Ais
🚀 Initial commit of FastAPI LoRA chatbot
48b2ebf
raw
history blame
1.86 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
app = FastAPI()
# ✅ Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token = tokenizer.eos_token
# ✅ Setup quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# ✅ Load base model
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2",
device_map="auto",
quantization_config=bnb_config
)
# ✅ Load LoRA adapter (ensure it's downloaded)
ADAPTER_DIR = "./adapter/version 1"
model = PeftModel.from_pretrained(model, ADAPTER_DIR)
model.eval()
# ✅ Build prompt from messages
def build_prompt(messages):
prompt = ""
for msg in messages:
if msg["role"] == "user":
prompt += f"### User:\n{msg['content']}\n"
elif msg["role"] == "assistant":
prompt += f"### Assistant:\n{msg['content']}\n"
prompt += "### Assistant:\n"
return prompt
# ✅ Input format
class ChatRequest(BaseModel):
messages: list # list of {"role": "user"/"assistant", "content": "..."}
@app.post("/chat")
async def chat(req: ChatRequest):
prompt = build_prompt(req.messages)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
reply = response.split("### Assistant:")[-1].strip()
return {"response": reply}