import os import sys import json import tempfile import pandas as pd import gradio as gr from gradio_i18n import Translate, gettext as _ from webui.base import GraphGenParams from webui.test_api import test_api_connection from webui.cache_utils import setup_workspace, cleanup_workspace from webui.count_tokens import count_tokens # pylint: disable=wrong-import-position root_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(root_dir) from graphgen.graphgen import GraphGen from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy from graphgen.models.llm.limitter import RPM, TPM from graphgen.utils import set_logger css = """ .center-row { display: flex; justify-content: center; align-items: center; } """ def init_graph_gen(config: dict, env: dict) -> GraphGen: # Set up working directory log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) set_logger(log_file, if_stream=False) graph_gen = GraphGen( working_dir=working_dir ) # Set up LLM clients graph_gen.synthesizer_llm_client = OpenAIModel( model_name=env.get("SYNTHESIZER_MODEL", ""), base_url=env.get("SYNTHESIZER_BASE_URL", ""), api_key=env.get("SYNTHESIZER_API_KEY", ""), request_limit=True, rpm= RPM(env.get("RPM", 1000)), tpm= TPM(env.get("TPM", 50000)), ) graph_gen.trainee_llm_client = OpenAIModel( model_name=env.get("TRAINEE_MODEL", ""), base_url=env.get("TRAINEE_BASE_URL", ""), api_key=env.get("TRAINEE_API_KEY", ""), request_limit=True, rpm= RPM(env.get("RPM", 1000)), tpm= TPM(env.get("TPM", 50000)), ) graph_gen.tokenizer_instance = Tokenizer( config.get("tokenizer", "cl100k_base")) strategy_config = config.get("traverse_strategy", {}) graph_gen.traverse_strategy = TraverseStrategy( qa_form=config.get("qa_form"), expand_method=strategy_config.get("expand_method"), bidirectional=strategy_config.get("bidirectional"), max_extra_edges=strategy_config.get("max_extra_edges"), max_tokens=strategy_config.get("max_tokens"), max_depth=strategy_config.get("max_depth"), edge_sampling=strategy_config.get("edge_sampling"), isolated_node_strategy=strategy_config.get("isolated_node_strategy"), loss_strategy=str(strategy_config.get("loss_strategy")) ) return graph_gen # pylint: disable=too-many-statements def run_graphgen(params, progress=gr.Progress()): def sum_tokens(client): return sum(u["total_tokens"] for u in client.token_usage) config = { "if_trainee_model": params.if_trainee_model, "input_file": params.input_file, "tokenizer": params.tokenizer, "qa_form": params.qa_form, "web_search": False, "quiz_samples": params.quiz_samples, "traverse_strategy": { "bidirectional": params.bidirectional, "expand_method": params.expand_method, "max_extra_edges": params.max_extra_edges, "max_tokens": params.max_tokens, "max_depth": params.max_depth, "edge_sampling": params.edge_sampling, "isolated_node_strategy": params.isolated_node_strategy, "loss_strategy": params.loss_strategy }, "chunk_size": params.chunk_size, } env = { "SYNTHESIZER_BASE_URL": params.synthesizer_url, "SYNTHESIZER_MODEL": params.synthesizer_model, "TRAINEE_BASE_URL": params.trainee_url, "TRAINEE_MODEL": params.trainee_model, "SYNTHESIZER_API_KEY": params.api_key, "TRAINEE_API_KEY": params.trainee_api_key, "RPM": params.rpm, "TPM": params.tpm, } # Test API connection test_api_connection(env["SYNTHESIZER_BASE_URL"], env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"]) if config['if_trainee_model']: test_api_connection(env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]) # Initialize GraphGen graph_gen = init_graph_gen(config, env) graph_gen.clear() graph_gen.progress_bar = progress try: # Load input data file = config['input_file'] if isinstance(file, list): file = file[0] data = [] if file.endswith(".jsonl"): data_type = "raw" with open(file, "r", encoding='utf-8') as f: data.extend(json.loads(line) for line in f) elif file.endswith(".json"): data_type = "chunked" with open(file, "r", encoding='utf-8') as f: data.extend(json.load(f)) elif file.endswith(".txt"): # 读取文件后根据chunk_size转成raw格式的数据 data_type = "raw" content = "" with open(file, "r", encoding='utf-8') as f: lines = f.readlines() for line in lines: content += line.strip() + " " size = int(config.get("chunk_size", 512)) chunks = [ content[i:i + size] for i in range(0, len(content), size) ] data.extend([{"content": chunk} for chunk in chunks]) else: raise ValueError(f"Unsupported file type: {file}") # Process the data graph_gen.insert(data, data_type) if config['if_trainee_model']: # Generate quiz graph_gen.quiz(max_samples=config['quiz_samples']) # Judge statements graph_gen.judge() else: graph_gen.traverse_strategy.edge_sampling = "random" # Skip judge statements graph_gen.judge(skip=True) # Traverse graph graph_gen.traverse() # Save output output_data = graph_gen.qa_storage.data with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False, encoding="utf-8") as tmpfile: json.dump(output_data, tmpfile, ensure_ascii=False) output_file = tmpfile.name synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client) trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0 total_tokens = synthesizer_tokens + trainee_tokens data_frame = params.token_counter try: _update_data = [ [ data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens) ] ] new_df = pd.DataFrame( _update_data, columns=data_frame.columns ) data_frame = new_df except Exception as e: raise gr.Error(f"DataFrame operation error: {str(e)}") return output_file, gr.DataFrame(label='Token Stats', headers=["Source Text Token Count", "Expected Token Usage", "Token Used"], datatype="str", interactive=False, value=data_frame, visible=True, wrap=True) except Exception as e: # pylint: disable=broad-except raise gr.Error(f"Error occurred: {str(e)}") finally: # Clean up workspace cleanup_workspace(graph_gen.working_dir) with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo): # Header gr.Image(value="https://github.com/open-sciencelab/GraphGen/blob/main/resources/images/logo.png?raw=true", label="GraphGen Banner", elem_id="banner", interactive=False, container=False, show_download_button=False, show_fullscreen_button=False) lang_btn = gr.Radio( choices=[ ("English", "en"), ("简体中文", "zh"), ], value="en", # label=_("Language"), render=False, container=False, elem_classes=["center-row"], ) gr.HTML("""
Version Documentation GitHub Stars arXiv
""") with Translate( os.path.join(root_dir, 'webui', 'translation.json'), lang_btn, placeholder_langs=["en", "zh"], persistant= False, # True to save the language setting in the browser. Requires gradio >= 5.6.0 ): lang_btn.render() gr.Markdown( value = "# " + _("Title") + "\n\n" + \ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro") ) if_trainee_model = gr.Checkbox(label=_("Use Trainee Model"), value=False, interactive=True) with gr.Accordion(label=_("Model Config"), open=False): synthesizer_url = gr.Textbox(label="Synthesizer URL", value="https://api.siliconflow.cn/v1", info=_("Synthesizer URL Info"), interactive=True) synthesizer_model = gr.Textbox(label="Synthesizer Model", value="Qwen/Qwen2.5-7B-Instruct", info=_("Synthesizer Model Info"), interactive=True) trainee_url = gr.Textbox(label="Trainee URL", value="https://api.siliconflow.cn/v1", info=_("Trainee URL Info"), interactive=True, visible=if_trainee_model.value is True) trainee_model = gr.Textbox( label="Trainee Model", value="Qwen/Qwen2.5-7B-Instruct", info=_("Trainee Model Info"), interactive=True, visible=if_trainee_model.value is True) trainee_api_key = gr.Textbox( label=_("SiliconCloud Token for Trainee Model"), type="password", value="", info="https://cloud.siliconflow.cn/account/ak", visible=if_trainee_model.value is True) with gr.Accordion(label=_("Generation Config"), open=False): chunk_size = gr.Slider(label="Chunk Size", minimum=256, maximum=4096, value=512, step=256, interactive=True) tokenizer = gr.Textbox(label="Tokenizer", value="cl100k_base", interactive=True) qa_form = gr.Radio(choices=["atomic", "multi_hop", "aggregated"], label="QA Form", value="aggregated", interactive=True) quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1, interactive=True, visible=if_trainee_model.value is True) bidirectional = gr.Checkbox(label="Bidirectional", value=True, interactive=True) expand_method = gr.Radio(choices=["max_width", "max_tokens"], label="Expand Method", value="max_tokens", interactive=True) max_extra_edges = gr.Slider( minimum=1, maximum=10, value=5, label="Max Extra Edges", step=1, interactive=True, visible=expand_method.value == "max_width") max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, label="Max Tokens", step=64, interactive=True, visible=(expand_method.value != "max_width")) max_depth = gr.Slider(minimum=1, maximum=5, value=2, label="Max Depth", step=1, interactive=True) edge_sampling = gr.Radio( choices=["max_loss", "min_loss", "random"], label="Edge Sampling", value="max_loss", interactive=True, visible=if_trainee_model.value is True) isolated_node_strategy = gr.Radio(choices=["add", "ignore"], label="Isolated Node Strategy", value="ignore", interactive=True) loss_strategy = gr.Radio(choices=["only_edge", "both"], label="Loss Strategy", value="only_edge", interactive=True) with gr.Row(equal_height=True): with gr.Column(scale=3): api_key = gr.Textbox( label=_("SiliconCloud Token"), type="password", value="", info="https://cloud.siliconflow.cn/account/ak") with gr.Column(scale=1): test_connection_btn = gr.Button(_("Test Connection")) with gr.Blocks(): with gr.Row(equal_height=True): with gr.Column(): rpm = gr.Slider( label="RPM", minimum=10, maximum=10000, value=1000, step=100, interactive=True, visible=True) with gr.Column(): tpm = gr.Slider( label="TPM", minimum=5000, maximum=5000000, value=50000, step=1000, interactive=True, visible=True) with gr.Blocks(): with gr.Row(equal_height=True): with gr.Column(scale=1): upload_file = gr.File( label=_("Upload File"), file_count="single", file_types=[".txt", ".json", ".jsonl"], interactive=True, ) examples_dir = os.path.join(root_dir, 'webui', 'examples') gr.Examples(examples=[ [os.path.join(examples_dir, "txt_demo.txt")], [os.path.join(examples_dir, "raw_demo.jsonl")], [os.path.join(examples_dir, "chunked_demo.json")], ], inputs=upload_file, label=_("Example Files"), examples_per_page=3) with gr.Column(scale=1): output = gr.File( label="Output(See Github FAQ)", file_count="single", interactive=False, ) with gr.Blocks(): token_counter = gr.DataFrame(label='Token Stats', headers=["Source Text Token Count", "Estimated Token Usage", "Token Used"], datatype="str", interactive=False, visible=False, wrap=True) submit_btn = gr.Button(_("Run GraphGen")) # Test Connection test_connection_btn.click( test_api_connection, inputs=[synthesizer_url, api_key, synthesizer_model], outputs=[]) if if_trainee_model.value: test_connection_btn.click(test_api_connection, inputs=[trainee_url, api_key, trainee_model], outputs=[]) expand_method.change(lambda method: (gr.update(visible=method == "max_width"), gr.update(visible=method != "max_width")), inputs=expand_method, outputs=[max_extra_edges, max_tokens]) if_trainee_model.change( lambda use_trainee: [gr.update(visible=use_trainee)] * 5, inputs=if_trainee_model, outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key]) upload_file.change( lambda x: (gr.update(visible=True)), inputs=[upload_file], outputs=[token_counter], ).then( count_tokens, inputs=[upload_file, tokenizer, token_counter], outputs=[token_counter], ) # run GraphGen submit_btn.click( lambda x: (gr.update(visible=False)), inputs=[token_counter], outputs=[token_counter], ) submit_btn.click( lambda *args: run_graphgen(GraphGenParams( if_trainee_model=args[0], input_file=args[1], tokenizer=args[2], qa_form=args[3], bidirectional=args[4], expand_method=args[5], max_extra_edges=args[6], max_tokens=args[7], max_depth=args[8], edge_sampling=args[9], isolated_node_strategy=args[10], loss_strategy=args[11], synthesizer_url=args[12], synthesizer_model=args[13], trainee_model=args[14], api_key=args[15], chunk_size=args[16], rpm=args[17], tpm=args[18], quiz_samples=args[19], trainee_url=args[20], trainee_api_key=args[21], token_counter=args[22], )), inputs=[ if_trainee_model, upload_file, tokenizer, qa_form, bidirectional, expand_method, max_extra_edges, max_tokens, max_depth, edge_sampling, isolated_node_strategy, loss_strategy, synthesizer_url, synthesizer_model, trainee_model, api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter ], outputs=[output, token_counter], ) if __name__ == "__main__": demo.queue(api_open=False, default_concurrency_limit=2) demo.launch(server_name='0.0.0.0')