File size: 9,068 Bytes
f7a2a53
 
 
8deab94
 
c97bd5e
8deab94
833e8c8
8deab94
 
3cb80db
 
8deab94
4492796
 
 
 
8deab94
23b0879
8ee116f
 
 
 
 
23b0879
 
 
 
3cb80db
 
 
ef24b05
3cb80db
 
 
 
 
 
 
ef24b05
833e8c8
 
0d93012
 
 
 
8deab94
 
 
 
 
 
 
 
 
 
 
16afc16
8deab94
 
ce443de
8deab94
 
 
 
15c527a
 
 
dadebad
 
0539125
bf3fcf5
 
 
 
8deab94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e529ec7
8deab94
0539125
8deab94
 
 
 
 
 
 
 
0539125
8deab94
 
 
0d93012
0539125
 
0d93012
0539125
0d93012
7db15b6
1eab9d3
b3317ea
7db15b6
0d93012
 
 
1eab9d3
08b4c0f
0539125
8deab94
 
 
 
 
 
89985c7
8deab94
 
 
 
 
0d93012
 
8deab94
f6f4325
8deab94
0539125
 
 
 
 
 
 
8deab94
 
 
 
 
ac6adc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1473469
8c42253
 
16c2a6d
8deab94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0539125
8deab94
 
15c527a
8deab94
 
 
0539125
15c527a
bb4dea4
 
 
 
dadebad
0539125
dadebad
0539125
 
 
 
 
 
 
 
 
 
 
bf3fcf5
e99c60b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# adapted from:
# https://medium.com/@james.irving.phd/creating-your-personal-chatbot-using-hugging-face-spaces-and-streamlit-596a54b9e3ed

import os
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
from transformers import LlamaTokenizer
import streamlit as st
import torch

# Define the model repository
REPO_NAME = 'schuler/experimental-JP47D20'
# REPO_NAME = 'schuler/experimental-JP47D21-KPhi-3-micro-4k-instruct'

# Configure the Streamlit app
st.set_page_config(page_title="Experimental KPhi3 Model - Currently in Training", page_icon="πŸ€—")
st.title("Experimental KPhi3 Model - Currently in Training")

