ccm commited on
Commit
24dc567
·
verified ·
1 Parent(s): 5694a69

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -43
main.py CHANGED
@@ -36,6 +36,18 @@ abstract_is_null = [
36
  data = data[~pandas.Series(abstract_is_null)]
37
  data.reset_index(inplace=True)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Create a FAISS index for fast similarity search
40
  metric = faiss.METRIC_INNER_PRODUCT
41
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
@@ -45,34 +57,35 @@ faiss.normalize_L2(vectors)
45
  index.train(vectors)
46
  index.add(vectors)
47
 
48
- # Load the model for later use in embeddings
49
- model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
50
 
51
-
52
- def search(query: str, k: int) -> tuple[str, str]:
53
  """
54
- Searches the dataset for the top k most relevant papers to the query
55
  Args:
56
  query (str): The user's query
57
  k (int): The number of results to return
58
  Returns:
59
- tuple[str, str]: A tuple containing the search results and references
60
  """
61
- query = numpy.expand_dims(model.encode(query), axis=0)
62
- faiss.normalize_L2(query)
63
- D, I = index.search(query, k)
64
  top_five = data.loc[I[0]]
65
 
66
- search_results = (
67
  "You are an AI assistant who delights in helping people learn about research from the Design "
68
- "Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_ABSTRACTS.\n\n"
69
- "RESEARCH_ABSTRACTS: \n"
 
 
 
70
  )
71
 
72
  references = "\n\n## References\n\n"
 
73
 
74
  for i in range(k):
75
- search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
76
  references += (
77
  str(i + 1)
78
  + ". "
@@ -93,36 +106,12 @@ def search(query: str, k: int) -> tuple[str, str]:
93
  + top_five["author_pub_id"].values[i]
94
  + ").\n"
95
  )
 
 
 
96
 
97
- search_results += (
98
- "\n\nUSER_QUERY: "
99
- )
100
-
101
- return search_results, references
102
-
103
-
104
- # Create an LLM pipeline that we can send queries to
105
- tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
106
- streamer = transformers.TextIteratorStreamer(
107
- tokenizer, skip_prompt=True, skip_special_tokens=True
108
- )
109
- chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
110
- LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
111
- )
112
-
113
-
114
- def preprocess(message: str) -> tuple[str, str]:
115
- """
116
- Applies a preprocessing step to the user's message before the LLM receives it
117
- Args:
118
- message (str): The user's message
119
- Returns:
120
- tuple[str, str]: A tuple containing the preprocessed message and a bypass variable
121
- """
122
- block_search_results, formatted_search_results = search(message, 5)
123
- return block_search_results + message + "\nANSWER: ", formatted_search_results
124
-
125
-
126
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
127
  """
128
  Applies a postprocessing step to the LLM's response before the user receives it
@@ -147,7 +136,7 @@ def reply(message: str, history: list[str]) -> str:
147
  """
148
 
149
  # Apply preprocessing
150
- message, bypass = preprocess(message)
151
 
152
  # This is some handling that is applied to the history variable to put it in a good format
153
  history_transformer_format = [
 
36
  data = data[~pandas.Series(abstract_is_null)]
37
  data.reset_index(inplace=True)
38
 
39
+ # Load the model for later use in embeddings
40
+ model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
41
+
42
+ # Create an LLM pipeline that we can send queries to
43
+ tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
44
+ streamer = transformers.TextIteratorStreamer(
45
+ tokenizer, skip_prompt=True, skip_special_tokens=True
46
+ )
47
+ chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
48
+ LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
49
+ )
50
+
51
  # Create a FAISS index for fast similarity search
52
  metric = faiss.METRIC_INNER_PRODUCT
53
  vectors = numpy.stack(data["embedding"].tolist(), axis=0)
 
57
  index.train(vectors)
58
  index.add(vectors)
59
 
 
 
60
 
61
+ def preprocess(query: str, k: int) -> tuple[str, str]:
 
62
  """
63
+ Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
64
  Args:
65
  query (str): The user's query
66
  k (int): The number of results to return
67
  Returns:
68
+ tuple[str, str]: A tuple containing the prompt and references
69
  """
70
+ encoded_query = numpy.expand_dims(model.encode(query), axis=0)
71
+ faiss.normalize_L2(encoded_query)
72
+ D, I = index.search(encoded_query, k)
73
  top_five = data.loc[I[0]]
74
 
75
+ prompt = (
76
  "You are an AI assistant who delights in helping people learn about research from the Design "
77
+ "Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the "
78
+ "RESEARCH_ABSTRACTS.\n\n"
79
+ "RESEARCH_ABSTRACTS:\n{{ABSTRACTS_GO_HERE}}\n\n"
80
+ "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
81
+ "ANSWER:\n"
82
  )
83
 
84
  references = "\n\n## References\n\n"
85
+ research_abstracts = ""
86
 
87
  for i in range(k):
88
+ research_abstracts += top_five["bib_dict"].values[i]["abstract"] + "\n"
89
  references += (
90
  str(i + 1)
91
  + ". "
 
106
  + top_five["author_pub_id"].values[i]
107
  + ").\n"
108
  )
109
+
110
+ prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
111
+ prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
112
 
113
+ return prompt, references
114
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
116
  """
117
  Applies a postprocessing step to the LLM's response before the user receives it
 
136
  """
137
 
138
  # Apply preprocessing
139
+ message, bypass = preprocess(message, 5)
140
 
141
  # This is some handling that is applied to the history variable to put it in a good format
142
  history_transformer_format = [