sabahat-shakeel commited on
Commit
b81e260
·
verified ·
1 Parent(s): 9b0d5d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -47,7 +47,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
47
  # Load the model and tokenizer
48
  @st.cache_resource
49
  def load_model_and_tokenizer():
50
- model_name = "EleutherAI/gpt-neo-2.7B" # Model suited for longer responses
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  model = AutoModelForCausalLM.from_pretrained(model_name)
53
  return tokenizer, model
@@ -55,43 +55,46 @@ def load_model_and_tokenizer():
55
  tokenizer, model = load_model_and_tokenizer()
56
 
57
  # Streamlit App
58
- st.title("General Chatbot with Detailed Responses")
59
  st.write("A chatbot powered by an open-source model from Hugging Face.")
60
 
61
- # Initialize the conversation
62
  if "conversation_history" not in st.session_state:
63
  st.session_state["conversation_history"] = []
64
 
65
- # Input for user query
66
  user_input = st.text_input("You:", placeholder="Ask me anything...", key="user_input")
67
 
68
- # Slider for response length
69
- max_length = st.slider("Set maximum response length:", min_value=100, max_value=1000, step=50, value=300)
70
-
71
  if st.button("Send") and user_input:
72
- # Add user query to the conversation
73
  st.session_state["conversation_history"].append({"role": "user", "content": user_input})
74
-
75
- # Build prompt for the model
76
- prompt = f"{user_input} Please explain in detail."
77
- input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors="pt")
78
-
79
- # Generate response
 
 
 
 
 
 
80
  chat_history_ids = model.generate(
81
  input_ids,
82
- max_length=max_length,
83
- temperature=0.9,
84
- top_p=0.9,
85
- repetition_penalty=1.2,
86
- pad_token_id=tokenizer.eos_token_id,
87
- early_stopping=False
88
  )
89
- response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
90
 
91
- # Add bot response to the conversation
 
92
  st.session_state["conversation_history"].append({"role": "assistant", "content": response})
93
 
94
- # Display conversation history
95
  for message in st.session_state["conversation_history"]:
96
  if message["role"] == "user":
97
  st.write(f"**You:** {message['content']}")
 
47
  # Load the model and tokenizer
48
  @st.cache_resource
49
  def load_model_and_tokenizer():
50
+ model_name = "microsoft/DialoGPT-medium" # Replace with your chosen model
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  model = AutoModelForCausalLM.from_pretrained(model_name)
53
  return tokenizer, model
 
55
  tokenizer, model = load_model_and_tokenizer()
56
 
57
  # Streamlit App
58
+ st.title("General Chatbot")
59
  st.write("A chatbot powered by an open-source model from Hugging Face.")
60
 
61
+ # Initialize the conversation history
62
  if "conversation_history" not in st.session_state:
63
  st.session_state["conversation_history"] = []
64
 
65
+ # Input box for user query
66
  user_input = st.text_input("You:", placeholder="Ask me anything...", key="user_input")
67
 
 
 
 
68
  if st.button("Send") and user_input:
69
+ # Append user input to history
70
  st.session_state["conversation_history"].append({"role": "user", "content": user_input})
71
+
72
+ # Prepare the input for the model
73
+ conversation_context = ""
74
+ for message in st.session_state["conversation_history"]:
75
+ if message["role"] == "user":
76
+ conversation_context += f"User: {message['content']}\n"
77
+ elif message["role"] == "assistant":
78
+ conversation_context += f"Bot: {message['content']}\n"
79
+
80
+ input_ids = tokenizer.encode(conversation_context + "Bot:", return_tensors="pt")
81
+
82
+ # Generate the response with adjusted parameters
83
  chat_history_ids = model.generate(
84
  input_ids,
85
+ max_length=500, # Increase maximum length for longer responses
86
+ num_return_sequences=1,
87
+ temperature=0.7, # Adjust for creativity (lower is more focused, higher is more diverse)
88
+ top_p=0.9, # Use nucleus sampling for diversity
89
+ top_k=50, # Limit to top-k tokens for more controlled output
90
+ pad_token_id=tokenizer.eos_token_id
91
  )
 
92
 
93
+ # Decode the response and add it to history
94
+ response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
95
  st.session_state["conversation_history"].append({"role": "assistant", "content": response})
96
 
97
+ # Display the conversation
98
  for message in st.session_state["conversation_history"]:
99
  if message["role"] == "user":
100
  st.write(f"**You:** {message['content']}")