Kevin Hu commited on
Commit
919f3a7
·
1 Parent(s): 22c8a6e

refine generate (#1562)

Browse files

### What problem does this PR solve?



### Type of change

- [x] Refactoring

graph/component/base.py CHANGED
@@ -445,6 +445,11 @@ class ComponentBase(ABC):
445
  if DEBUG: print(self.component_name, reversed_cpnts[::-1])
446
  for u in reversed_cpnts[::-1]:
447
  if self.get_component_name(u) in ["switch"]: continue
 
 
 
 
 
448
  if u not in self._canvas.get_component(self._id)["upstream"]: continue
449
  if self.component_name.lower().find("switch") < 0 \
450
  and self.get_component_name(u) in ["relevant", "categorize"]:
 
445
  if DEBUG: print(self.component_name, reversed_cpnts[::-1])
446
  for u in reversed_cpnts[::-1]:
447
  if self.get_component_name(u) in ["switch"]: continue
448
+ if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
449
+ o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
450
+ if o is not None:
451
+ upstream_outs.append(o)
452
+ continue
453
  if u not in self._canvas.get_component(self._id)["upstream"]: continue
454
  if self.component_name.lower().find("switch") < 0 \
455
  and self.get_component_name(u) in ["relevant", "categorize"]:
graph/component/generate.py CHANGED
@@ -72,7 +72,7 @@ class Generate(ComponentBase):
72
  prompt = self._param.prompt
73
 
74
  retrieval_res = self.get_input()
75
- input = "\n- ".join(retrieval_res["content"])
76
  for para in self._param.parameters:
77
  cpn = self._canvas.get_component(para["component_id"])["obj"]
78
  _, out = cpn.output(allow_partial=False)
 
72
  prompt = self._param.prompt
73
 
74
  retrieval_res = self.get_input()
75
+ input = "\n- ".join(retrieval_res["content"]) if "content" in retrieval_res else ""
76
  for para in self._param.parameters:
77
  cpn = self._canvas.get_component(para["component_id"])["obj"]
78
  _, out = cpn.output(allow_partial=False)