juancho72h commited on
Commit
e5e0026
·
verified ·
1 Parent(s): 79041c3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -6,6 +6,12 @@ import torch
6
  from dotenv import load_dotenv
7
  from pinecone import Pinecone
8
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 
 
 
 
 
 
9
 
10
  # Detect GPU availability and set device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -42,6 +48,12 @@ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarc
42
  # Initialize chat history manually
43
  chat_history = []
44
 
 
 
 
 
 
 
45
  # Helper function to recursively flatten any list to a string
46
  def flatten_to_string(data):
47
  if isinstance(data, list):
@@ -53,12 +65,15 @@ def flatten_to_string(data):
53
  # Function to interact with Pinecone and OpenAI GPT-4
54
  def get_model_response(human_input):
55
  try:
 
 
 
56
  # Embed the query
57
  query_embedding = torch.tensor(embedding_model.embed_query(human_input)).to(device)
58
  query_embedding = query_embedding.cpu().numpy().tolist()
59
 
60
- # Query Pinecone index
61
- search_results = index.query(vector=query_embedding, top_k=2, include_metadata=True)
62
 
63
  context_list, images = [], []
64
  for ind, result in enumerate(search_results['matches']):
@@ -66,9 +81,15 @@ def get_model_response(human_input):
66
  image_url = flatten_to_string(result.get('metadata', {}).get('image_path', None))
67
  figure_desc = flatten_to_string(result.get('metadata', {}).get('figure_description', ''))
68
 
69
- context_list.append(f"Relevant information: {document_content}")
70
- if image_url and figure_desc:
71
- images.append((figure_desc, image_url))
 
 
 
 
 
 
72
 
73
  context_string = '\n\n'.join(context_list)
74
 
 
6
  from dotenv import load_dotenv
7
  from pinecone import Pinecone
8
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
9
+ from fuzzywuzzy import fuzz
10
+ import logging
11
+ import re # To help with preprocessing
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
 
16
  # Detect GPU availability and set device
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
48
  # Initialize chat history manually
49
  chat_history = []
50
 
51
+ # Helper function to preprocess text (removing unnecessary words)
52
+ def preprocess_text(text):
53
+ # Convert text to lowercase and remove special characters
54
+ text = re.sub(r'[^\w\s]', '', text.lower())
55
+ return text.strip()
56
+
57
  # Helper function to recursively flatten any list to a string
58
  def flatten_to_string(data):
59
  if isinstance(data, list):
 
65
  # Function to interact with Pinecone and OpenAI GPT-4
66
  def get_model_response(human_input):
67
  try:
68
+ # Preprocess the human input (cleaning up unnecessary words)
69
+ processed_input = preprocess_text(human_input)
70
+
71
  # Embed the query
72
  query_embedding = torch.tensor(embedding_model.embed_query(human_input)).to(device)
73
  query_embedding = query_embedding.cpu().numpy().tolist()
74
 
75
+ # Query Pinecone index with top_k=5 to get more potential matches
76
+ search_results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
77
 
78
  context_list, images = [], []
79
  for ind, result in enumerate(search_results['matches']):
 
81
  image_url = flatten_to_string(result.get('metadata', {}).get('image_path', None))
82
  figure_desc = flatten_to_string(result.get('metadata', {}).get('figure_description', ''))
83
 
84
+ # Preprocess the figure description and match keywords
85
+ processed_figure_desc = preprocess_text(figure_desc)
86
+ similarity_score = fuzz.token_set_ratio(processed_input, processed_figure_desc)
87
+ logging.info(f"Matching '{processed_input}' with '{processed_figure_desc}', similarity score: {similarity_score}")
88
+
89
+ if similarity_score >= 80: # Keep the threshold at 80 for now
90
+ context_list.append(f"Relevant information: {document_content}")
91
+ if image_url and figure_desc:
92
+ images.append((figure_desc, image_url))
93
 
94
  context_string = '\n\n'.join(context_list)
95