Update app.py
Browse files
app.py
CHANGED
|
@@ -109,7 +109,7 @@ def get_docs(input_query, country = [], vulnerability_cat = []):
|
|
| 109 |
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
| 110 |
else:
|
| 111 |
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
| 112 |
-
docs = retriever.retrieve(query=
|
| 113 |
# Break out the key fields and convert to pandas for filtering
|
| 114 |
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
| 115 |
df_docs = pd.DataFrame(docs)
|
|
@@ -154,11 +154,11 @@ def get_refs(docs, res):
|
|
| 154 |
return result_str
|
| 155 |
|
| 156 |
# define a special function for putting the prompt together (as we can't use haystack)
|
| 157 |
-
def get_prompt(docs,
|
| 158 |
base_prompt=prompt_template
|
| 159 |
# Add the meta data for references
|
| 160 |
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
|
| 161 |
-
prompt = base_prompt+"; Context: "+context+"; Question: "+
|
| 162 |
return(prompt)
|
| 163 |
|
| 164 |
def run_query(input_query, country, model_sel):
|
|
@@ -167,13 +167,13 @@ def run_query(input_query, country, model_sel):
|
|
| 167 |
# st.write('Selected country: ', country) # Debugging country
|
| 168 |
if model_sel == "chatGPT":
|
| 169 |
# res = pipe.run(query=input_text, documents=docs)
|
| 170 |
-
res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs,
|
| 171 |
output = res["results"][0]
|
| 172 |
references = get_refs(docs, output)
|
| 173 |
-
else:
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
st.write('Response')
|
| 178 |
st.success(output)
|
| 179 |
st.write('References')
|
|
|
|
| 109 |
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
|
| 110 |
else:
|
| 111 |
filters = {'country': {'$in': country},'vulnerability_cat': {'$in': vulnerability_cat}}
|
| 112 |
+
docs = retriever.retrieve(query=input_query, filters = filters, top_k = 10)
|
| 113 |
# Break out the key fields and convert to pandas for filtering
|
| 114 |
docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs]
|
| 115 |
df_docs = pd.DataFrame(docs)
|
|
|
|
| 154 |
return result_str
|
| 155 |
|
| 156 |
# define a special function for putting the prompt together (as we can't use haystack)
|
| 157 |
+
def get_prompt(docs, input_query):
|
| 158 |
base_prompt=prompt_template
|
| 159 |
# Add the meta data for references
|
| 160 |
context = ' - '.join(['&&& [ref. '+str(d.meta['ref_id'])+'] '+d.meta['document']+' &&&: '+d.content for d in docs])
|
| 161 |
+
prompt = base_prompt+"; Context: "+context+"; Question: "+input_query+"; Answer:"
|
| 162 |
return(prompt)
|
| 163 |
|
| 164 |
def run_query(input_query, country, model_sel):
|
|
|
|
| 167 |
# st.write('Selected country: ', country) # Debugging country
|
| 168 |
if model_sel == "chatGPT":
|
| 169 |
# res = pipe.run(query=input_text, documents=docs)
|
| 170 |
+
res = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_query)}])
|
| 171 |
output = res["results"][0]
|
| 172 |
references = get_refs(docs, output)
|
| 173 |
+
# else:
|
| 174 |
+
# res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
|
| 175 |
+
# output = res
|
| 176 |
+
# references = get_refs(docs, res)
|
| 177 |
st.write('Response')
|
| 178 |
st.success(output)
|
| 179 |
st.write('References')
|