mtyrrell commited on
Commit
0e839b6
·
1 Parent(s): 2a6a9cf

chatui node scaffolding

Browse files
Files changed (1) hide show
  1. app/main.py +18 -2
app/main.py CHANGED
@@ -9,7 +9,20 @@ class GraphState(TypedDict):
9
  context: str
10
  result: str
11
 
12
- # node 1: retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def retrieve_node(state: GraphState) -> GraphState:
14
  client = Client("giz/chatfed_retriever") # HF repo name
15
  context = client.predict(
@@ -22,7 +35,7 @@ def retrieve_node(state: GraphState) -> GraphState:
22
  )
23
  return {"context": context}
24
 
25
- # node 2: generator
26
  def generate_node(state: GraphState) -> GraphState:
27
  client = Client("giz/chatfed_generator")
28
  result = client.predict(
@@ -36,10 +49,13 @@ def generate_node(state: GraphState) -> GraphState:
36
  workflow = StateGraph(GraphState)
37
 
38
  # Add nodes
 
39
  workflow.add_node("retrieve", retrieve_node)
40
  workflow.add_node("generate", generate_node)
41
 
42
  # Add edges
 
 
43
  workflow.add_edge(START, "retrieve")
44
  workflow.add_edge("retrieve", "generate")
45
  workflow.add_edge("generate", END)
 
9
  context: str
10
  result: str
11
 
12
+
13
+ ### OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
14
+
15
+ # node 1: chatUI
16
+ # To be finalized
17
+ # def chatui_node(state: GraphState) -> GraphState:
18
+ # client = Client("giz/chatfed_chatui")
19
+ # context = client.predict(
20
+ # query=state["query"], # not sure if we need to pass the query here
21
+ # api_name="/chat"
22
+ # )
23
+ # return {"query": query}
24
+
25
+ # node 2: retriever
26
  def retrieve_node(state: GraphState) -> GraphState:
27
  client = Client("giz/chatfed_retriever") # HF repo name
28
  context = client.predict(
 
35
  )
36
  return {"context": context}
37
 
38
+ # node 3: generator
39
  def generate_node(state: GraphState) -> GraphState:
40
  client = Client("giz/chatfed_generator")
41
  result = client.predict(
 
49
  workflow = StateGraph(GraphState)
50
 
51
  # Add nodes
52
+ # workflow.add_node("chatui", chatui_node)
53
  workflow.add_node("retrieve", retrieve_node)
54
  workflow.add_node("generate", generate_node)
55
 
56
  # Add edges
57
+ # workflow.add_edge(START, "chatui")
58
+ # workflow.add_edge("chatui", "retrieve")
59
  workflow.add_edge(START, "retrieve")
60
  workflow.add_edge("retrieve", "generate")
61
  workflow.add_edge("generate", END)