Spaces:
Sleeping
Sleeping
File size: 2,477 Bytes
3fe707b 025b757 3fe707b 025b757 3fe707b 95775e3 025b757 5044361 3fe707b 025b757 3fe707b 025b757 3fe707b 025b757 3fe707b 025b757 c91a27e 025b757 c91a27e 025b757 3fe707b c91a27e 025b757 c91a27e 025b757 c91a27e 95775e3 025b757 c91a27e 025b757 c91a27e 95775e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import gradio as gr
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Base model and adapter paths (updated for Hugging Face repo)
base_model_name = "microsoft/phi-2" # Pull from HF Hub directly
adapter_path = "Shriti09/Microsoft-Phi-QLora" # Update with your Hugging Face repo path
print("π§ Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
print("π§ Loading LoRA adapter...")
adapter_model = PeftModel.from_pretrained(base_model, adapter_path)
print("π Merging adapter into base model...")
merged_model = adapter_model.merge_and_unload()
merged_model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
print("β
Model ready for inference!")
# Chat function with history
def chat_fn(message, history):
# Combine conversation history into one prompt
full_prompt = ""
for user_msg, bot_msg in history:
full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
full_prompt += f"User: {message}\nAI:"
# Tokenize inputs
inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = merged_model.generate(
**inputs,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode and return only the AI's latest response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("AI:")[-1].strip()
# Append to history
history.append((message, response))
return history, history
# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1>π§ Phi-2 QLoRA Chatbot</h1>")
# Use 'type' parameter to specify message format for gr.Chatbot()
chatbot = gr.Chatbot(type="messages") # Use 'messages' type for structured messages
message = gr.Textbox(label="Your message:")
clear = gr.Button("Clear chat")
state = gr.State([])
message.submit(chat_fn, [message, state], [chatbot, state])
clear.click(lambda: [], None, chatbot)
clear.click(lambda: [], None, state)
# Run the app without the 'concurrency_count' argument
demo.queue().launch()
|