from citekit.cite_modules.LLM import LLM from citekit.cite_modules.augment_model import ( Retriever, CitationSimplyfier, Verifier, Ranker, ) from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT, PIPELINE_DOC_CACHE from citekit.prompt.prompt import Prompt, ALCEDocPrompt, DocPrompt, NewALCEVanillaPrompt from citekit.Dataset.Dataset import PromptDataset from citekit.evaluator.evaluator import ( DefaultEvaluator, compute_autoais, test_compute_autoais, ) from citekit.utils.utils import ( sentence, one_paragraph, each_make_as, each_make_as, make_as, remove_citations, compute_str_em, ) import json import argparse from parser import * def segment(i, text): return [make_as("docs")(doc) for doc in text.split("\n") if doc] def segment_query(text): return [make_as("query")(doc) for doc in text.split("\n") if doc] if __name__ == "__main__": # SETTING ARGS parser = argparse.ArgumentParser() parser.add_argument( "--save_path", type=str, default="resu.json", help="Path to the config file" ) parser.add_argument( "--model", type=str, default="gpt-3.5-turbo", help="model name or path" ) parser.add_argument("--shots", type=int, default=2, help="number of shots") parser.add_argument("--ndoc", type=int, default=5, help="number of docs") parser.add_argument("--pr", action="store_true", help="use cite PR") parser.add_argument("--rouge", action="store_true", help="use rouge") parser.add_argument("--temp", type=float, default=0.5, help="temperature") parser.add_argument("--qa", action="store_true", help="eval qa") parser.add_argument("--mauve", action="store_true", help="eval mauve") parser.add_argument("--length", type=bool, default=True, help="eval length") parser.add_argument("--claims", action="store_true", help="eval claims") parser.add_argument("--qampari", type=str, default=False, help="eval qampari") parser.add_argument( "--dataset", type=str, default="data/asqa_eval_gtr_top100.json", help="dataset" ) parser.add_argument( "--demo", type=str, default="prompts/asqa_default.json", help="demo" ) parser.add_argument("--doctype", type=str, default="text", help="demo") parser.add_argument("--data_num", type=int, default=1000, help="num of data") parser.add_argument( "--mode", type=str, default="text", help="mode-granularity: text, extraction or summary", ) parser.add_argument("--k", type=float, default=1.5, help="coefficient of em") parser.add_argument("--topk", type=int, default=2, help="topk") args = parser.parse_args() def score(data): pr = compute_autoais(data) p = pr["citation_prec"] r = pr["citation_rec"] em = compute_str_em(data) return p + r + args.k * em return 1 # DATA LOADING file_path = args.dataset demo_path = args.demo with open(file_path, "r", encoding="utf-8") as file: dataset = json.load(file) with open(demo_path, "r", encoding="utf-8") as file: demo = json.load(file) data_num = min(args.data_num, len(dataset)) llm_instruction = demo["one_sentence_instruction"] query_inst = demo["query_instruction"] shots = "\n\n".join( NewALCEVanillaPrompt().load_data( [demo["demos"][1], demo["demos"][3]], "question", answer=lambda data: remove_citations( sentence("first")(data["answer"])["first"] ), INST=lambda _: llm_instruction, docs=lambda data: "".join( ALCEDocPrompt().default_load_data(data["docs"][1:2]) ), ) ) documents = [ DocPrompt().load_data( list(enumerate(data["docs"])), Title=lambda data: data[1]["title"], Passage=lambda data: data[1][args.mode], ) for data in dataset ] dataset = PromptDataset( dataset, "question", "answer", "answers", "qa_pairs", "claims", docs=lambda data: ALCEDocPrompt().default_load_data(data["docs"][: args.ndoc]), )[:data_num] prompt = Prompt( template="\nAnswer:", components={ "INST": "{INST}\n\n", "shots": "{shots}\n", "question": "Question:{question}\n\n", "ans": "Prefix:{ans}\n\n", "docs": "{docs}\n", }, ) queryprompt = Prompt( template="Please generate at most three queries regarding possible subquestions of the given question. Your queris should be diverse and informative in natual language, splited by a new line.\n", components={ "question": "Given the original question: {question}\n", "ans": "The context is: {ans}\n", "prev": "\nPrevious queries:\n{prev}\n\n", "INST": "{INST}\n\n", }, ) retriever_prompt = Prompt(template="", components={"query": "{query}"}) query_generator = LLM( model=args.model, prompt_maker=queryprompt, self_prompt={"INST": query_inst} ) retriever_prompt = Prompt(template="", components={"query": "{query}"}) eval = DefaultEvaluator(args) ranker = Ranker(max_turn=6, iterative=True, fixed_turn=2) # ranker.set_eval('length', output = 'answer') # ranker.new_eval('score', score , output = 'answer', docs = 'doc_cache', qa_pairs = 'qa_pairs') ranker.new_eval("score", score, output="answer", docs="doc_cache") # PIPELINE CONSTRUCTING llm = LLM( model=args.model, prompt_maker=prompt, self_prompt={"INST": llm_instruction, "shots": shots}, max_turn=30, auto_cite=True, share_model_with=query_generator, parallel=True, ) pipeline = Pipeline( save_path=args.save_path, llm=llm, module=[ranker, query_generator], head_prompt_maker=prompt, evaluator=eval, dataset=dataset, ) retriever = Retriever( prompt_maker=retriever_prompt, pipeline=pipeline, retrieve_by="bm25", documents=documents, topk=args.topk, ) query_generator.set_target(retriever, post_processing=segment_query) query_generator.add_to_head("prev", sub=False) retriever.set_target(llm, post_processing=segment) llm.set_target(ranker, post_processing=make_as("answer")) #ranker.set_output(post_processing=lambda x: x["answer"], end=False) ranker.add_to_head( "ans", sub=True, process=lambda text: one_paragraph(text["answer"]) ) #ranker.set_target(query_generator, post_processing=lambda x: {"ans": x["answer"]}) pipeline.set_initial_module(query_generator) pipeline.set_data_keys(["question"]) simplifier = CitationSimplyfier() ranker.set_target(simplifier) simplifier.set_output(end=False) # graph = PipelineGraph(pipeline=pipeline) # html = graph.generate_html_embed(results='old/res_attr.json') # graph.visualize() # print(html) # with open('pipeline_.html','w') as file: # file.write(html) raise KeyError pipeline.run_on_dataset(datakeys=['question'], initial_module=query_generator)