prakhardoneria commited on
Commit
11b75fb
·
verified ·
1 Parent(s): 3738e2a

Implemented Fixes

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -16,17 +16,17 @@ def respond(
16
  temperature,
17
  top_p,
18
  ):
19
- # Format the chat history for the DialoGPT model
20
- full_conversation = ""
21
  for user_msg, bot_msg in history:
22
  if user_msg:
23
- full_conversation += f"User: {user_msg}\n"
24
  if bot_msg:
25
- full_conversation += f"DialoGPT: {bot_msg}\n"
26
- full_conversation += f"User: {message}\nDialoGPT:"
27
 
28
  # Tokenize input and generate response
29
- inputs = tokenizer.encode(full_conversation, return_tensors="pt")
30
  outputs = model.generate(
31
  inputs,
32
  max_length=max_tokens,
@@ -38,9 +38,11 @@ def respond(
38
  return response
39
 
40
  # Gradio Chat Interface
41
- demo = gr.ChatInterface(
42
- respond,
43
- additional_inputs=[
 
 
44
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
45
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
46
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
@@ -52,9 +54,8 @@ demo = gr.ChatInterface(
52
  label="Top-p (nucleus sampling)",
53
  ),
54
  ],
 
55
  )
56
 
57
  if __name__ == "__main__":
58
- # Launch the Gradio app without 'enable_api'
59
  demo.launch()
60
-
 
16
  temperature,
17
  top_p,
18
  ):
19
+ # Format the chat history for the DialoGPT model using the 'messages' format
20
+ conversation = [{"role": "system", "content": system_message}]
21
  for user_msg, bot_msg in history:
22
  if user_msg:
23
+ conversation.append({"role": "user", "content": user_msg})
24
  if bot_msg:
25
+ conversation.append({"role": "assistant", "content": bot_msg})
26
+ conversation.append({"role": "user", "content": message})
27
 
28
  # Tokenize input and generate response
29
+ inputs = tokenizer.encode(" ".join([msg["content"] for msg in conversation]), return_tensors="pt")
30
  outputs = model.generate(
31
  inputs,
32
  max_length=max_tokens,
 
38
  return response
39
 
40
  # Gradio Chat Interface
41
+ demo = gr.Interface(
42
+ fn=respond,
43
+ inputs=[
44
+ gr.Textbox(label="Message"),
45
+ gr.State(), # For history
46
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
47
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
48
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
54
  label="Top-p (nucleus sampling)",
55
  ),
56
  ],
57
+ outputs="text",
58
  )
59
 
60
  if __name__ == "__main__":
 
61
  demo.launch()