Spaces:
Running
Running
import os | |
import argparse | |
import asyncio | |
from dotenv import load_dotenv | |
from .models import NetworkXStorage, JsonKVStorage, OpenAIModel | |
from .operators import judge_statement | |
sys_path = os.path.abspath(os.path.dirname(__file__)) | |
load_dotenv() | |
def calculate_average_loss(graph: NetworkXStorage): | |
""" | |
Calculate the average loss of the graph. | |
:param graph: NetworkXStorage | |
:return: float | |
""" | |
edges = asyncio.run(graph.get_all_edges()) | |
total_loss = 0 | |
for edge in edges: | |
total_loss += edge[2]['loss'] | |
return total_loss / len(edges) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph') | |
parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output') | |
args = parser.parse_args() | |
llm_client = OpenAIModel( | |
model_name=os.getenv("TRAINEE_MODEL"), | |
api_key=os.getenv("TRAINEE_API_KEY"), | |
base_url=os.getenv("TRAINEE_BASE_URL") | |
) | |
graph_storage = NetworkXStorage( | |
args.input, | |
namespace="graph" | |
) | |
average_loss = calculate_average_loss(graph_storage) | |
print(f"Average loss of the graph: {average_loss}") | |
rephrase_storage = JsonKVStorage( | |
os.path.join(sys_path, "cache"), | |
namespace="rephrase" | |
) | |
new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True)) | |
graph_file = asyncio.run(graph_storage.get_graph()) | |
new_graph.write_nx_graph(graph_file, args.output) | |
average_loss = calculate_average_loss(new_graph) | |
print(f"Average loss of the graph: {average_loss}") | |