Spaces:
Running
on
Zero
Running
on
Zero
Update main.py
Browse files
main.py
CHANGED
@@ -36,6 +36,18 @@ abstract_is_null = [
|
|
36 |
data = data[~pandas.Series(abstract_is_null)]
|
37 |
data.reset_index(inplace=True)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# Create a FAISS index for fast similarity search
|
40 |
metric = faiss.METRIC_INNER_PRODUCT
|
41 |
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
|
@@ -45,34 +57,35 @@ faiss.normalize_L2(vectors)
|
|
45 |
index.train(vectors)
|
46 |
index.add(vectors)
|
47 |
|
48 |
-
# Load the model for later use in embeddings
|
49 |
-
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
50 |
|
51 |
-
|
52 |
-
def search(query: str, k: int) -> tuple[str, str]:
|
53 |
"""
|
54 |
-
Searches the dataset for the top k most relevant papers to the query
|
55 |
Args:
|
56 |
query (str): The user's query
|
57 |
k (int): The number of results to return
|
58 |
Returns:
|
59 |
-
tuple[str, str]: A tuple containing the
|
60 |
"""
|
61 |
-
|
62 |
-
faiss.normalize_L2(
|
63 |
-
D, I = index.search(
|
64 |
top_five = data.loc[I[0]]
|
65 |
|
66 |
-
|
67 |
"You are an AI assistant who delights in helping people learn about research from the Design "
|
68 |
-
"Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the
|
69 |
-
"RESEARCH_ABSTRACTS
|
|
|
|
|
|
|
70 |
)
|
71 |
|
72 |
references = "\n\n## References\n\n"
|
|
|
73 |
|
74 |
for i in range(k):
|
75 |
-
|
76 |
references += (
|
77 |
str(i + 1)
|
78 |
+ ". "
|
@@ -93,36 +106,12 @@ def search(query: str, k: int) -> tuple[str, str]:
|
|
93 |
+ top_five["author_pub_id"].values[i]
|
94 |
+ ").\n"
|
95 |
)
|
|
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
)
|
100 |
-
|
101 |
-
return search_results, references
|
102 |
-
|
103 |
-
|
104 |
-
# Create an LLM pipeline that we can send queries to
|
105 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
106 |
-
streamer = transformers.TextIteratorStreamer(
|
107 |
-
tokenizer, skip_prompt=True, skip_special_tokens=True
|
108 |
-
)
|
109 |
-
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
110 |
-
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
|
111 |
-
)
|
112 |
-
|
113 |
-
|
114 |
-
def preprocess(message: str) -> tuple[str, str]:
|
115 |
-
"""
|
116 |
-
Applies a preprocessing step to the user's message before the LLM receives it
|
117 |
-
Args:
|
118 |
-
message (str): The user's message
|
119 |
-
Returns:
|
120 |
-
tuple[str, str]: A tuple containing the preprocessed message and a bypass variable
|
121 |
-
"""
|
122 |
-
block_search_results, formatted_search_results = search(message, 5)
|
123 |
-
return block_search_results + message + "\nANSWER: ", formatted_search_results
|
124 |
-
|
125 |
-
|
126 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
127 |
"""
|
128 |
Applies a postprocessing step to the LLM's response before the user receives it
|
@@ -147,7 +136,7 @@ def reply(message: str, history: list[str]) -> str:
|
|
147 |
"""
|
148 |
|
149 |
# Apply preprocessing
|
150 |
-
message, bypass = preprocess(message)
|
151 |
|
152 |
# This is some handling that is applied to the history variable to put it in a good format
|
153 |
history_transformer_format = [
|
|
|
36 |
data = data[~pandas.Series(abstract_is_null)]
|
37 |
data.reset_index(inplace=True)
|
38 |
|
39 |
+
# Load the model for later use in embeddings
|
40 |
+
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
41 |
+
|
42 |
+
# Create an LLM pipeline that we can send queries to
|
43 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
44 |
+
streamer = transformers.TextIteratorStreamer(
|
45 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
46 |
+
)
|
47 |
+
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
48 |
+
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
|
49 |
+
)
|
50 |
+
|
51 |
# Create a FAISS index for fast similarity search
|
52 |
metric = faiss.METRIC_INNER_PRODUCT
|
53 |
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
|
|
|
57 |
index.train(vectors)
|
58 |
index.add(vectors)
|
59 |
|
|
|
|
|
60 |
|
61 |
+
def preprocess(query: str, k: int) -> tuple[str, str]:
|
|
|
62 |
"""
|
63 |
+
Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
|
64 |
Args:
|
65 |
query (str): The user's query
|
66 |
k (int): The number of results to return
|
67 |
Returns:
|
68 |
+
tuple[str, str]: A tuple containing the prompt and references
|
69 |
"""
|
70 |
+
encoded_query = numpy.expand_dims(model.encode(query), axis=0)
|
71 |
+
faiss.normalize_L2(encoded_query)
|
72 |
+
D, I = index.search(encoded_query, k)
|
73 |
top_five = data.loc[I[0]]
|
74 |
|
75 |
+
prompt = (
|
76 |
"You are an AI assistant who delights in helping people learn about research from the Design "
|
77 |
+
"Research Collective. Your main task is to provide an ANSWER to the USER_QUERY based on the "
|
78 |
+
"RESEARCH_ABSTRACTS.\n\n"
|
79 |
+
"RESEARCH_ABSTRACTS:\n{{ABSTRACTS_GO_HERE}}\n\n"
|
80 |
+
"USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
|
81 |
+
"ANSWER:\n"
|
82 |
)
|
83 |
|
84 |
references = "\n\n## References\n\n"
|
85 |
+
research_abstracts = ""
|
86 |
|
87 |
for i in range(k):
|
88 |
+
research_abstracts += top_five["bib_dict"].values[i]["abstract"] + "\n"
|
89 |
references += (
|
90 |
str(i + 1)
|
91 |
+ ". "
|
|
|
106 |
+ top_five["author_pub_id"].values[i]
|
107 |
+ ").\n"
|
108 |
)
|
109 |
+
|
110 |
+
prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
|
111 |
+
prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
|
112 |
|
113 |
+
return prompt, references
|
114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
116 |
"""
|
117 |
Applies a postprocessing step to the LLM's response before the user receives it
|
|
|
136 |
"""
|
137 |
|
138 |
# Apply preprocessing
|
139 |
+
message, bypass = preprocess(message, 5)
|
140 |
|
141 |
# This is some handling that is applied to the history variable to put it in a good format
|
142 |
history_transformer_format = [
|