Shriti09 commited on
Commit
89ef257
Β·
verified Β·
1 Parent(s): 5010915

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -33
app.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import PeftModel
4
  import gradio as gr
5
- import os
6
 
7
  # Use GPU if available
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -12,16 +11,18 @@ base_model_name = "microsoft/phi-2" # Pull from HF Hub directly
12
  adapter_path = "Shriti09/Microsoft-Phi-QLora" # Update with your Hugging Face repo path
13
 
14
  print("πŸ”§ Loading base model...")
15
- # Using the Accelerator to load the model and dispatch to the correct devices
16
  base_model = AutoModelForCausalLM.from_pretrained(
17
  base_model_name,
18
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
19
  )
20
 
21
  print("πŸ”§ Loading LoRA adapter...")
 
22
  adapter_model = PeftModel.from_pretrained(base_model, adapter_path)
23
 
24
  print("πŸ”— Merging adapter into base model...")
 
25
  merged_model = adapter_model.merge_and_unload()
26
  merged_model.eval()
27
 
@@ -29,16 +30,10 @@ merged_model.eval()
29
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
30
  print("βœ… Model ready for inference!")
31
 
32
- # Chat function with history
33
- def chat_fn(message, history):
34
- # Convert history to the required format for gr.Chatbot (list of dictionaries with role and content)
35
- full_prompt = ""
36
- for user_msg, bot_msg in history:
37
- full_prompt += f"User: {user_msg}\nAI: {bot_msg}\n"
38
- full_prompt += f"User: {message}\nAI:"
39
-
40
- # Tokenize inputs
41
- inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
42
 
43
  with torch.no_grad():
44
  outputs = merged_model.generate(
@@ -50,30 +45,20 @@ def chat_fn(message, history):
50
  pad_token_id=tokenizer.eos_token_id
51
  )
52
 
53
- # Decode and return only the AI's latest response
54
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
- response = response.split("AI:")[-1].strip()
56
-
57
- # Append to history in the correct format for gr.Chatbot (list of dictionaries)
58
- history.append({"role": "user", "content": message})
59
- history.append({"role": "assistant", "content": response})
60
-
61
- return history, history
62
 
63
  # Gradio UI
64
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
65
- gr.Markdown("<h1>🧠 Phi-2 QLoRA Chatbot</h1>")
66
-
67
- # Use 'type' parameter to specify message format for gr.Chatbot()
68
- chatbot = gr.Chatbot(type="messages") # Use 'messages' type for structured messages
69
- message = gr.Textbox(label="Your message:")
70
- clear = gr.Button("Clear chat")
71
-
72
- state = gr.State([])
73
 
74
- message.submit(chat_fn, [message, state], [chatbot, state])
75
- clear.click(lambda: [], None, chatbot)
76
- clear.click(lambda: [], None, state)
77
 
78
- # Run the app without the 'concurrency_count' argument and share the app publicly
79
- demo.queue().launch(share=True)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from peft import PeftModel
4
  import gradio as gr
 
5
 
6
  # Use GPU if available
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
  adapter_path = "Shriti09/Microsoft-Phi-QLora" # Update with your Hugging Face repo path
12
 
13
  print("πŸ”§ Loading base model...")
14
+ # Load the base model
15
  base_model = AutoModelForCausalLM.from_pretrained(
16
  base_model_name,
17
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
  )
19
 
20
  print("πŸ”§ Loading LoRA adapter...")
21
+ # Load the LoRA adapter
22
  adapter_model = PeftModel.from_pretrained(base_model, adapter_path)
23
 
24
  print("πŸ”— Merging adapter into base model...")
25
+ # Merge adapter into the base model
26
  merged_model = adapter_model.merge_and_unload()
27
  merged_model.eval()
28
 
 
30
  tokenizer = AutoTokenizer.from_pretrained(base_model_name)
31
  print("βœ… Model ready for inference!")
32
 
33
+ # Text generation function
34
+ def generate_text(prompt):
35
+ # Tokenize the input
36
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
37
 
38
  with torch.no_grad():
39
  outputs = merged_model.generate(
 
45
  pad_token_id=tokenizer.eos_token_id
46
  )
47
 
48
+ # Decode and return the generated response
49
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return response
 
 
 
 
 
 
51
 
52
  # Gradio UI
53
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
54
+ gr.Markdown("<h1>🧠 Phi-2 QLoRA Text Generator</h1>")
55
+
56
+ # Textbox for user input and a button to generate text
57
+ prompt = gr.Textbox(label="Enter your prompt:", lines=2)
58
+ output = gr.Textbox(label="Generated text:", lines=5)
 
 
 
59
 
60
+ # Generate text when the button is clicked
61
+ prompt.submit(generate_text, prompt, output)
 
62
 
63
+ # Launch the app
64
+ demo.launch(share=True)