Spaces:
Sleeping
Sleeping
File size: 4,617 Bytes
3941168 3ac47d5 c17af4e 3ac47d5 c17af4e 5064436 c17af4e 6d21a48 c17af4e 6d21a48 c17af4e 6d21a48 c17af4e e2a2d53 3ac47d5 6d21a48 3ac47d5 c17af4e 3ac47d5 c17af4e 6d21a48 c17af4e 6d21a48 c17af4e 6d21a48 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 6d21a48 c17af4e 6d21a48 3ac47d5 c17af4e 3ac47d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import warnings
warnings.simplefilter("ignore", category=FutureWarning)
import os
import streamlit as st
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoTokenizer, AutoModel
import torch
# Custom Embedding Class
class CustomHuggingFaceEmbeddings:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def embed_text(self, text):
try:
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
except Exception as e:
print(f"Error during tokenization: {e}")
return []
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
def embed_query(self, text):
return self.embed_text(text)
def embed_documents(self, text):
return self.embed_text(text)
# Function to set up the Neo4j Vector Index
@st.cache_resource
def setup_vector_index():
return Neo4jVector.from_existing_graph(
CustomHuggingFaceEmbeddings(),
url=os.environ['NEO4J_URI'],
username=os.environ['NEO4J_USERNAME'],
password=os.environ['NEO4J_PASSWORD'],
index_name='articles',
node_label="Article",
text_node_properties=['name', 'abstract'],
embedding_node_property='embedding',
)
# Hugging Face API Setup
API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(api_key=API_TOKEN, )
# Query Mistral
def query_from_mistral(context: str, user_input: str):
messages = [
{"role": "system", "content": f"Use the following context to answer the query:\n{context}"},
{"role": "user", "content": user_input},
]
completion = client.chat.completions.create(
model=MISTRAL_MODEL_NAME,
messages=messages,
max_tokens=500,
)
return completion.choices[0].message["content"]
# extract data from retriever response
def extract_data(documents):
result = []
for doc in documents:
publication_date = doc.metadata.get('date_publication', "N/A")
page_content = doc.page_content.strip().split("\n")
title = "N/A"
abstract = "N/A"
for line in page_content:
if line.lower().startswith("name:"):
title = line[len("name:"):].strip()
elif line.lower().startswith("abstract:"):
abstract = line[len("abstract:"):].strip()
doc_data = {
"Publication Date": publication_date,
"Title": title,
"Abstract": abstract,
}
result.append(doc_data)
return result
# Main Streamlit Application
def main():
st.set_page_config(page_title="Vector Chat with Mistral", layout="centered")
st.title("🤖 Vector Chat with Mistral")
st.markdown("Chat with **Mistral-7B-Instruct** using context retrieved from a Neo4j vector index.")
# Initialize the vector index
vector_index = setup_vector_index()
if "messages" not in st.session_state:
st.session_state.messages = []
with st.form(key="chat_form", clear_on_submit=True):
user_input = st.text_input("You:", "")
submit = st.form_submit_button("Send")
if submit and user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.spinner("Fetching response..."):
try:
# Retrieve context from the vector index
context_results = vector_index.similarity_search(user_input, k=5)
context = "\n".join([f"Title: {doc['Title']}\nAbstract: {doc['Abstract']}\nPublication Date: {doc['Publication Date']}"
for doc in extract_data(context_results)])
# Get response from Mistral
response = query_from_mistral(context.strip(), user_input)
st.session_state.messages.append({"role": "bot", "content": response})
except Exception as e:
st.error(f"Error: {e}")
# Display chat history
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(f"**You:** {message['content']}")
elif message["role"] == "bot":
st.markdown(f"**Bot:** {message['content']}")
if __name__ == "__main__":
main()
|