Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# Function to generate a response | |
def generate_response(input_text): | |
# Adjusted input to include the [Bot] marker | |
#adjusted_input = f"{input_text} [Bot]" | |
# Encode the adjusted input | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate a sequence of text with a slightly increased max_length to account for the prompt length | |
output_sequences = model.generate( | |
input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
max_length=100, # Adjusted max_length | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
no_repeat_ngram_size=2, | |
pad_token_id=tokenizer.eos_token_id, | |
#early_stopping=True, | |
do_sample=True | |
) | |
# Decode the generated sequence | |
full_generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
# Extract the generated response after the [Bot] marker | |
bot_response_start = full_generated_text.find('[Bot]') + len('[Bot]') | |
bot_response = full_generated_text[bot_response_start:] | |
# Trim the response to end at the last period within the specified max_length | |
last_period_index = bot_response.rfind('.') | |
if last_period_index != -1: | |
bot_response = bot_response[:last_period_index + 1] | |
return bot_response.strip() | |
# Load pre-trained model tokenizer (vocabulary) and model | |
model_name = 'KhantKyaw/Chat_GPT-2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
st.title("Chat with GPT-2") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.container(): | |
st.markdown(f"**{message['role'].capitalize()}**: {message['content']}") | |
# React to user input | |
prompt = st.text_input("What is up?", key="chat_input") | |
if prompt: | |
with st.container(): | |
st.markdown(f"**User**: {prompt}") | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Decode the generated tokens and remove the eos token | |
response = generate_response(prompt) | |
with st.container(): | |
st.markdown(f"**GPT-2**: {response}") | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |