Added pagerank support to infinity (#4059)
Browse files### What problem does this PR solve?
Added pagerank support to infinity
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- api/apps/kb_app.py +1 -0
- rag/utils/infinity_conn.py +17 -9
api/apps/kb_app.py
CHANGED
|
@@ -107,6 +107,7 @@ def update():
|
|
| 107 |
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
| 108 |
search.index_name(kb.tenant_id), kb.id)
|
| 109 |
else:
|
|
|
|
| 110 |
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
| 111 |
search.index_name(kb.tenant_id), kb.id)
|
| 112 |
|
|
|
|
| 107 |
settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
|
| 108 |
search.index_name(kb.tenant_id), kb.id)
|
| 109 |
else:
|
| 110 |
+
# Elasticsearch requires pagerank_fea be non-zero!
|
| 111 |
settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
|
| 112 |
search.index_name(kb.tenant_id), kb.id)
|
| 113 |
|
rag/utils/infinity_conn.py
CHANGED
|
@@ -46,13 +46,14 @@ def equivalent_condition_to_str(condition: dict) -> str|None:
|
|
| 46 |
cond.append(f"{k}='{v}'")
|
| 47 |
else:
|
| 48 |
cond.append(f"{k}={str(v)}")
|
| 49 |
-
return " AND ".join(cond) if cond else
|
| 50 |
|
| 51 |
|
| 52 |
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
| 53 |
"""
|
| 54 |
Concatenate multiple dataframes into one.
|
| 55 |
"""
|
|
|
|
| 56 |
if df_list:
|
| 57 |
return pl.concat(df_list)
|
| 58 |
schema = dict()
|
|
@@ -246,8 +247,9 @@ class InfinityConnection(DocStoreConnection):
|
|
| 246 |
db_instance = inf_conn.get_database(self.dbName)
|
| 247 |
df_list = list()
|
| 248 |
table_list = list()
|
| 249 |
-
|
| 250 |
-
selectFields
|
|
|
|
| 251 |
|
| 252 |
# Prepare expressions common to all tables
|
| 253 |
filter_cond = None
|
|
@@ -331,10 +333,13 @@ class InfinityConnection(DocStoreConnection):
|
|
| 331 |
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
| 332 |
if extra_result:
|
| 333 |
total_hits_count += int(extra_result["total_hits_count"])
|
|
|
|
| 334 |
df_list.append(kb_res)
|
| 335 |
self.connPool.release_conn(inf_conn)
|
| 336 |
res = concat_dataframes(df_list, selectFields)
|
| 337 |
-
|
|
|
|
|
|
|
| 338 |
return res, total_hits_count
|
| 339 |
|
| 340 |
def get(
|
|
@@ -350,12 +355,10 @@ class InfinityConnection(DocStoreConnection):
|
|
| 350 |
table_list.append(table_name)
|
| 351 |
table_instance = db_instance.get_table(table_name)
|
| 352 |
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
self.connPool.release_conn(inf_conn)
|
| 357 |
res = concat_dataframes(df_list, ["id"])
|
| 358 |
-
logger.debug(f"INFINITY get tables: {str(table_list)}, result: {str(res)}")
|
| 359 |
res_fields = self.getFields(res, res.columns)
|
| 360 |
return res_fields.get(chunkId, None)
|
| 361 |
|
|
@@ -421,8 +424,10 @@ class InfinityConnection(DocStoreConnection):
|
|
| 421 |
db_instance = inf_conn.get_database(self.dbName)
|
| 422 |
table_name = f"{indexName}_{knowledgebaseId}"
|
| 423 |
table_instance = db_instance.get_table(table_name)
|
|
|
|
|
|
|
| 424 |
filter = equivalent_condition_to_str(condition)
|
| 425 |
-
for k, v in newValue.items():
|
| 426 |
if k.endswith("_kwd") and isinstance(v, list):
|
| 427 |
newValue[k] = " ".join(v)
|
| 428 |
elif k == 'kb_id':
|
|
@@ -435,6 +440,9 @@ class InfinityConnection(DocStoreConnection):
|
|
| 435 |
elif k in ["page_num_int", "top_int"]:
|
| 436 |
assert isinstance(v, list)
|
| 437 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
|
|
|
|
|
|
|
|
|
| 438 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
| 439 |
table_instance.update(filter, newValue)
|
| 440 |
self.connPool.release_conn(inf_conn)
|
|
|
|
| 46 |
cond.append(f"{k}='{v}'")
|
| 47 |
else:
|
| 48 |
cond.append(f"{k}={str(v)}")
|
| 49 |
+
return " AND ".join(cond) if cond else "1=1"
|
| 50 |
|
| 51 |
|
| 52 |
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame:
|
| 53 |
"""
|
| 54 |
Concatenate multiple dataframes into one.
|
| 55 |
"""
|
| 56 |
+
df_list = [df for df in df_list if not df.is_empty()]
|
| 57 |
if df_list:
|
| 58 |
return pl.concat(df_list)
|
| 59 |
schema = dict()
|
|
|
|
| 247 |
db_instance = inf_conn.get_database(self.dbName)
|
| 248 |
df_list = list()
|
| 249 |
table_list = list()
|
| 250 |
+
for essential_field in ["id", "score()", "pagerank_fea"]:
|
| 251 |
+
if essential_field not in selectFields:
|
| 252 |
+
selectFields.append(essential_field)
|
| 253 |
|
| 254 |
# Prepare expressions common to all tables
|
| 255 |
filter_cond = None
|
|
|
|
| 333 |
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl()
|
| 334 |
if extra_result:
|
| 335 |
total_hits_count += int(extra_result["total_hits_count"])
|
| 336 |
+
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
| 337 |
df_list.append(kb_res)
|
| 338 |
self.connPool.release_conn(inf_conn)
|
| 339 |
res = concat_dataframes(df_list, selectFields)
|
| 340 |
+
res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
|
| 341 |
+
res = res.limit(limit)
|
| 342 |
+
logger.debug(f"INFINITY search final result: {str(res)}")
|
| 343 |
return res, total_hits_count
|
| 344 |
|
| 345 |
def get(
|
|
|
|
| 355 |
table_list.append(table_name)
|
| 356 |
table_instance = db_instance.get_table(table_name)
|
| 357 |
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl()
|
| 358 |
+
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
| 359 |
+
df_list.append(kb_res)
|
|
|
|
| 360 |
self.connPool.release_conn(inf_conn)
|
| 361 |
res = concat_dataframes(df_list, ["id"])
|
|
|
|
| 362 |
res_fields = self.getFields(res, res.columns)
|
| 363 |
return res_fields.get(chunkId, None)
|
| 364 |
|
|
|
|
| 424 |
db_instance = inf_conn.get_database(self.dbName)
|
| 425 |
table_name = f"{indexName}_{knowledgebaseId}"
|
| 426 |
table_instance = db_instance.get_table(table_name)
|
| 427 |
+
if "exist" in condition:
|
| 428 |
+
del condition["exist"]
|
| 429 |
filter = equivalent_condition_to_str(condition)
|
| 430 |
+
for k, v in list(newValue.items()):
|
| 431 |
if k.endswith("_kwd") and isinstance(v, list):
|
| 432 |
newValue[k] = " ".join(v)
|
| 433 |
elif k == 'kb_id':
|
|
|
|
| 440 |
elif k in ["page_num_int", "top_int"]:
|
| 441 |
assert isinstance(v, list)
|
| 442 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
| 443 |
+
elif k == "remove" and v in ["pagerank_fea"]:
|
| 444 |
+
del newValue[k]
|
| 445 |
+
newValue[v] = 0
|
| 446 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
| 447 |
table_instance.update(filter, newValue)
|
| 448 |
self.connPool.release_conn(inf_conn)
|