|
import os |
|
from datetime import datetime |
|
|
|
import streamlit as st |
|
from langchain import LLMChain |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic |
|
from langchain.chat_models.base import BaseChatModel |
|
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory |
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from streamlit_feedback import streamlit_feedback |
|
|
|
_STMEMORY = StreamlitChatMessageHistory(key="langchain_messages") |
|
_MEMORY = ConversationBufferMemory( |
|
chat_memory=_STMEMORY, |
|
return_messages=True, |
|
memory_key="chat_history", |
|
) |
|
|
|
_DEFAULT_SYSTEM_PROMPT = os.environ.get( |
|
"DEFAULT_SYSTEM_PROMPT", |
|
"You are a helpful chatbot.", |
|
) |
|
|
|
_MODEL_DICT = { |
|
"gpt-3.5-turbo": "OpenAI", |
|
"gpt-4": "OpenAI", |
|
"claude-instant-v1": "Anthropic", |
|
"claude-2": "Anthropic", |
|
"meta-llama/Llama-2-7b-chat-hf": "Anyscale Endpoints", |
|
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints", |
|
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints", |
|
} |
|
_SUPPORTED_MODELS = list(_MODEL_DICT.keys()) |
|
_DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo") |
|
|
|
_DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7)) |
|
_MIN_TEMPERATURE = float(os.environ.get("MIN_TEMPERATURE", 0.0)) |
|
_MAX_TEMPERATURE = float(os.environ.get("MAX_TEMPERATURE", 1.0)) |
|
|
|
_DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000)) |
|
_MIN_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1)) |
|
_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000)) |
|
|
|
|
|
def get_llm( |
|
model: str, |
|
provider_api_key: str, |
|
temperature: float, |
|
max_tokens: int = _DEFAULT_MAX_TOKENS, |
|
) -> BaseChatModel: |
|
if _MODEL_DICT[model] == "OpenAI": |
|
return ChatOpenAI( |
|
model=model, |
|
openai_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
elif _MODEL_DICT[model] == "Anthropic": |
|
return ChatAnthropic( |
|
model_name=model, |
|
anthropic_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens_to_sample=max_tokens, |
|
) |
|
elif _MODEL_DICT[model] == "Anyscale Endpoints": |
|
return ChatAnyscale( |
|
model=model, |
|
anyscale_api_key=provider_api_key, |
|
temperature=temperature, |
|
streaming=True, |
|
max_tokens=max_tokens, |
|
) |
|
else: |
|
raise NotImplementedError(f"Unknown model {model}") |
|
|
|
|
|
def get_llm_chain( |
|
model: str, |
|
provider_api_key: str, |
|
system_prompt: str = _DEFAULT_SYSTEM_PROMPT, |
|
temperature: float = _DEFAULT_TEMPERATURE, |
|
max_tokens: int = _DEFAULT_MAX_TOKENS, |
|
) -> LLMChain: |
|
"""Return a basic LLMChain with memory.""" |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
system_prompt + "\nIt's currently {time}.", |
|
), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{input}"), |
|
], |
|
).partial(time=lambda: str(datetime.now())) |
|
llm = get_llm(model, provider_api_key, temperature, max_tokens) |
|
return LLMChain(prompt=prompt, llm=llm, memory=_MEMORY) |
|
|
|
|
|
class StreamHandler(BaseCallbackHandler): |
|
def __init__(self, container, initial_text=""): |
|
self.container = container |
|
self.text = initial_text |
|
|
|
def on_llm_new_token(self, token: str, **kwargs) -> None: |
|
self.text += token |
|
self.container.markdown(self.text) |
|
|
|
|
|
def feedback_component(client): |
|
scores = {"π": 1, "π": 0.75, "π": 0.5, "π": 0.25, "π": 0} |
|
if feedback := streamlit_feedback( |
|
feedback_type="faces", |
|
optional_text_label="[Optional] Please provide an explanation", |
|
key=f"feedback_{st.session_state.run_id}", |
|
): |
|
score = scores[feedback["score"]] |
|
feedback = client.create_feedback( |
|
st.session_state.run_id, |
|
feedback["type"], |
|
score=score, |
|
comment=feedback.get("text", None), |
|
) |
|
st.session_state.feedback = {"feedback_id": str(feedback.id), "score": score} |
|
st.toast("Feedback recorded!", icon="π") |
|
|