File size: 2,316 Bytes
3fe707b
025b757
3fe707b
 
 
025b757
 
3fe707b
025b757
 
baa8c5f
3fe707b
025b757
 
 
 
 
 
3fe707b
025b757
 
3fe707b
025b757
 
 
3fe707b
025b757
 
 
 
 
 
 
c91a27e
025b757
c91a27e
 
 
025b757
 
 
 
 
 
 
 
 
 
 
 
 
 
3fe707b
c91a27e
 
025b757
 
 
c91a27e
025b757
 
 
c91a27e
 
025b757
 
c91a27e
 
 
025b757
 
 
c91a27e
025b757
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import gradio as gr

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Base model and adapter paths
base_model_name = "microsoft/phi-2"  # Pull from HF Hub directly
adapter_path = "Shriti09/Microsoft-Phi-QLora"  # Your uploaded adapter folder in Space repo

print("πŸ”§ Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)

print("πŸ”§ Loading LoRA adapter...")
adapter_model = PeftModel.from_pretrained(base_model, adapter_path)

print("πŸ”— Merging adapter into base model...")
merged_model = adapter_model.merge_and_unload()
merged_model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print("βœ… Model ready for inference!")

# Chat function with history
def chat_fn(message, history):
    # Combine conversation history into one prompt
    full_prompt = ""
    for user_msg, bot_msg in history:
        full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
    full_prompt += f"User: {message}\nAI:"

    # Tokenize inputs
    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = merged_model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode and return only the AI's latest response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("AI:")[-1].strip()

    # Append to history
    history.append((message, response))
    return history, history

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("<h1>🧠 Phi-2 QLoRA Chatbot</h1>")

    chatbot = gr.Chatbot()
    message = gr.Textbox(label="Your message:")
    clear = gr.Button("Clear chat")

    state = gr.State([])

    message.submit(chat_fn, [message, state], [chatbot, state])
    clear.click(lambda: [], None, chatbot)
    clear.click(lambda: [], None, state)

# Run with queue for multiple users
demo.queue(concurrency_count=2).launch()