import os import json import time import argparse from importlib.resources import files import yaml from dotenv import load_dotenv from .graphgen import GraphGen from .models import OpenAIModel, Tokenizer, TraverseStrategy from .utils import set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) load_dotenv() def set_working_dir(folder): os.makedirs(folder, exist_ok=True) os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True) os.makedirs(os.path.join(folder, "logs"), exist_ok=True) def save_config(config_path, global_config): if not os.path.exists(os.path.dirname(config_path)): os.makedirs(os.path.dirname(config_path)) with open(config_path, "w", encoding='utf-8') as config_file: yaml.dump(global_config, config_file, default_flow_style=False, allow_unicode=True) def main(): parser = argparse.ArgumentParser() parser.add_argument('--config_file', help='Config parameters for GraphGen.', # default=os.path.join(sys_path, "configs", "graphgen_config.yaml"), default=files('graphgen').joinpath("configs", "graphgen_config.yaml"), type=str) parser.add_argument('--output_dir', help='Output directory for GraphGen.', default=sys_path, required=True, type=str) args = parser.parse_args() working_dir = args.output_dir set_working_dir(working_dir) unique_id = int(time.time()) set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False) with open(args.config_file, "r", encoding='utf-8') as f: config = yaml.load(f, Loader=yaml.FullLoader) input_file = config['input_file'] if config['data_type'] == 'raw': with open(input_file, "r", encoding='utf-8') as f: data = [json.loads(line) for line in f] elif config['data_type'] == 'chunked': with open(input_file, "r", encoding='utf-8') as f: data = json.load(f) else: raise ValueError(f"Invalid data type: {config['data_type']}") synthesizer_llm_client = OpenAIModel( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), base_url=os.getenv("SYNTHESIZER_BASE_URL") ) trainee_llm_client = OpenAIModel( model_name=os.getenv("TRAINEE_MODEL"), api_key=os.getenv("TRAINEE_API_KEY"), base_url=os.getenv("TRAINEE_BASE_URL") ) traverse_strategy = TraverseStrategy( **config['traverse_strategy'] ) graph_gen = GraphGen( working_dir=working_dir, unique_id=unique_id, synthesizer_llm_client=synthesizer_llm_client, trainee_llm_client=trainee_llm_client, if_web_search=config['web_search'], tokenizer_instance=Tokenizer( model_name=config['tokenizer'] ), traverse_strategy=traverse_strategy ) graph_gen.insert(data, config['data_type']) graph_gen.quiz(max_samples=config['quiz_samples']) graph_gen.judge(re_judge=config["re_judge"]) graph_gen.traverse() path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml") save_config(path, config) if __name__ == '__main__': main()