|
|
|
import streamlit as st |
|
import uuid |
|
import sys |
|
import requests |
|
from peft import * |
|
import bitsandbytes as bnb |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
from datasets import load_dataset |
|
from huggingface_hub import notebook_login |
|
from peft import ( |
|
LoraConfig, |
|
PeftConfig, |
|
get_peft_model, |
|
prepare_model_for_kbit_training, |
|
) |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
) |
|
import pickle |
|
|
|
USER_ICON = "images/user-icon.png" |
|
AI_ICON = "images/ai-icon.png" |
|
MAX_HISTORY_LENGTH = 5 |
|
|
|
if 'user_id' in st.session_state: |
|
user_id = st.session_state['user_id'] |
|
else: |
|
user_id = str(uuid.uuid4()) |
|
st.session_state['user_id'] = user_id |
|
|
|
if 'chat_history' not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
if "chats" not in st.session_state: |
|
st.session_state.chats = [ |
|
{ |
|
'id': 0, |
|
'question': '', |
|
'answer': '' |
|
} |
|
] |
|
|
|
if "questions" not in st.session_state: |
|
st.session_state.questions = [] |
|
|
|
if "answers" not in st.session_state: |
|
st.session_state.answers = [] |
|
|
|
if "input" not in st.session_state: |
|
st.session_state.input = "" |
|
|
|
st.markdown(""" |
|
<style> |
|
.block-container { |
|
padding-top: 32px; |
|
padding-bottom: 32px; |
|
padding-left: 0; |
|
padding-right: 0; |
|
} |
|
.element-container img { |
|
background-color: #000000; |
|
} |
|
|
|
.main-header { |
|
font-size: 24px; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
with open('model_saved.pkl', 'rb') as f: |
|
model = pickle.load(f) |
|
if not isinstance(model, str): |
|
st.error("The loaded model is not valid.") |
|
|
|
def write_top_bar(): |
|
col1, col2, col3 = st.columns([1,10,2]) |
|
with col1: |
|
st.image(AI_ICON, use_column_width='always') |
|
with col2: |
|
header = "Cogwise Intelligent Assistant" |
|
st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True) |
|
with col3: |
|
clear = st.button("Clear Chat") |
|
return clear |
|
|
|
clear = write_top_bar() |
|
|
|
if clear: |
|
st.session_state.questions = [] |
|
st.session_state.answers = [] |
|
st.session_state.input = "" |
|
st.session_state["chat_history"] = [] |
|
|
|
def handle_input(): |
|
input = st.session_state.input |
|
question_with_id = { |
|
'question': input, |
|
'id': len(st.session_state.questions) |
|
} |
|
st.session_state.questions.append(question_with_id) |
|
|
|
chat_history = st.session_state["chat_history"] |
|
if len(chat_history) == MAX_HISTORY_LENGTH: |
|
chat_history = chat_history[:-1] |
|
|
|
prompt = input |
|
answer = model |
|
|
|
chat_history.append((input, answer)) |
|
|
|
st.session_state.answers.append({ |
|
'answer': answer, |
|
'id': len(st.session_state.questions) |
|
}) |
|
st.session_state.input = "" |
|
|
|
def write_user_message(md): |
|
col1, col2 = st.columns([1,12]) |
|
|
|
with col1: |
|
st.image(USER_ICON, use_column_width='always') |
|
with col2: |
|
st.warning(md['question']) |
|
|
|
def render_answer(answer): |
|
col1, col2 = st.columns([1,12]) |
|
with col1: |
|
st.image(AI_ICON, use_column_width='always') |
|
with col2: |
|
st.info(answer) |
|
|
|
def write_chat_message(md, q): |
|
chat = st.container() |
|
with chat: |
|
render_answer(md['answer']) |
|
|
|
with st.container(): |
|
for (q, a) in zip(st.session_state.questions, st.session_state.answers): |
|
write_user_message(q) |
|
write_chat_message(a, q) |
|
|
|
st.markdown('---') |
|
input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input) |