Kevin Hu
commited on
Commit
·
919f3a7
1
Parent(s):
22c8a6e
refine generate (#1562)
Browse files### What problem does this PR solve?
### Type of change
- [x] Refactoring
- graph/component/base.py +5 -0
- graph/component/generate.py +1 -1
graph/component/base.py
CHANGED
|
@@ -445,6 +445,11 @@ class ComponentBase(ABC):
|
|
| 445 |
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
|
| 446 |
for u in reversed_cpnts[::-1]:
|
| 447 |
if self.get_component_name(u) in ["switch"]: continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
if u not in self._canvas.get_component(self._id)["upstream"]: continue
|
| 449 |
if self.component_name.lower().find("switch") < 0 \
|
| 450 |
and self.get_component_name(u) in ["relevant", "categorize"]:
|
|
|
|
| 445 |
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
|
| 446 |
for u in reversed_cpnts[::-1]:
|
| 447 |
if self.get_component_name(u) in ["switch"]: continue
|
| 448 |
+
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
| 449 |
+
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
|
| 450 |
+
if o is not None:
|
| 451 |
+
upstream_outs.append(o)
|
| 452 |
+
continue
|
| 453 |
if u not in self._canvas.get_component(self._id)["upstream"]: continue
|
| 454 |
if self.component_name.lower().find("switch") < 0 \
|
| 455 |
and self.get_component_name(u) in ["relevant", "categorize"]:
|
graph/component/generate.py
CHANGED
|
@@ -72,7 +72,7 @@ class Generate(ComponentBase):
|
|
| 72 |
prompt = self._param.prompt
|
| 73 |
|
| 74 |
retrieval_res = self.get_input()
|
| 75 |
-
input = "\n- ".join(retrieval_res["content"])
|
| 76 |
for para in self._param.parameters:
|
| 77 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
| 78 |
_, out = cpn.output(allow_partial=False)
|
|
|
|
| 72 |
prompt = self._param.prompt
|
| 73 |
|
| 74 |
retrieval_res = self.get_input()
|
| 75 |
+
input = "\n- ".join(retrieval_res["content"]) if "content" in retrieval_res else ""
|
| 76 |
for para in self._param.parameters:
|
| 77 |
cpn = self._canvas.get_component(para["component_id"])["obj"]
|
| 78 |
_, out = cpn.output(allow_partial=False)
|