Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModel, AutoTokenizer | |
import mdtex2html | |
from utils import load_model_on_gpus | |
st.set_page_config(page_title="ChatGLM2-6B", page_icon=":robot:") | |
st.header("ChatGLM2-6B") | |
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) | |
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() | |
# Load model on multiple GPUs | |
#model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) | |
model = model.eval() | |
def postprocess(chat): | |
for i, (user, response) in enumerate(chat): | |
chat[i] = (mdtex2html.convert(user), mdtex2html.convert(response)) | |
return chat | |
user_input = st.text_area("Input:", height=200, placeholder="Ask me anything!") | |
if user_input: | |
history = st.session_state.get('history', []) | |
max_length = st.slider("Max Length:", 0, 32768, 8192, 1) | |
top_p = st.slider("Top P:", 0.0, 1.0, 0.8, 0.01) | |
temperature = st.slider("Temperature:", 0.0, 1.0, 0.95, 0.01) | |
if 'past_key_values' not in st.session_state: | |
st.session_state['past_key_values'] = None | |
with st.spinner("Thinking..."): | |
response = model.generate(tokenizer.encode(user_input), | |
max_length=max_length, | |
top_p=top_p, | |
temperature=temperature, | |
return_dict_in_generate=True, | |
output_scores=True, | |
return_past_key_values=True, | |
past_key_values=st.session_state.past_key_values) | |
st.session_state.past_key_values = response.past_key_values | |
history.append((user_input, response.sequences[0])) | |
history = postprocess(history) | |
for user, chatbot in history: | |
message = f"**Human:** {user}" if user else "" | |
response = f"**AI:** {chatbot}" if chatbot else "" | |
st.markdown(message + response, unsafe_allow_html=True) | |
if st.button("Clear History"): | |
st.session_state['history'] = [] | |
st.session_state['past_key_values'] = None |