sabahat-shakeel commited on
Commit
d3a021c
·
verified ·
1 Parent(s): 99abf93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -35
app.py CHANGED
@@ -47,54 +47,46 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
47
  # Load the model and tokenizer
48
  @st.cache_resource
49
  def load_model_and_tokenizer():
50
- model_name = "microsoft/DialoGPT-medium" # You can replace with any Hugging Face conversational model
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  model = AutoModelForCausalLM.from_pretrained(model_name)
53
  return tokenizer, model
54
 
55
  tokenizer, model = load_model_and_tokenizer()
56
 
57
- # Streamlit App Title
58
- st.title("General Chatbot")
59
- st.markdown("This chatbot is powered by an open-source model from Hugging Face. Feel free to ask me anything!")
60
 
61
- # Initialize the session state for conversation history
62
- if "chat_history" not in st.session_state:
63
- st.session_state["chat_history"] = ""
64
 
65
- # User Input Section
66
- user_input = st.text_input("You:", placeholder="Type your message here...", key="user_input")
67
 
68
- if st.button("Send") and user_input:
69
- # Add user input to the conversation history
70
- st.session_state["chat_history"] += f"User: {user_input}\n"
71
-
72
- # Tokenize the input with conversation history
73
- input_ids = tokenizer.encode(st.session_state["chat_history"], return_tensors="pt")
74
 
75
- # Generate a response
 
 
 
 
 
76
  chat_history_ids = model.generate(
77
- input_ids,
78
- max_length=1500, # Allow long responses
79
- min_length=200, # Ensure responses are not too short
80
- temperature=1.0, # Adjust for creativity
81
- top_p=0.9, # Nucleus sampling for focused responses
82
- repetition_penalty=1.2, # Penalize repeated phrases
83
  pad_token_id=tokenizer.eos_token_id
84
  )
85
-
86
- # Decode the model's response
87
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
88
 
89
- # Add the response to the conversation history
90
- st.session_state["chat_history"] += f"Bot: {response}\n"
91
-
92
- # Display the conversation
93
- st.markdown(f"**You:** {user_input}")
94
- st.markdown(f"**Bot:** {response}")
95
-
96
- # Display Full Conversation History
97
- st.divider()
98
- st.subheader("Conversation History:")
99
- st.text(st.session_state["chat_history"])
100
 
 
 
 
 
 
 
 
47
  # Load the model and tokenizer
48
  @st.cache_resource
49
  def load_model_and_tokenizer():
50
+ model_name = "microsoft/DialoGPT-medium" # Replace with your chosen model
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  model = AutoModelForCausalLM.from_pretrained(model_name)
53
  return tokenizer, model
54
 
55
  tokenizer, model = load_model_and_tokenizer()
56
 
57
+ # Streamlit App
58
+ st.title("General Chatbot with Adjustable Answer Length")
59
+ st.write("A chatbot powered by an open-source model from Hugging Face.")
60
 
61
+ # Initialize the conversation
62
+ if "conversation_history" not in st.session_state:
63
+ st.session_state["conversation_history"] = []
64
 
65
+ # Input for user query
66
+ user_input = st.text_input("You:", placeholder="Ask me anything...", key="user_input")
67
 
68
+ # Slider for setting max response length
69
+ max_length = st.slider("Set the maximum response length:", min_value=50, max_value=500, step=50, value=150)
 
 
 
 
70
 
71
+ if st.button("Send") and user_input:
72
+ # Append user input to history
73
+ st.session_state["conversation_history"].append({"role": "user", "content": user_input})
74
+
75
+ # Tokenize and generate response
76
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
77
  chat_history_ids = model.generate(
78
+ input_ids,
79
+ max_length=max_length, # Use the user-specified max length
 
 
 
 
80
  pad_token_id=tokenizer.eos_token_id
81
  )
 
 
82
  response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
83
 
84
+ # Append model response to history
85
+ st.session_state["conversation_history"].append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
86
 
87
+ # Display the conversation
88
+ for message in st.session_state["conversation_history"]:
89
+ if message["role"] == "user":
90
+ st.write(f"**You:** {message['content']}")
91
+ else:
92
+ st.write(f"**Bot:** {message['content']}")