tstone87 commited on
Commit
8dae174
·
verified ·
1 Parent(s): 9111270

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -33
app.py CHANGED
@@ -1,41 +1,23 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import gradio as gr
3
- import torch
4
 
5
- # Load the model and tokenizer
6
- model_name = "microsoft/DialoGPT-small"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
  # Function to generate a response
11
  def dialoGPT_response(user_input, history):
12
- # Convert history to tensor if it's not None
13
- if history:
14
- history_tensor = torch.LongTensor(history)
15
- else:
16
- history_tensor = torch.LongTensor([])
17
-
18
- # Encode the new user input, with the history
19
- new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
20
-
21
- # Append the new user input tokens to the chat history
22
- bot_input_ids = torch.cat([history_tensor, new_user_input_ids], dim=-1)
23
-
24
- # Generate a response
25
- chat_history_ids = model.generate(
26
- bot_input_ids,
27
- max_length=1000, # You might want to adjust this based on your needs
28
- pad_token_id=tokenizer.eos_token_id,
29
- no_repeat_ngram_size=3 # This prevents repeating phrases
30
- )
31
-
32
- # Decode the response, keeping only the new tokens
33
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
34
 
35
- # Update history with new input and response
36
- new_history = chat_history_ids.tolist()[0] # Convert tensor to list for Gradio state
37
 
38
- return response, new_history
 
 
 
 
 
 
39
 
40
  # Gradio interface
41
  iface = gr.Interface(
@@ -50,7 +32,7 @@ iface = gr.Interface(
50
  ],
51
  title="DialoGPT Chat",
52
  description="Chat with DialoGPT-small model. Your conversation history is maintained.",
53
- allow_flagging="never" # Disabling flagging since this is a chat model
54
  )
55
 
56
  iface.launch()
 
1
+ from transformers import pipeline
 
 
2
 
3
+ # Load the pipeline for text generation
4
+ generator = pipeline("text-generation", model="microsoft/DialoGPT-small")
 
 
5
 
6
  # Function to generate a response
7
  def dialoGPT_response(user_input, history):
8
+ # Since the pipeline handles everything, we just need to format our input
9
+ conversation = [{"role": "user", "content": user_input}] if history is None else history + [{"role": "user", "content": user_input}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Generate response using the pipeline, which manages all pre/post-processing
12
+ response = generator(conversation, return_full_text=False, max_length=1000)
13
 
14
+ # Extract the last assistant response
15
+ assistant_response = response[0]['generated_text']
16
+
17
+ # Append this response to history
18
+ new_history = conversation + [{"role": "assistant", "content": assistant_response}]
19
+
20
+ return assistant_response, new_history
21
 
22
  # Gradio interface
23
  iface = gr.Interface(
 
32
  ],
33
  title="DialoGPT Chat",
34
  description="Chat with DialoGPT-small model. Your conversation history is maintained.",
35
+ allow_flagging="never"
36
  )
37
 
38
  iface.launch()