File size: 2,352 Bytes
18aea39
 
 
 
6668ea3
0053216
 
6668ea3
 
18aea39
 
 
 
 
 
6668ea3
 
 
48b2ebf
6668ea3
0053216
6668ea3
 
0053216
18aea39
 
0053216
1e414fd
6668ea3
 
1e414fd
6668ea3
18aea39
1e414fd
6668ea3
1e414fd
18aea39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# app/main.py
from fastapi import FastAPI, Form
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import os

from app.download_adapter import download_latest_adapter

# === Step 1: Download Adapter ===
download_latest_adapter()

# === Step 2: Load Model and Tokenizer ===
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
ADAPTER_FOLDER = "adapter"
HF_TOKEN = os.environ.get("HF_TOKEN", None)

print("🚀 Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
    token=HF_TOKEN,
    trust_remote_code=True
)

print("🔧 Applying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_FOLDER)

print("🧠 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# === Step 3: FastAPI App ===
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins for testing
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", response_class=HTMLResponse)
async def form():
    return """
    <html>
        <head><title>Qwen Chat</title></head>
        <body>
            <h2>Ask something:</h2>
            <form method="post">
                <textarea name="prompt" rows="4" cols="60"></textarea><br>
                <input type="submit" value="Generate">
            </form>
        </body>
    </html>
    """

@app.post("/", response_class=HTMLResponse)
async def generate(prompt: str = Form(...)):
    full_prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    output = pipe(full_prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
    response = output[0]["generated_text"].split("<|im_start|>assistant\n")[-1].strip()

    return f"""
    <html>
        <head><title>Qwen Chat</title></head>
        <body>
            <h2>Your Prompt:</h2>
            <p>{prompt}</p>
            <h2>Response:</h2>
            <p>{response}</p>
            <a href="/">Ask again</a>
        </body>
    </html>
    """