Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| # Function to generate a response | |
| def generate_response(input_text): | |
| # Adjusted input to include the [Bot] marker | |
| #adjusted_input = f"{input_text} [Bot]" | |
| # Encode the adjusted input | |
| inputs = tokenizer(input_text, return_tensors="pt") | |
| # Generate a sequence of text with a slightly increased max_length to account for the prompt length | |
| output_sequences = model.generate( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| max_length=100, # Adjusted max_length | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| no_repeat_ngram_size=2, | |
| pad_token_id=tokenizer.eos_token_id, | |
| #early_stopping=True, | |
| do_sample=True | |
| ) | |
| # Decode the generated sequence | |
| full_generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
| # Extract the generated response after the [Bot] marker | |
| bot_response_start = full_generated_text.find('[Bot]') + len('[Bot]') | |
| bot_response = full_generated_text[bot_response_start:] | |
| # Trim the response to end at the last period within the specified max_length | |
| last_period_index = bot_response.rfind('.') | |
| if last_period_index != -1: | |
| bot_response = bot_response[:last_period_index + 1] | |
| return bot_response.strip() | |
| # Load pre-trained model tokenizer (vocabulary) and model | |
| model_name = 'KhantKyaw/Chat_GPT-2' | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| st.title("Chat with GPT-2") | |
| # Initialize chat history | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display chat messages from history on app rerun | |
| for message in st.session_state.messages: | |
| with st.container(): | |
| st.markdown(f"**{message['role'].capitalize()}**: {message['content']}") | |
| # React to user input | |
| prompt = st.text_input("What is up?", key="chat_input") | |
| if prompt: | |
| with st.container(): | |
| st.markdown(f"**User**: {prompt}") | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Decode the generated tokens and remove the eos token | |
| response = generate_response(prompt) | |
| with st.container(): | |
| st.markdown(f"**GPT-2**: {response}") | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |