Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# Function to generate a response | |
def generate_response(input_text): | |
# Adjusted input to include the [Bot] marker | |
#adjusted_input = f"{input_text} [Bot]" | |
# Encode the adjusted input | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate a sequence of text with a slightly increased max_length to account for the prompt length | |
output_sequences = model.generate( | |
input_ids=inputs['input_ids'], | |
attention_mask=inputs['attention_mask'], | |
max_length=100, # Adjusted max_length | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
no_repeat_ngram_size=2, | |
pad_token_id=tokenizer.eos_token_id, | |
#early_stopping=True, | |
do_sample=True | |
) | |
# Decode the generated sequence | |
full_generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True) | |
# Extract the generated response after the [Bot] marker | |
bot_response_start = full_generated_text.find('[Bot]') + len('[Bot]') | |
bot_response = full_generated_text[bot_response_start:] | |
# Trim the response to end at the last period within the specified max_length | |
last_period_index = bot_response.rfind('.') | |
if last_period_index != -1: | |
bot_response = bot_response[:last_period_index + 1] | |
return bot_response.strip() | |
# Load pre-trained model tokenizer (vocabulary) and model | |
model_name = 'KhantKyaw/Chat_GPT-2' | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
# Chat loop | |
#print("Chatbot is ready. Type 'quit' to exit.") | |
#while True: | |
#user_input = input("You: ") | |
#if user_input.lower() == "quit": | |
#break | |
#response = generate_response(user_input) | |
#print("Chatbot:", response) | |
prompt = st.text_input("Say Something!", key=None, max_chars=None, disabled=False) | |
if prompt: | |
with st.container(): | |
# Displaying the user's input question. | |
st.markdown(prompt) | |
# Generating and displaying the response. | |
response = generate_response(prompt) | |
st.markdown(generate_response(prompt)) | |
#prompt = st.chat_input(placeholder="Say Something!",key=None, max_chars=None, disabled=False, on_submit=None, args=None, kwargs=None) | |
#if prompt: | |
# with st.chat_message(name="AI",avatar=None): | |
# st.write(generate_response(prompt)) |