File size: 6,473 Bytes
3941168
 
 
3ac47d5
c17af4e
3d6517f
3ac47d5
c17af4e
 
 
 
3d6517f
 
 
 
 
 
 
 
 
 
 
c17af4e
 
 
 
 
 
 
5064436
 
 
 
 
c17af4e
 
 
6d21a48
c17af4e
 
 
 
 
 
 
 
 
 
 
6d21a48
 
 
c17af4e
 
6d21a48
c17af4e
 
e2a2d53
3ac47d5
c17af4e
 
 
 
 
3ac47d5
 
 
 
 
 
 
3d6517f
 
 
 
 
 
 
 
 
 
 
c17af4e
 
 
 
 
6d21a48
c17af4e
 
6d21a48
 
 
 
 
 
 
 
 
3d6517f
 
 
6d21a48
 
 
 
3d6517f
6d21a48
c17af4e
6d21a48
c17af4e
 
 
3ac47d5
c17af4e
3d6517f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c17af4e
 
 
3ac47d5
 
 
 
 
 
 
 
 
 
c17af4e
3ac47d5
 
6d21a48
5d8ad15
 
 
 
 
3d6517f
 
 
 
 
 
 
 
 
 
 
5d8ad15
c17af4e
3ac47d5
 
 
 
c17af4e
3ac47d5
 
 
 
 
 
 
5d8ad15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import warnings
warnings.simplefilter("ignore", category=FutureWarning)

import os
import streamlit as st
from neo4j import GraphDatabase
from huggingface_hub import InferenceClient
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoTokenizer, AutoModel
import torch

# Hugging Face API Setup
API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(api_key=API_TOKEN, )

# Driver neo4j 
driver = GraphDatabase.driver(
        os.environ['NEO4J_URI'], 
        auth=(os.environ['NEO4J_USERNAME'], os.environ['NEO4J_PASSWORD'])
    )

# Custom Embedding Class
class CustomHuggingFaceEmbeddings:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def embed_text(self, text):
        try:
            inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        except Exception as e:
            print(f"Error during tokenization: {e}")
            return []
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
    
    def embed_query(self, text):
        return self.embed_text(text)
    
    def embed_documents(self, text):
        return self.embed_text(text)

# Function to set up the Neo4j Vector Index
@st.cache_resource 
def setup_vector_index():
    return Neo4jVector.from_existing_graph(
        CustomHuggingFaceEmbeddings(),
        url=os.environ['NEO4J_URI'],
        username=os.environ['NEO4J_USERNAME'],
        password=os.environ['NEO4J_PASSWORD'],
        index_name='articles',
        node_label="Article",
        text_node_properties=['name', 'abstract'],
        embedding_node_property='embedding',
    )

# Query Mistral
def query_from_mistral(context: str, user_input: str):
    messages = [
        {"role": "system", "content": f"Use the following context to answer the query:\n{context}"},
        {"role": "user", "content": user_input},
    ]
    completion = client.chat.completions.create(
        model=MISTRAL_MODEL_NAME,
        messages=messages,
        max_tokens=500,
    )
    return completion.choices[0].message["content"]

# Find keywords
def query_article_keywords(name):
    with driver.session() as session:
        query = """
        MATCH (a:Article)-[:CONTAIN]->(k:Keyword)
        WHERE a.name = $name
        RETURN k
        """
        result = session.run(query, name=name)
        return [record["k"] for record in result]

# extract data from retriever response
def extract_data(documents):
    result = []

    for doc in documents:
        publication_date = doc.metadata.get('date_publication', "N/A")
        page_content = doc.page_content.strip().split("\n")
        
        title = "N/A"
        abstract = "N/A"

        for line in page_content:
            if line.lower().startswith("name:"):
                title = line[len("name:"):].strip()
            elif line.lower().startswith("abstract:"):
                abstract = line[len("abstract:"):].strip()

        keywords = query_article_keywords(title)
        keywords = [dict(node)['text'] for node in keywords]

        doc_data = {
            "Publication Date": publication_date,
            "Title": title,
            "Abstract": abstract,
            "keywords": ','.join(keywords)
        }
        result.append(doc_data)

    return result

# Main Streamlit Application
def main():
    st.set_page_config(page_title="Vector Chat with Mistral", layout="centered")
    
    # App description and features
    st.title("🤖 RAG with Mistral")
    st.markdown("""
        ## Description:
        Chat with **Mistral-7B-Instruct** using context retrieved from a **Neo4j** vector index. This app allows you to ask questions, and the assistant will provide real-time, context-driven answers by querying relevant articles and their keywords from the database.
    """)

    st.image(image="image.jpg", caption="Neo4j")

    st.markdown("""
        ## Key Features:
        - **Real-time context search** from a Neo4j vector index.
        - **Integration with Mistral-7B-Instruct model** for natural language processing.
        - **Keyword extraction** from relevant articles for enhanced context-based responses.

        ## GitHub Repository:
        You can find the source code and more information about this app on GitHub: [GitHub Repository Link](https://github.com/yourusername/your-repository-name)
    """)

    # Initialize the vector index
    vector_index = setup_vector_index()

    if "messages" not in st.session_state:
        st.session_state.messages = []

    with st.form(key="chat_form", clear_on_submit=True):
        user_input = st.text_input("You:", "")
        submit = st.form_submit_button("Send")

    if submit and user_input:
        st.session_state.messages.append({"role": "user", "content": user_input})

        with st.spinner("Fetching response..."):
            try:
                context_results = vector_index.similarity_search(user_input, k=5)

                if not context_results:
                    st.warning("No relevant context found. Please refine your query.")
                    response = "I'm sorry, I couldn't find any relevant information to answer your question."
                else:
                    data_dict = extract_data(context_results)

                    # convert to string
                    context = '\n'.join([ 
                        f"Title: {doc['Title']}\n"
                        f"Abstract: {doc['Abstract']}\n"
                        f"Publication Date: {doc['Publication Date']}\n"
                        f"Keywords: {doc['keywords']}"
                        for doc in data_dict
                    ])

                    response = query_from_mistral(context.strip(), user_input)

                st.session_state.messages.append({"role": "bot", "content": response})
            except Exception as e:
                st.error(f"Error: {e}")

    # Display chat history
    for message in st.session_state.messages:
        if message["role"] == "user":
            st.markdown(f"**You:** {message['content']}")
        elif message["role"] == "bot":
            st.markdown(f"**Bot:** {message['content']}")

if __name__ == "__main__":
    main()