Adventure123 commited on
Commit
c17af4e
·
verified ·
1 Parent(s): cbe9d31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -28
app.py CHANGED
@@ -1,31 +1,53 @@
1
- import streamlit as st
2
- import time
3
- import requests
4
  import os
 
5
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Hugging Face API Setup
8
  API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
9
- GPT2XL_API_URL = "https://api-inference.huggingface.co/models/openai-community/gpt2-xl"
10
  MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
11
  client = InferenceClient(api_key=API_TOKEN)
12
 
13
- # Query GPT-2 XL
14
- def query_from_gpt2xl(text: str):
15
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
16
- while True:
17
- response = requests.post(GPT2XL_API_URL, headers=headers, json={"inputs": text})
18
- response_data = response.json()
19
- if "error" in response_data and "loading" in response_data["error"]:
20
- wait_time = response_data.get("estimated_time", 10)
21
- st.info(f"Model is loading. Waiting for {wait_time:.2f} seconds...")
22
- time.sleep(wait_time)
23
- else:
24
- return response_data[0]["generated_text"]
25
-
26
  # Query Mistral
27
- def query_from_mistral(text: str):
28
- messages = [{"role": "user", "content": text}]
 
 
 
29
  completion = client.chat.completions.create(
30
  model=MISTRAL_MODEL_NAME,
31
  messages=messages,
@@ -33,32 +55,65 @@ def query_from_mistral(text: str):
33
  )
34
  return completion.choices[0].message["content"]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def main():
37
- st.set_page_config(page_title="Multi-Model Chat", layout="centered")
38
- st.title("🤖 Multi-Model Chat")
39
- st.markdown("Chat with either **GPT-2 XL** or **Mistral-7B-Instruct** via Hugging Face API.")
 
 
 
40
 
41
  if "messages" not in st.session_state:
42
  st.session_state.messages = []
43
 
44
- model_choice = st.selectbox("Select a model:", ["GPT-2 XL", "Mistral-7B-Instruct"])
45
-
46
  with st.form(key="chat_form", clear_on_submit=True):
47
  user_input = st.text_input("You:", "")
48
  submit = st.form_submit_button("Send")
49
 
50
  if submit and user_input:
51
  st.session_state.messages.append({"role": "user", "content": user_input})
 
52
  with st.spinner("Fetching response..."):
53
  try:
54
- if model_choice == "GPT-2 XL":
55
- response = query_from_gpt2xl(user_input)
56
- elif model_choice == "Mistral-7B-Instruct":
57
- response = query_from_mistral(user_input)
 
 
58
  st.session_state.messages.append({"role": "bot", "content": response})
59
  except Exception as e:
60
  st.error(f"Error: {e}")
61
 
 
62
  for message in st.session_state.messages:
63
  if message["role"] == "user":
64
  st.markdown(f"**You:** {message['content']}")
 
 
 
 
1
  import os
2
+ import streamlit as st
3
  from huggingface_hub import InferenceClient
4
+ from langchain_community.vectorstores import Neo4jVector
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import torch
7
+
8
+ # Custom Embedding Class
9
+ class CustomHuggingFaceEmbeddings:
10
+ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ self.model = AutoModel.from_pretrained(model_name)
13
+
14
+ def embed_text(self, text):
15
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
16
+ with torch.no_grad():
17
+ outputs = self.model(**inputs)
18
+ return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
19
+
20
+ def embed_query(self, text):
21
+ return self.embed_text(text)
22
+
23
+ def embed_documents(self, text):
24
+ return self.embed_text(text)
25
+
26
+ # Function to set up the Neo4j Vector Index
27
+ @st.cache_resource
28
+ def setup_vector_index():
29
+ return Neo4jVector.from_existing_graph(
30
+ CustomHuggingFaceEmbeddings(),
31
+ url=os.environ['NEO4J_URI'],
32
+ username=os.environ['NEO4J_USERNAME'],
33
+ password=os.environ['NEO4J_PASSWORD'],
34
+ index_name='articles',
35
+ node_label="Article",
36
+ text_node_properties=['topic', 'title', 'abstract'],
37
+ embedding_node_property='embedding',
38
+ )
39
 
40
  # Hugging Face API Setup
41
  API_TOKEN = os.environ.get("HUGGINGFACE_API_TOKEN")
 
42
  MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
43
  client = InferenceClient(api_key=API_TOKEN)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Query Mistral
46
+ def query_from_mistral(context: str, user_input: str):
47
+ messages = [
48
+ {"role": "system", "content": f"Use the following context to answer the query:\n{context}"},
49
+ {"role": "user", "content": user_input},
50
+ ]
51
  completion = client.chat.completions.create(
52
  model=MISTRAL_MODEL_NAME,
53
  messages=messages,
 
55
  )
56
  return completion.choices[0].message["content"]
57
 
58
+ # extract data from retriever response
59
+ def extract_data(documents):
60
+ result = []
61
+
62
+ for doc in documents:
63
+ # Extract metadata
64
+ publication_date = doc.metadata.get('publication_date')
65
+ if publication_date:
66
+ publication_date = publication_date.isoformat()
67
+
68
+ # Extract page content
69
+ page_content = doc.page_content.strip().split("\n")
70
+ topic = page_content[1].strip() if len(page_content) > 1 else "N/A"
71
+ title = page_content[2].strip() if len(page_content) > 2 else "N/A"
72
+ abstract = page_content[3].strip() if len(page_content) > 3 else "N/A"
73
+
74
+ # Format the extracted data as a string
75
+ doc_data = (
76
+ f"Publication Date: {publication_date}\n"
77
+ f"Topic: {topic}\n"
78
+ f"Title: {title}\n"
79
+ f"Abstract: {abstract}\n"
80
+ )
81
+ result.append(doc_data)
82
+
83
+ return result
84
+
85
+ # Main Streamlit Application
86
  def main():
87
+ st.set_page_config(page_title="Vector Chat with Mistral", layout="centered")
88
+ st.title("🤖 Vector Chat with Mistral")
89
+ st.markdown("Chat with **Mistral-7B-Instruct** using context retrieved from a Neo4j vector index.")
90
+
91
+ # Initialize the vector index
92
+ vector_index = setup_vector_index()
93
 
94
  if "messages" not in st.session_state:
95
  st.session_state.messages = []
96
 
 
 
97
  with st.form(key="chat_form", clear_on_submit=True):
98
  user_input = st.text_input("You:", "")
99
  submit = st.form_submit_button("Send")
100
 
101
  if submit and user_input:
102
  st.session_state.messages.append({"role": "user", "content": user_input})
103
+
104
  with st.spinner("Fetching response..."):
105
  try:
106
+ # Retrieve context from the vector index
107
+ context_results = vector_index.similarity_search(user_input, top_k=3)
108
+ context = extract_data(context_results)[0]
109
+
110
+ # Get response from Mistral
111
+ response = query_from_mistral(context, user_input)
112
  st.session_state.messages.append({"role": "bot", "content": response})
113
  except Exception as e:
114
  st.error(f"Error: {e}")
115
 
116
+ # Display chat history
117
  for message in st.session_state.messages:
118
  if message["role"] == "user":
119
  st.markdown(f"**You:** {message['content']}")