H commited on
Commit
87d8c78
·
1 Parent(s): 970a3e8

Fix multiple generate (#1722)

Browse files

### What problem does this PR solve?

#1625

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

graph/component/answer.py CHANGED
@@ -59,8 +59,10 @@ class Answer(ComponentBase, ABC):
59
  stream = self.get_stream_input()
60
  if isinstance(stream, pd.DataFrame):
61
  res = stream
 
62
  for ii, row in stream.iterrows():
63
- yield row.to_dict()
 
64
  else:
65
  for st in stream():
66
  res = st
 
59
  stream = self.get_stream_input()
60
  if isinstance(stream, pd.DataFrame):
61
  res = stream
62
+ answer = ""
63
  for ii, row in stream.iterrows():
64
+ answer += row.to_dict()["content"]
65
+ yield {"content": answer}
66
  else:
67
  for st in stream():
68
  res = st
graph/component/generate.py CHANGED
@@ -67,6 +67,34 @@ class Generate(ComponentBase):
67
  cpnts = [para["component_id"] for para in self._param.parameters]
68
  return cpnts
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def _run(self, history, **kwargs):
71
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
72
  prompt = self._param.prompt
@@ -87,9 +115,8 @@ class Generate(ComponentBase):
87
  prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
88
 
89
  downstreams = self._canvas.get_component(self._id)["downstream"]
90
- if kwargs.get("stream") \
91
- and len(downstreams) == 1 \
92
- and self._canvas.get_component(downstreams[0])["obj"].component_name.lower() == "answer":
93
  return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
94
 
95
  if "empty_response" in retrieval_res.columns:
@@ -97,27 +124,8 @@ class Generate(ComponentBase):
97
 
98
  ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
99
  self._param.gen_conf())
100
-
101
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
102
- ans, idx = retrievaler.insert_citations(ans,
103
- [ck["content_ltks"]
104
- for _, ck in retrieval_res.iterrows()],
105
- [ck["vector"]
106
- for _, ck in retrieval_res.iterrows()],
107
- LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
108
- self._canvas.get_embedding_model()),
109
- tkweight=0.7,
110
- vtweight=0.3)
111
- del retrieval_res["vector"]
112
- retrieval_res = retrieval_res.to_dict("records")
113
- df = []
114
- for i in idx:
115
- df.append(retrieval_res[int(i)])
116
- r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
117
- assert r, f"{i} => {ans}"
118
- df[-1]["content"] = r.group(1)
119
- ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
120
- if ans: df.append({"content": ans})
121
  return pd.DataFrame(df)
122
 
123
  return Generate.be_output(ans)
@@ -138,34 +146,7 @@ class Generate(ComponentBase):
138
  yield res
139
 
140
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
141
- answer, idx = retrievaler.insert_citations(answer,
142
- [ck["content_ltks"]
143
- for _, ck in retrieval_res.iterrows()],
144
- [ck["vector"]
145
- for _, ck in retrieval_res.iterrows()],
146
- LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
147
- self._canvas.get_embedding_model()),
148
- tkweight=0.7,
149
- vtweight=0.3)
150
- doc_ids = set([])
151
- recall_docs = []
152
- for i in idx:
153
- did = retrieval_res.loc[int(i), "doc_id"]
154
- if did in doc_ids: continue
155
- doc_ids.add(did)
156
- recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
157
-
158
- del retrieval_res["vector"]
159
- del retrieval_res["content_ltks"]
160
-
161
- reference = {
162
- "chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
163
- "doc_aggs": recall_docs
164
- }
165
-
166
- if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
167
- answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
168
- res = {"content": answer, "reference": reference}
169
  yield res
170
 
171
  self.set_output(res)
 
67
  cpnts = [para["component_id"] for para in self._param.parameters]
68
  return cpnts
69
 
70
+ def set_cite(self, retrieval_res, answer):
71
+ answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
72
+ [ck["vector"] for _, ck in retrieval_res.iterrows()],
73
+ LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
74
+ self._canvas.get_embedding_model()), tkweight=0.7,
75
+ vtweight=0.3)
76
+ doc_ids = set([])
77
+ recall_docs = []
78
+ for i in idx:
79
+ did = retrieval_res.loc[int(i), "doc_id"]
80
+ if did in doc_ids: continue
81
+ doc_ids.add(did)
82
+ recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
83
+
84
+ del retrieval_res["vector"]
85
+ del retrieval_res["content_ltks"]
86
+
87
+ reference = {
88
+ "chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
89
+ "doc_aggs": recall_docs
90
+ }
91
+
92
+ if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
93
+ answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
94
+ res = {"content": answer, "reference": reference}
95
+
96
+ return res
97
+
98
  def _run(self, history, **kwargs):
99
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
100
  prompt = self._param.prompt
 
115
  prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
116
 
117
  downstreams = self._canvas.get_component(self._id)["downstream"]
118
+ if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
119
+ "obj"].component_name.lower() == "answer":
 
120
  return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
121
 
122
  if "empty_response" in retrieval_res.columns:
 
124
 
125
  ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
126
  self._param.gen_conf())
 
127
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
128
+ df = self.set_cite(retrieval_res, ans)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return pd.DataFrame(df)
130
 
131
  return Generate.be_output(ans)
 
146
  yield res
147
 
148
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
149
+ res = self.set_cite(retrieval_res, answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  yield res
151
 
152
  self.set_output(res)