File size: 1,723 Bytes
0053216
 
 
 
1e414fd
 
48b2ebf
1e414fd
 
0053216
 
 
 
1e414fd
0053216
 
 
1e414fd
0053216
 
 
 
 
 
1e414fd
0053216
 
1e414fd
0053216
 
 
1e414fd
0053216
 
 
 
1e414fd
0053216
 
1e414fd
0053216
1e414fd
0053216
1e414fd
 
0053216
 
1e414fd
0053216
 
1e414fd
 
0053216
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import PeftModel
import torch
import os
import gdown

app = FastAPI()

# Auto-download adapter from Google Drive (if not already present)
ADAPTER_DIR = "adapter"
ADAPTER_PATH = os.path.join(ADAPTER_DIR, "adapter_model.safetensors")
DRIVE_FILE_ID = "1wnuE5t_m4ojI7YqxXZ8lBdtDFoHJJ6_H"  # version 1 model

if not os.path.exists(ADAPTER_PATH):
    os.makedirs(ADAPTER_DIR, exist_ok=True)
    gdown.download(f"https://drive.google.com/uc?id={DRIVE_FILE_ID}", ADAPTER_PATH, quiet=False)

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
    device_map="auto",
    torch_dtype=torch.float16
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model.eval()

@app.post("/chat")
async def chat(request: Request):
    data = await request.json()
    prompt = data.get("prompt")

    if not prompt:
        return {"error": "No prompt provided."}

    full_prompt = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.7,
            do_sample=True,
            top_p=0.9
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("<|im_start|>assistant\n")[-1].strip()
    return {"response": response}