Spaces:
Sleeping
Sleeping
File size: 4,018 Bytes
c423312 6f8d992 c423312 6f8d992 c423312 6f8d992 c423312 4304dbd c423312 4304dbd c423312 6f8d992 c423312 116461b 4304dbd 116461b c423312 116461b c423312 4304dbd c423312 4304dbd c423312 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import numpy as np
import redis
import streamlit as st
from langchain import HuggingFaceHub
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from redis.commands.search.query import Query
from sentence_transformers import SentenceTransformer
from constants import (
EMBEDDING_MODEL_NAME,
FALCON_MAX_TOKENS,
FALCON_REPO_ID,
FALCON_TEMPERATURE,
HUGGINGFACEHUB_API_TOKEN,
ITEM_KEYWORD_EMBEDDING,
OPENAI_API_KEY,
OPENAI_MODEL_NAME,
OPENAI_TEMPERATURE,
TEMPLATE_1,
TEMPLATE_2,
TOPK,
)
from database import create_redis
# connect to redis database
@st.cache_resource()
def connect_to_redis():
pool = create_redis()
return redis.Redis(connection_pool=pool)
# the encoding keywords chain
@st.cache_resource()
def encode_keywords_chain():
llm = HuggingFaceHub(
repo_id=FALCON_REPO_ID,
model_kwargs={"temperature": FALCON_TEMPERATURE, "max_new_tokens": FALCON_MAX_TOKENS},
huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN,
)
prompt = PromptTemplate(
input_variables=["product_description"],
template=TEMPLATE_1,
)
chain = LLMChain(llm=llm, prompt=prompt)
return chain
# the present products chain
def present_products_chain():
template = TEMPLATE_2
memory = ConversationBufferMemory(memory_key="chat_history")
prompt = PromptTemplate(input_variables=["chat_history", "user_msg"], template=template)
chain = LLMChain(
llm=ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=OPENAI_TEMPERATURE, model=OPENAI_MODEL_NAME),
prompt=prompt,
verbose=False,
memory=memory,
)
return chain
@st.cache_resource()
def instance_embedding_model():
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
return embedding_model
def main():
st.title("My Amazon shopping buddy π·οΈ")
st.caption("π€ Powered by Falcon Open Source AI model")
redis_conn = connect_to_redis()
keywords_chain = encode_keywords_chain()
if "window_refreshed" not in st.session_state:
st.session_state.window_refreshed = True
st.session_state.chat_chain = present_products_chain()
embedding_model = instance_embedding_model()
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "Hey im your online shopping buddy, how can i help you today?"}
]
for msg in st.session_state["messages"]:
st.chat_message(msg["role"]).write(msg["content"])
prompt = st.chat_input(key="user_input")
if prompt:
st.session_state["messages"].append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
st.session_state.disabled = True
keywords = keywords_chain.run(prompt)
# vectorize the query
query_vector = embedding_model.encode(keywords)
query_vector_bytes = np.array(query_vector).astype(np.float32).tobytes()
# prepare the query
q = (
Query(f"*=>[KNN {TOPK} @{ITEM_KEYWORD_EMBEDDING} $vec_param AS vector_score]")
.sort_by("vector_score")
.paging(0, TOPK)
.return_fields("vector_score", "item_name", "item_id", "item_keywords")
.dialect(2)
)
params_dict = {"vec_param": query_vector_bytes}
# Execute the query
results = redis_conn.ft().search(q, query_params=params_dict)
result_output = ""
for product in results.docs:
result_output += f"product_name:{product.item_name}, product_description:{product.item_keywords} \n"
result = st.session_state.chat_chain.predict(user_msg=f"{result_output}\n{prompt}")
st.session_state.messages.append({"role": "assistant", "content": result})
st.chat_message("assistant").write(result)
if __name__ == "__main__":
main()
|