# Load tokenizer and model
@st.cache_resource(show_spinner="Loading model...")
def load_model(local_repo_name):
    # tokenizer = AutoTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
    tokenizer = LlamaTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
    generator_conf = GenerationConfig.from_pretrained(local_repo_name)
    model = AutoModelForCausalLM.from_pretrained(local_repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
    return tokenizer, generator_conf, model

tokenizer, generator_conf, model = load_model(REPO_NAME)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
embed_params = sum(p.numel() for p in model.model.embed_tokens.parameters())*2
non_embed_params = (trainable_params - embed_params) / 1e6

st.markdown(f"*This chat uses the {REPO_NAME} model with {model.get_memory_footprint() / 1e6:.2f} MB memory footprint. ")

# st.markdown(f"Total number of parameters: {total_params}. ")
# st.markdown(f"Total number of trainable parameters: {trainable_params}. ")
# st.markdown(f"Total number of embed parameters: {embed_params}. ")

st.markdown(f"Total number of non embedding trainable parameters: {non_embed_params:.2f} million. ")
st.markdown(f"You may ask questions such as 'What is biology?' or 'What is the human body?'*")

try:
    generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
    st.error(f"Failed to load model: {str(e)}")

# Initialize session state for avatars
if "avatars" not in st.session_state:
    st.session_state.avatars = {'user': None, 'assistant': None}

# Initialize session state for user text input
if 'user_text' not in st.session_state:
    st.session_state.user_text = None

# Initialize session state for model parameters
if "max_response_length" not in st.session_state:
    st.session_state.max_response_length = 64

if "system_message" not in st.session_state:
    st.session_state.system_message = ""

if "starter_message" not in st.session_state:
    st.session_state.starter_message = "Hello, there! How can I help you today?"

if "can_continue" not in st.session_state:
    st.session_state.can_continue = False

# Initialize state for continue action
need_continue = False

# Initialize the last response
if "last_response" not in st.session_state:
    st.session_state.last_response = ''

# Sidebar for settings
with st.sidebar:
    st.header("System Settings")

    # AI Settings
    st.session_state.system_message = st.text_area(
        "System Message", value=st.session_state.system_message
    )
    st.session_state.starter_message = st.text_area(
        'First AI Message', value=st.session_state.starter_message
    )

    # Model Settings
    st.session_state.max_response_length = st.number_input(
        "Max Response Length", value=st.session_state.max_response_length
    )

    # Avatar Selection
    st.markdown("*Select Avatars:*")
    col1, col2 = st.columns(2)
    with col1:
        st.session_state.avatars['assistant'] = st.selectbox(
            "AI Avatar", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
        )
    with col2:
        st.session_state.avatars['user'] = st.selectbox(
            "User Avatar", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
        )
    # Reset Chat History
    reset_history = st.button("Reset Chat History")

# Initialize or reset chat history
if "chat_history" not in st.session_state or reset_history:
    st.session_state.chat_history = [] # [{"role": "assistant", "content": st.session_state.starter_message}]

def get_response(system_message, chat_history, user_text, max_new_tokens=256, continue_last=False):
    """
    Generates a response from the chatbot model.

    Args:
        system_message (str): The system message for the conversation.
        chat_history (list): The list of previous chat messages.
        user_text (str): The user's input text.
        max_new_tokens (int): The maximum number of new tokens to generate.
        continue_last (bool): Whether to continue the last assistant response.

    Returns:
        tuple: A tuple containing the generated response and the updated chat history.
    """    
    if continue_last:
        # We want to continue the last assistant response
        prompt = st.session_state.last_response
    else:
        # Build the conversation prompt
        if (len(system_message)>0):
            prompt = "<|assistant|>"+system_message+f"<|end|>"
        else:
            prompt = ''
        # f"{system_message}\nCurrent Conversation:\n"
        for message in chat_history:
            role = "<|assistant|>" if message['role'] == 'assistant' else "<|user|>"
            prompt += f"{role}{message['content']}<|end|>"        
        prompt += f"<|user|>{user_text}<|end|><|assistant|>"

    # Generate the response
    response_output = generator(
        prompt,
        generation_config=generator_conf,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.25,
        repetition_penalty=1.2
    )

    generated_text = response_output[0]['generated_text']

    st.session_state.last_response = generated_text

    # Extract the assistant's response
    assistant_response = generated_text[len(prompt):] # .strip()

    if continue_last:
        # Append the continued text to the last assistant message
        st.session_state.chat_history[-1]['content'] += assistant_response
    else:
        # Update the chat history
        chat_history.append({'role': 'user', 'content': user_text})
        chat_history.append({'role': 'assistant', 'content': assistant_response})

    return assistant_response, chat_history

# Chat interface
chat_interface = st.container()
def refresh_chat():
    with chat_interface:
        output_container = st.container()
    
        # Display chat messages
        with output_container:
            for idx, message in enumerate(st.session_state.chat_history):
                if message['role'] == 'system':
                    continue
                with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
                    st.markdown(message['content'])
    
                    # If this is the last assistant message, add the "Continue" button
                    # if idx == len(st.session_state.chat_history) - 1 and message['role'] == 'assistant':

refresh_chat()

# User input area (moved to the bottom)
st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
                
# When the user enters new text
if st.session_state.user_text:
    # Display the user's message
    with st.chat_message("user", avatar=st.session_state.avatars['user']):
        st.markdown(st.session_state.user_text)

    # Display a spinner while generating the response
    with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
        with st.spinner("Thinking..."):
            # Generate the assistant's response
            response, st.session_state.chat_history = get_response(
                system_message=st.session_state.system_message,
                user_text=st.session_state.user_text,
                chat_history=st.session_state.chat_history,
                max_new_tokens=st.session_state.max_response_length,
                continue_last=False
            )
            st.markdown(response)
            st.session_state.can_continue = True

    # Clear the user input
    st.session_state.user_text = None

if st.session_state.can_continue:
    if st.button("Continue"):
        need_continue = True
else:
    need_continue = False

# If "Continue" button was pressed
if need_continue:
    # Display a spinner while generating the continuation
    with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
        with st.spinner("Continuing..."):
            # Generate the continuation of the assistant's last response
            response, st.session_state.chat_history = get_response(
                system_message=st.session_state.system_message,
                user_text=None,
                chat_history=st.session_state.chat_history,
                max_new_tokens=st.session_state.max_response_length,
                continue_last=True
            )
            st.markdown(response)
            st.rerun()