mtyrrell's picture
initial commit
d6e543b
raw
history blame
1.4 kB
import gradio as gr
from gradio_client import Client
from langgraph import Graph, NodeFunction
# node 1: call your retriever
def retrieve_node(query: str):
client = Client("giz/chatfed_retriever") # HF repo name
return client.predict(
query=query,
reports_filter="",
sources_filter="",
subtype_filter="",
year_filter="",
api_name="/retrieve"
)
# node 2: call your generator
def generate_node(query: str, context: str):
client = Client("giz/chatfed_generator")
return client.predict(
query=query,
context=context,
api_name="/generate"
)
# build the graph
graph = Graph()
n1 = graph.add_node(NodeFunction(retrieve_node), name="retrieve")
n2 = graph.add_node(NodeFunction(generate_node), name="generate")
graph.link(n1, n2)
# expose a simple Gradio interface that drives the graph
def pipeline(query: str):
# run the graph: pass query into retrieve, then into generate
result = graph.run({
"retrieve": {"query": query}
})
return result["generate"]
iface = gr.Interface(
fn=pipeline,
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
outputs="text",
title="Modular RAG Pipeline",
)
if __name__ == "__main__":
# when HF spins up your container it will pick up port 7860 by default
iface.launch(server_name="0.0.0.0", server_port=7860)