sabahat-shakeel commited on
Commit
0136084
·
verified ·
1 Parent(s): 00d9971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -41
app.py CHANGED
@@ -41,56 +41,36 @@
41
  # else:
42
  # st.write(f"**Bot:** {message['content']}")
43
  import streamlit as st
44
- from transformers import AutoModelForCausalLM, AutoTokenizer
45
 
46
- st.title("🤖 Improved Chatbot")
47
 
48
- # Initialize model and tokenizer
49
  @st.cache_resource
50
- def load_model():
51
- model_name = "microsoft/DialoGPT-medium"
52
- tokenizer = AutoTokenizer.from_pretrained(model_name)
53
- model = AutoModelForCausalLM.from_pretrained(model_name)
54
- return model, tokenizer
55
 
56
- model, tokenizer = load_model()
57
 
58
- # Initialize chat history
59
- if "history" not in st.session_state:
60
- st.session_state.history = []
61
 
62
- # Display chat history
63
- for message in st.session_state.history:
64
- with st.chat_message(message["role"]):
65
- st.markdown(message["content"])
66
 
67
- # User input
68
- if prompt := st.chat_input("Type your message..."):
69
- # Add user message to history
70
- st.session_state.history.append({"role": "user", "content": prompt})
71
-
72
- # Prepare context for the model
73
- input_ids = tokenizer.encode(
74
- "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.history[-5:]]) + "\nassistant:",
75
- return_tensors="pt"
76
- )
77
 
78
  # Generate response
79
  with st.spinner("Thinking..."):
80
- output = model.generate(
81
- input_ids,
82
- max_length=1000,
83
- pad_token_id=tokenizer.eos_token_id,
84
- no_repeat_ngram_size=3,
85
- do_sample=True,
86
- top_k=50,
87
- top_p=0.95,
88
- temperature=0.7
89
- )
90
-
91
- response = tokenizer.decode(output[0], skip_special_tokens=True).split("assistant:")[-1].strip()
92
-
93
- # Add assistant response to history
94
- st.session_state.history.append({"role": "assistant", "content": response})
95
 
96
  st.rerun()
 
41
  # else:
42
  # st.write(f"**Bot:** {message['content']}")
43
  import streamlit as st
44
+ from transformers import pipeline
45
 
46
+ st.title("🤖 Conversational Chatbot")
47
 
 
48
  @st.cache_resource
49
+ def load_chatbot():
50
+ return pipeline("conversational", model="facebook/blenderbot-400M-distill")
 
 
 
51
 
52
+ chatbot = load_chatbot()
53
 
54
+ if "conversation" not in st.session_state:
55
+ st.session_state.conversation = []
 
56
 
57
+ # Display history
58
+ for msg in st.session_state.conversation:
59
+ with st.chat_message(msg["role"]):
60
+ st.markdown(msg["content"])
61
 
62
+ if prompt := st.chat_input("Say something"):
63
+ # Add user message
64
+ st.session_state.conversation.append({"role": "user", "content": prompt})
 
 
 
 
 
 
 
65
 
66
  # Generate response
67
  with st.spinner("Thinking..."):
68
+ result = chatbot(str(st.session_state.conversation))
69
+
70
+ # Extract bot response
71
+ response = result.generated_responses[-1]
72
+
73
+ # Add to conversation
74
+ st.session_state.conversation.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
75
 
76
  st.rerun()