H
commited on
Commit
·
87d8c78
1
Parent(s):
970a3e8
Fix multiple generate (#1722)
Browse files### What problem does this PR solve?
#1625
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- graph/component/answer.py +3 -1
- graph/component/generate.py +32 -51
graph/component/answer.py
CHANGED
@@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC):
|
|
59 |
stream = self.get_stream_input()
|
60 |
if isinstance(stream, pd.DataFrame):
|
61 |
res = stream
|
|
|
62 |
for ii, row in stream.iterrows():
|
63 |
-
|
|
|
64 |
else:
|
65 |
for st in stream():
|
66 |
res = st
|
|
|
59 |
stream = self.get_stream_input()
|
60 |
if isinstance(stream, pd.DataFrame):
|
61 |
res = stream
|
62 |
+
answer = ""
|
63 |
for ii, row in stream.iterrows():
|
64 |
+
answer += row.to_dict()["content"]
|
65 |
+
yield {"content": answer}
|
66 |
else:
|
67 |
for st in stream():
|
68 |
res = st
|
graph/component/generate.py
CHANGED
@@ -67,6 +67,34 @@ class Generate(ComponentBase):
|
|
67 |
cpnts = [para["component_id"] for para in self._param.parameters]
|
68 |
return cpnts
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def _run(self, history, **kwargs):
|
71 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
72 |
prompt = self._param.prompt
|
@@ -87,9 +115,8 @@ class Generate(ComponentBase):
|
|
87 |
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
|
88 |
|
89 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
90 |
-
if kwargs.get("stream")
|
91 |
-
|
92 |
-
and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer":
|
93 |
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
94 |
|
95 |
if "empty_response" in retrieval_res.columns:
|
@@ -97,27 +124,8 @@ class Generate(ComponentBase):
|
|
97 |
|
98 |
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
99 |
self._param.gen_conf())
|
100 |
-
|
101 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
102 |
-
|
103 |
-
[ck["content_ltks"]
|
104 |
-
for _, ck in retrieval_res.iterrows()],
|
105 |
-
[ck["vector"]
|
106 |
-
for _, ck in retrieval_res.iterrows()],
|
107 |
-
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
108 |
-
self._canvas.get_embedding_model()),
|
109 |
-
tkweight=0.7,
|
110 |
-
vtweight=0.3)
|
111 |
-
del retrieval_res["vector"]
|
112 |
-
retrieval_res = retrieval_res.to_dict("records")
|
113 |
-
df = []
|
114 |
-
for i in idx:
|
115 |
-
df.append(retrieval_res[int(i)])
|
116 |
-
r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
|
117 |
-
assert r, f"{i} => {ans}"
|
118 |
-
df[-1]["content"] = r.group(1)
|
119 |
-
ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
|
120 |
-
if ans: df.append({"content": ans})
|
121 |
return pd.DataFrame(df)
|
122 |
|
123 |
return Generate.be_output(ans)
|
@@ -138,34 +146,7 @@ class Generate(ComponentBase):
|
|
138 |
yield res
|
139 |
|
140 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
141 |
-
|
142 |
-
[ck["content_ltks"]
|
143 |
-
for _, ck in retrieval_res.iterrows()],
|
144 |
-
[ck["vector"]
|
145 |
-
for _, ck in retrieval_res.iterrows()],
|
146 |
-
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
147 |
-
self._canvas.get_embedding_model()),
|
148 |
-
tkweight=0.7,
|
149 |
-
vtweight=0.3)
|
150 |
-
doc_ids = set([])
|
151 |
-
recall_docs = []
|
152 |
-
for i in idx:
|
153 |
-
did = retrieval_res.loc[int(i), "doc_id"]
|
154 |
-
if did in doc_ids: continue
|
155 |
-
doc_ids.add(did)
|
156 |
-
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
157 |
-
|
158 |
-
del retrieval_res["vector"]
|
159 |
-
del retrieval_res["content_ltks"]
|
160 |
-
|
161 |
-
reference = {
|
162 |
-
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
163 |
-
"doc_aggs": recall_docs
|
164 |
-
}
|
165 |
-
|
166 |
-
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
167 |
-
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
168 |
-
res = {"content": answer, "reference": reference}
|
169 |
yield res
|
170 |
|
171 |
self.set_output(res)
|
|
|
67 |
cpnts = [para["component_id"] for para in self._param.parameters]
|
68 |
return cpnts
|
69 |
|
70 |
+
def set_cite(self, retrieval_res, answer):
|
71 |
+
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
72 |
+
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
73 |
+
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
74 |
+
self._canvas.get_embedding_model()), tkweight=0.7,
|
75 |
+
vtweight=0.3)
|
76 |
+
doc_ids = set([])
|
77 |
+
recall_docs = []
|
78 |
+
for i in idx:
|
79 |
+
did = retrieval_res.loc[int(i), "doc_id"]
|
80 |
+
if did in doc_ids: continue
|
81 |
+
doc_ids.add(did)
|
82 |
+
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
|
83 |
+
|
84 |
+
del retrieval_res["vector"]
|
85 |
+
del retrieval_res["content_ltks"]
|
86 |
+
|
87 |
+
reference = {
|
88 |
+
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
|
89 |
+
"doc_aggs": recall_docs
|
90 |
+
}
|
91 |
+
|
92 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
93 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
94 |
+
res = {"content": answer, "reference": reference}
|
95 |
+
|
96 |
+
return res
|
97 |
+
|
98 |
def _run(self, history, **kwargs):
|
99 |
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
100 |
prompt = self._param.prompt
|
|
|
115 |
prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
|
116 |
|
117 |
downstreams = self._canvas.get_component(self._id)["downstream"]
|
118 |
+
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
|
119 |
+
"obj"].component_name.lower() == "answer":
|
|
|
120 |
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
|
121 |
|
122 |
if "empty_response" in retrieval_res.columns:
|
|
|
124 |
|
125 |
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
|
126 |
self._param.gen_conf())
|
|
|
127 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
128 |
+
df = self.set_cite(retrieval_res, ans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
return pd.DataFrame(df)
|
130 |
|
131 |
return Generate.be_output(ans)
|
|
|
146 |
yield res
|
147 |
|
148 |
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
|
149 |
+
res = self.set_cite(retrieval_res, answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
yield res
|
151 |
|
152 |
self.set_output(res)
|