File size: 3,401 Bytes
8b51148
 
a8acb9b
8b51148
 
 
 
 
 
 
a8acb9b
8b51148
a8acb9b
8b51148
 
 
a8acb9b
8b51148
 
 
a8acb9b
8b51148
 
a8acb9b
8b51148
 
 
9c0b768
8b51148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d9971
8b51148
00d9971
9c0b768
00d9971
8b51148
00d9971
 
8b51148
 
00d9971
8b51148
00d9971
a8acb9b
00d9971
 
 
9c0b768
8b51148
00d9971
8b51148
 
 
00d9971
8b51148
 
00d9971
8b51148
00d9971
 
 
 
 
 
 
8b51148
00d9971
 
8b51148
00d9971
 
 
 
 
 
 
 
 
 
 
 
8b51148
00d9971
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# import streamlit as st
# from transformers import AutoModelForCausalLM, AutoTokenizer

# # Load the model and tokenizer
# @st.cache_resource
# def load_model_and_tokenizer():
#     model_name = "microsoft/DialoGPT-medium"  # Replace with your chosen model
#     tokenizer = AutoTokenizer.from_pretrained(model_name)
#     model = AutoModelForCausalLM.from_pretrained(model_name)
#     return tokenizer, model

# tokenizer, model = load_model_and_tokenizer()

# # Streamlit App
# st.title("General Chatbot")
# st.write("A chatbot powered by an open-source model from Hugging Face.")

# # Initialize the conversation
# if "conversation_history" not in st.session_state:
#     st.session_state["conversation_history"] = []

# # Input box for user query
# user_input = st.text_input("You:", placeholder="Ask me anything...", key="user_input")

# if st.button("Send") and user_input:
#     # Append user input to history
#     st.session_state["conversation_history"].append({"role": "user", "content": user_input})
    
#     # Tokenize and generate response
#     input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
#     chat_history_ids = model.generate(input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
#     response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

#     # Append model response to history
#     st.session_state["conversation_history"].append({"role": "assistant", "content": response})

# # Display the conversation
# for message in st.session_state["conversation_history"]:
#     if message["role"] == "user":
#         st.write(f"**You:** {message['content']}")
#     else:
#         st.write(f"**Bot:** {message['content']}")
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer

st.title("🤖 Improved Chatbot")

# Initialize model and tokenizer
@st.cache_resource
def load_model():
    model_name = "microsoft/DialoGPT-medium"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

# Initialize chat history
if "history" not in st.session_state:
    st.session_state.history = []

# Display chat history
for message in st.session_state.history:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# User input
if prompt := st.chat_input("Type your message..."):
    # Add user message to history
    st.session_state.history.append({"role": "user", "content": prompt})
    
    # Prepare context for the model
    input_ids = tokenizer.encode(
        "\n".join([f"{msg['role']}: {msg['content']}" for msg in st.session_state.history[-5:]]) + "\nassistant:",
        return_tensors="pt"
    )
    
    # Generate response
    with st.spinner("Thinking..."):
        output = model.generate(
            input_ids,
            max_length=1000,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=3,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )
        
        response = tokenizer.decode(output[0], skip_special_tokens=True).split("assistant:")[-1].strip()

    # Add assistant response to history
    st.session_state.history.append({"role": "assistant", "content": response})
    
    st.rerun()