research14 commited on
Commit
9272cb4
·
1 Parent(s): 2ba8da5
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import random
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  # Load Vicuna 7B model and tokenizer
@@ -30,7 +31,6 @@ with gr.Blocks() as demo:
30
  prompt = gr.Textbox(show_label=False, placeholder="Enter prompt")
31
  send_button_POS = gr.Button("Send", scale=0)
32
  clear = gr.ClearButton([prompt, vicuna_chatbot1])
33
-
34
  with gr.Tab("Chunk"):
35
  gr.Markdown("Strategy 1 QA")
36
  with gr.Row():
@@ -52,12 +52,21 @@ with gr.Blocks() as demo:
52
  send_button_Chunk = gr.Button("Send", scale=0)
53
  clear = gr.ClearButton([prompt_chunk, vicuna_chatbot1_chunk])
54
 
 
 
 
 
 
 
 
55
  # Define the Gradio interface
56
  def chatbot_interface(prompt):
57
  vicuna_response = generate_response(model, tokenizer, prompt)
 
 
58
  return {"Vicuna-7B": vicuna_response}
59
 
60
- # Use the chatbot_interface function in the prompt.submit() method
61
  prompt.submit(chatbot_interface, [prompt, vicuna_chatbot1, vicuna_chatbot1_chunk])
62
 
63
  demo.launch()
 
1
  import gradio as gr
2
  import random
3
+ import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
  # Load Vicuna 7B model and tokenizer
 
31
  prompt = gr.Textbox(show_label=False, placeholder="Enter prompt")
32
  send_button_POS = gr.Button("Send", scale=0)
33
  clear = gr.ClearButton([prompt, vicuna_chatbot1])
 
34
  with gr.Tab("Chunk"):
35
  gr.Markdown("Strategy 1 QA")
36
  with gr.Row():
 
52
  send_button_Chunk = gr.Button("Send", scale=0)
53
  clear = gr.ClearButton([prompt_chunk, vicuna_chatbot1_chunk])
54
 
55
+ # Define the function for generating responses
56
+ def generate_response(model, tokenizer, prompt):
57
+ inputs = tokenizer(prompt, return_tensors="pt")
58
+ outputs = model.generate(**inputs, max_length=500, pad_token_id=tokenizer.eos_token_id)
59
+ response = tokenizer.decode(outputs[0])
60
+ return response
61
+
62
  # Define the Gradio interface
63
  def chatbot_interface(prompt):
64
  vicuna_response = generate_response(model, tokenizer, prompt)
65
+ # llama_response = generate_response(llama_model, llama_tokenizer, prompt)
66
+
67
  return {"Vicuna-7B": vicuna_response}
68
 
69
+ # Replace the old respond function with the new general function for Vicuna
70
  prompt.submit(chatbot_interface, [prompt, vicuna_chatbot1, vicuna_chatbot1_chunk])
71
 
72
  demo.launch()