File size: 9,068 Bytes
f7a2a53 8deab94 c97bd5e 8deab94 833e8c8 8deab94 3cb80db 8deab94 4492796 8deab94 23b0879 8ee116f 23b0879 3cb80db ef24b05 3cb80db ef24b05 833e8c8 0d93012 8deab94 16afc16 8deab94 ce443de 8deab94 15c527a dadebad 0539125 bf3fcf5 8deab94 e529ec7 8deab94 0539125 8deab94 0539125 8deab94 0d93012 0539125 0d93012 0539125 0d93012 7db15b6 1eab9d3 b3317ea 7db15b6 0d93012 1eab9d3 08b4c0f 0539125 8deab94 89985c7 8deab94 0d93012 8deab94 f6f4325 8deab94 0539125 8deab94 ac6adc0 1473469 8c42253 16c2a6d 8deab94 0539125 8deab94 15c527a 8deab94 0539125 15c527a bb4dea4 dadebad 0539125 dadebad 0539125 bf3fcf5 e99c60b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# adapted from:
# https://medium.com/@james.irving.phd/creating-your-personal-chatbot-using-hugging-face-spaces-and-streamlit-596a54b9e3ed
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
from transformers import LlamaTokenizer
import streamlit as st
import torch
# Define the model repository
REPO_NAME = 'schuler/experimental-JP47D20'
# REPO_NAME = 'schuler/experimental-JP47D21-KPhi-3-micro-4k-instruct'
# Configure the Streamlit app
st.set_page_config(page_title="Experimental KPhi3 Model - Currently in Training", page_icon="π€")
st.title("Experimental KPhi3 Model - Currently in Training")
# Load tokenizer and model
@st.cache_resource(show_spinner="Loading model...")
def load_model(local_repo_name):
# tokenizer = AutoTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
tokenizer = LlamaTokenizer.from_pretrained(local_repo_name, trust_remote_code=True)
generator_conf = GenerationConfig.from_pretrained(local_repo_name)
model = AutoModelForCausalLM.from_pretrained(local_repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
return tokenizer, generator_conf, model
tokenizer, generator_conf, model = load_model(REPO_NAME)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
embed_params = sum(p.numel() for p in model.model.embed_tokens.parameters())*2
non_embed_params = (trainable_params - embed_params) / 1e6
st.markdown(f"*This chat uses the {REPO_NAME} model with {model.get_memory_footprint() / 1e6:.2f} MB memory footprint. ")
# st.markdown(f"Total number of parameters: {total_params}. ")
# st.markdown(f"Total number of trainable parameters: {trainable_params}. ")
# st.markdown(f"Total number of embed parameters: {embed_params}. ")
st.markdown(f"Total number of non embedding trainable parameters: {non_embed_params:.2f} million. ")
st.markdown(f"You may ask questions such as 'What is biology?' or 'What is the human body?'*")
try:
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
# Initialize session state for avatars
if "avatars" not in st.session_state:
st.session_state.avatars = {'user': None, 'assistant': None}
# Initialize session state for user text input
if 'user_text' not in st.session_state:
st.session_state.user_text = None
# Initialize session state for model parameters
if "max_response_length" not in st.session_state:
st.session_state.max_response_length = 64
if "system_message" not in st.session_state:
st.session_state.system_message = ""
if "starter_message" not in st.session_state:
st.session_state.starter_message = "Hello, there! How can I help you today?"
if "can_continue" not in st.session_state:
st.session_state.can_continue = False
# Initialize state for continue action
need_continue = False
# Initialize the last response
if "last_response" not in st.session_state:
st.session_state.last_response = ''
# Sidebar for settings
with st.sidebar:
st.header("System Settings")
# AI Settings
st.session_state.system_message = st.text_area(
"System Message", value=st.session_state.system_message
)
st.session_state.starter_message = st.text_area(
'First AI Message', value=st.session_state.starter_message
)
# Model Settings
st.session_state.max_response_length = st.number_input(
"Max Response Length", value=st.session_state.max_response_length
)
# Avatar Selection
st.markdown("*Select Avatars:*")
col1, col2 = st.columns(2)
with col1:
st.session_state.avatars['assistant'] = st.selectbox(
"AI Avatar", options=["π€", "π¬", "π€"], index=0
)
with col2:
st.session_state.avatars['user'] = st.selectbox(
"User Avatar", options=["π€", "π±ββοΈ", "π¨πΎ", "π©", "π§πΎ"], index=0
)
# Reset Chat History
reset_history = st.button("Reset Chat History")
# Initialize or reset chat history
if "chat_history" not in st.session_state or reset_history:
st.session_state.chat_history = [] # [{"role": "assistant", "content": st.session_state.starter_message}]
def get_response(system_message, chat_history, user_text, max_new_tokens=256, continue_last=False):
"""
Generates a response from the chatbot model.
Args:
system_message (str): The system message for the conversation.
chat_history (list): The list of previous chat messages.
user_text (str): The user's input text.
max_new_tokens (int): The maximum number of new tokens to generate.
continue_last (bool): Whether to continue the last assistant response.
Returns:
tuple: A tuple containing the generated response and the updated chat history.
"""
if continue_last:
# We want to continue the last assistant response
prompt = st.session_state.last_response
else:
# Build the conversation prompt
if (len(system_message)>0):
prompt = "<|assistant|>"+system_message+f"<|end|>"
else:
prompt = ''
# f"{system_message}\nCurrent Conversation:\n"
for message in chat_history:
role = "<|assistant|>" if message['role'] == 'assistant' else "<|user|>"
prompt += f"{role}{message['content']}<|end|>"
prompt += f"<|user|>{user_text}<|end|><|assistant|>"
# Generate the response
response_output = generator(
prompt,
generation_config=generator_conf,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.25,
repetition_penalty=1.2
)
generated_text = response_output[0]['generated_text']
st.session_state.last_response = generated_text
# Extract the assistant's response
assistant_response = generated_text[len(prompt):] # .strip()
if continue_last:
# Append the continued text to the last assistant message
st.session_state.chat_history[-1]['content'] += assistant_response
else:
# Update the chat history
chat_history.append({'role': 'user', 'content': user_text})
chat_history.append({'role': 'assistant', 'content': assistant_response})
return assistant_response, chat_history
# Chat interface
chat_interface = st.container()
def refresh_chat():
with chat_interface:
output_container = st.container()
# Display chat messages
with output_container:
for idx, message in enumerate(st.session_state.chat_history):
if message['role'] == 'system':
continue
with st.chat_message(message['role'], avatar=st.session_state.avatars[message['role']]):
st.markdown(message['content'])
# If this is the last assistant message, add the "Continue" button
# if idx == len(st.session_state.chat_history) - 1 and message['role'] == 'assistant':
refresh_chat()
# User input area (moved to the bottom)
st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
# When the user enters new text
if st.session_state.user_text:
# Display the user's message
with st.chat_message("user", avatar=st.session_state.avatars['user']):
st.markdown(st.session_state.user_text)
# Display a spinner while generating the response
with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
with st.spinner("Thinking..."):
# Generate the assistant's response
response, st.session_state.chat_history = get_response(
system_message=st.session_state.system_message,
user_text=st.session_state.user_text,
chat_history=st.session_state.chat_history,
max_new_tokens=st.session_state.max_response_length,
continue_last=False
)
st.markdown(response)
st.session_state.can_continue = True
# Clear the user input
st.session_state.user_text = None
if st.session_state.can_continue:
if st.button("Continue"):
need_continue = True
else:
need_continue = False
# If "Continue" button was pressed
if need_continue:
# Display a spinner while generating the continuation
with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']):
with st.spinner("Continuing..."):
# Generate the continuation of the assistant's last response
response, st.session_state.chat_history = get_response(
system_message=st.session_state.system_message,
user_text=None,
chat_history=st.session_state.chat_history,
max_new_tokens=st.session_state.max_response_length,
continue_last=True
)
st.markdown(response)
st.rerun() |