File size: 4,441 Bytes
3ac47d5
c17af4e
3ac47d5
c17af4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2a2d53
3ac47d5
 
 
 
 
 
c17af4e
 
 
 
 
3ac47d5
 
 
 
 
 
 
c17af4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac47d5
c17af4e
 
 
 
 
 
3ac47d5
 
 
 
 
 
 
 
 
 
c17af4e
3ac47d5
 
c17af4e
 
 
 
 
 
3ac47d5
 
 
 
c17af4e
3ac47d5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import streamlit as st
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoTokenizer, AutoModel
import torch

# Custom Embedding Class
class CustomHuggingFaceEmbeddings:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def embed_text(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()

    def embed_query(self, text):
        return self.embed_text(text)
    
    def embed_documents(self, text):
        return self.embed_text(text)

# Function to set up the Neo4j Vector Index
@st.cache_resource 
def setup_vector_index():
    return Neo4jVector.from_existing_graph(
        CustomHuggingFaceEmbeddings(),
        url=os.environ['NEO4J_URI'],
        username=os.environ['NEO4J_USERNAME'],
        password=os.environ['NEO4J_PASSWORD'],
        index_name='articles',
        node_label="Article",
        text_node_properties=['topic', 'title', 'abstract'],
        embedding_node_property='embedding',
    )

# Hugging Face API Setup
API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(api_key=API_TOKEN)

# Query Mistral
def query_from_mistral(context: str, user_input: str):
    messages = [
        {"role": "system", "content": f"Use the following context to answer the query:\n{context}"},
        {"role": "user", "content": user_input},
    ]
    completion = client.chat.completions.create(
        model=MISTRAL_MODEL_NAME,
        messages=messages,
        max_tokens=500,
    )
    return completion.choices[0].message["content"]

# extract data from retriever response
def extract_data(documents):
    result = []

    for doc in documents:
        # Extract metadata
        publication_date = doc.metadata.get('publication_date')
        if publication_date:
            publication_date = publication_date.isoformat()
        
        # Extract page content
        page_content = doc.page_content.strip().split("\n")
        topic = page_content[1].strip() if len(page_content) > 1 else "N/A"
        title = page_content[2].strip() if len(page_content) > 2 else "N/A"
        abstract = page_content[3].strip() if len(page_content) > 3 else "N/A"
        
        # Format the extracted data as a string
        doc_data = (
            f"Publication Date: {publication_date}\n"
            f"Topic: {topic}\n"
            f"Title: {title}\n"
            f"Abstract: {abstract}\n"
        )
        result.append(doc_data)
    
    return result

# Main Streamlit Application
def main():
    st.set_page_config(page_title="Vector Chat with Mistral", layout="centered")
    st.title("🤖 Vector Chat with Mistral")
    st.markdown("Chat with **Mistral-7B-Instruct** using context retrieved from a Neo4j vector index.")

    # Initialize the vector index
    vector_index = setup_vector_index()

    if "messages" not in st.session_state:
        st.session_state.messages = []

    with st.form(key="chat_form", clear_on_submit=True):
        user_input = st.text_input("You:", "")
        submit = st.form_submit_button("Send")

    if submit and user_input:
        st.session_state.messages.append({"role": "user", "content": user_input})

        with st.spinner("Fetching response..."):
            try:
                # Retrieve context from the vector index
                context_results = vector_index.similarity_search(user_input, top_k=3)
                context = extract_data(context_results)[0]

                # Get response from Mistral
                response = query_from_mistral(context, user_input)
                st.session_state.messages.append({"role": "bot", "content": response})
            except Exception as e:
                st.error(f"Error: {e}")

    # Display chat history
    for message in st.session_state.messages:
        if message["role"] == "user":
            st.markdown(f"**You:** {message['content']}")
        elif message["role"] == "bot":
            st.markdown(f"**Bot:** {message['content']}")

if __name__ == "__main__":
    main()