File size: 1,401 Bytes
d6e543b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
from gradio_client import Client
from langgraph import Graph, NodeFunction

# node 1: call your retriever
def retrieve_node(query: str):
    client = Client("giz/chatfed_retriever")  # HF repo name
    return client.predict(
        query=query,
        reports_filter="",
        sources_filter="",
        subtype_filter="",
        year_filter="",
        api_name="/retrieve"
    )

# node 2: call your generator
def generate_node(query: str, context: str):
    client = Client("giz/chatfed_generator")
    return client.predict(
        query=query,
        context=context,
        api_name="/generate"
    )

# build the graph
graph = Graph()
n1 = graph.add_node(NodeFunction(retrieve_node), name="retrieve")
n2 = graph.add_node(NodeFunction(generate_node), name="generate")
graph.link(n1, n2)

# expose a simple Gradio interface that drives the graph
def pipeline(query: str):
    # run the graph: pass query into retrieve, then into generate
    result = graph.run({
        "retrieve": {"query": query}
    })
    return result["generate"]

iface = gr.Interface(
    fn=pipeline,
    inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
    outputs="text",
    title="Modular RAG Pipeline",
)

if __name__ == "__main__":
    # when HF spins up your container it will pick up port 7860 by default
    iface.launch(server_name="0.0.0.0", server_port=7860)