Spaces:
Running
Running
import gradio as gr | |
from gradio_client import Client | |
from langgraph import Graph, NodeFunction | |
# node 1: call your retriever | |
def retrieve_node(query: str): | |
client = Client("giz/chatfed_retriever") # HF repo name | |
return client.predict( | |
query=query, | |
reports_filter="", | |
sources_filter="", | |
subtype_filter="", | |
year_filter="", | |
api_name="/retrieve" | |
) | |
# node 2: call your generator | |
def generate_node(query: str, context: str): | |
client = Client("giz/chatfed_generator") | |
return client.predict( | |
query=query, | |
context=context, | |
api_name="/generate" | |
) | |
# build the graph | |
graph = Graph() | |
n1 = graph.add_node(NodeFunction(retrieve_node), name="retrieve") | |
n2 = graph.add_node(NodeFunction(generate_node), name="generate") | |
graph.link(n1, n2) | |
# expose a simple Gradio interface that drives the graph | |
def pipeline(query: str): | |
# run the graph: pass query into retrieve, then into generate | |
result = graph.run({ | |
"retrieve": {"query": query} | |
}) | |
return result["generate"] | |
iface = gr.Interface( | |
fn=pipeline, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."), | |
outputs="text", | |
title="Modular RAG Pipeline", | |
) | |
if __name__ == "__main__": | |
# when HF spins up your container it will pick up port 7860 by default | |
iface.launch(server_name="0.0.0.0", server_port=7860) |