Spaces:
Sleeping
Sleeping
File size: 4,711 Bytes
3941168 3ac47d5 c17af4e 3ac47d5 c17af4e 2bcdd30 c17af4e 5064436 c17af4e 2bcdd30 c17af4e e2a2d53 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 c17af4e 3ac47d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import warnings
warnings.simplefilter("ignore", category=FutureWarning)
import os
import streamlit as st
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoTokenizer, AutoModel
import torch
print(f"Username Neo4j: {os.environ.get('NEO4J_USERNAME')}")
# Custom Embedding Class
class CustomHuggingFaceEmbeddings:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def embed_text(self, text):
try:
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
except Exception as e:
print(f"Error during tokenization: {e}")
return []
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
def embed_query(self, text):
return self.embed_text(text)
def embed_documents(self, text):
return self.embed_text(text)
# Function to set up the Neo4j Vector Index
@st.cache_resource
def setup_vector_index():
return Neo4jVector.from_existing_graph(
CustomHuggingFaceEmbeddings(),
url=os.environ.get('NEO4J_URI'),
username=os.environ.get('NEO4J_USERNAME'),
password=os.environ.get('NEO4J_PASSWORD'),
index_name='articles',
node_label="Article",
text_node_properties=['topic', 'title', 'abstract'],
embedding_node_property='embedding',
)
# Hugging Face API Setup
API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(api_key=API_TOKEN)
# Query Mistral
def query_from_mistral(context: str, user_input: str):
messages = [
{"role": "system", "content": f"Use the following context to answer the query:\n{context}"},
{"role": "user", "content": user_input},
]
completion = client.chat.completions.create(
model=MISTRAL_MODEL_NAME,
messages=messages,
max_tokens=500,
)
return completion.choices[0].message["content"]
# extract data from retriever response
def extract_data(documents):
result = []
for doc in documents:
# Extract metadata
publication_date = doc.metadata.get('publication_date')
if publication_date:
publication_date = publication_date.isoformat()
# Extract page content
page_content = doc.page_content.strip().split("\n")
topic = page_content[1].strip() if len(page_content) > 1 else "N/A"
title = page_content[2].strip() if len(page_content) > 2 else "N/A"
abstract = page_content[3].strip() if len(page_content) > 3 else "N/A"
# Format the extracted data as a string
doc_data = (
f"Publication Date: {publication_date}\n"
f"Topic: {topic}\n"
f"Title: {title}\n"
f"Abstract: {abstract}\n"
)
result.append(doc_data)
return result
# Main Streamlit Application
def main():
st.set_page_config(page_title="Vector Chat with Mistral", layout="centered")
st.title("🤖 Vector Chat with Mistral")
st.markdown("Chat with **Mistral-7B-Instruct** using context retrieved from a Neo4j vector index.")
# Initialize the vector index
vector_index = setup_vector_index()
if "messages" not in st.session_state:
st.session_state.messages = []
with st.form(key="chat_form", clear_on_submit=True):
user_input = st.text_input("You:", "")
submit = st.form_submit_button("Send")
if submit and user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
with st.spinner("Fetching response..."):
try:
# Retrieve context from the vector index
context_results = vector_index.similarity_search(user_input, top_k=3)
context = extract_data(context_results)[0]
# Get response from Mistral
response = query_from_mistral(context, user_input)
st.session_state.messages.append({"role": "bot", "content": response})
except Exception as e:
st.error(f"Error: {e}")
# Display chat history
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(f"**You:** {message['content']}")
elif message["role"] == "bot":
st.markdown(f"**Bot:** {message['content']}")
if __name__ == "__main__":
main()
|