File size: 6,349 Bytes
96b6673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from citekit.cite_modules.LLM import LLM
from citekit.cite_modules.augment_model import Retriever,CitationSimplyfier,Verifier,Ranker, AttributingModule
from citekit.pipeline.pipeline import Pipeline, PIPELINE_OUTPUT, PIPELINE_DOC_CACHE
from citekit.prompt.prompt import Prompt, ALCEDocPrompt,DocPrompt,NewALCEVanillaPrompt
from citekit.Dataset.Dataset import PromptDataset
from citekit.evaluator.evaluator import DefaultEvaluator, compute_autoais, test_compute_autoais
from citekit.utils.utils import sentence, one_paragraph, each_make_as, each_make_as, make_as,remove_citations, compute_str_em

import json
from parser import * 
import argparse

def segment(i,text):
    return [make_as('docs')(doc) for doc in text.split('\n') if doc]




# SETTING ARGS
parser = argparse.ArgumentParser()
parser.add_argument("--save_path", type=str, default='res.json', help="Path to the config file")
parser.add_argument("--model", type=str, default='gpt-3.5-turbo', help="model name or path")
parser.add_argument("--shots", type=int, default=2, help="number of shots")
parser.add_argument("--ndoc", type=int, default=5, help="number of docs")
parser.add_argument("--pr", action='store_true', help="use cite PR")
parser.add_argument("--rouge", action='store_true', help="use rouge")
parser.add_argument("--temp", type=float, default=0.5, help="temperature")
parser.add_argument("--qa", action='store_true', help="eval qa")
parser.add_argument("--mauve",  action='store_true', help="eval mauve")
parser.add_argument("--length", type=bool, default=True, help="eval length")
parser.add_argument("--claims", action='store_true', help="eval claims")
parser.add_argument("--qampari", type=str, default=False, help="eval qampari")
parser.add_argument("--dataset", type=str, default='data/asqa_eval_gtr_top100.json', help="dataset")
parser.add_argument("--demo", type=str, default='prompts/asqa_default.json', help="demo")
parser.add_argument("--doctype", type=str, default='text', help="demo")
parser.add_argument("--data_num", type=int, default=1000, help="num of data")
parser.add_argument("--mode", type=str, default='text', help="mode-granularity: text, extraction or summary")
parser.add_argument("--k", type=float, default=1.5, help="coefficient of em")
parser.add_argument("--topk", type=int, default=1, help="topk")
args = parser.parse_args()

def score(data):
    #pr = compute_autoais(data)
    #p = pr["citation_prec"]
    #r = pr["citation_rec"]
    #em = compute_str_em(data)
    #return p + r + args.k * em
    return 1

# DATA LOADING
file_path = args.dataset
demo_path = args.demo


dataset = []
with open(demo_path,'r',encoding='utf-8') as file:
    demo = json.load(file)
data_num  = min(args.data_num,len(dataset))

llm_instruction = demo['one_sentence_instruction']
query_inst = demo["query_instruction"]
shots = '\n\n'.join(NewALCEVanillaPrompt().load_data([demo['demos'][1],demo['demos'][3]],'question', answer = lambda data: remove_citations(sentence('first')(data['answer'])['first']), INST = lambda _: llm_instruction, docs = lambda data: ''.join(ALCEDocPrompt().default_load_data(data['docs'][1:2]))))


documents = [DocPrompt().load_data(list(enumerate(data['docs'])),Title = lambda data: data[1]['title'], Passage = lambda data: data[1][args.mode]) for data in dataset]

dataset = PromptDataset(dataset,'question','answer','answers','qa_pairs','claims', docs = lambda data: ALCEDocPrompt().default_load_data(data['docs'][:args.ndoc]))[:data_num]

prompt = Prompt(template='<shots><INST><question><ans><docs><span>\nAnswer:', components= {'INST':'{INST}\n\n','shots':'{shots}\n','question':'Question:{question}\n\n', 'ans':'Prefix:{ans}\n\n','docs':'{docs}\n', 'span':'The highlighted spans are: \n{span}\n\n'})
queryprompt = Prompt(template='<INST><question><prev><ans>Please generate one query to help find relevent documents, making sure it is different from previous queries(if provided). your query is:\n',components={'question':'Given the original question: {question}\n','ans':'The context is: {ans}\n','prev':'\nPrevious queries:\n{prev}\n\n','INST':'{INST}\n\n'})

retriever_prompt = Prompt(template='<query>',components={'query':'{query}'})

query_generator = LLM(model=args.model, prompt_maker=queryprompt, self_prompt={'INST':query_inst})
retriever_prompt = Prompt(template='<query>',components={'query':'{query}'})
eval = DefaultEvaluator(args)
ranker = Ranker(max_turn=5, iterative= True)
#ranker.set_eval('length', output = 'answer')
#ranker.new_eval('score', score , output = 'answer', docs = 'doc_cache', qa_pairs = 'qa_pairs')
ranker.new_eval('score', score , output = 'answer', docs = 'doc_cache')
# PIPELINE CONSTRUCTING
llm = LLM(model=args.model,prompt_maker=prompt, self_prompt={'INST':llm_instruction, 'shots':shots}, max_turn= 2, auto_cite=True,share_model_with= query_generator, parallel= True)


pipeline = Pipeline(save_path=args.save_path, llm = llm ,module=[ranker,query_generator],head_prompt_maker=prompt, evaluator=eval,dataset = dataset)

retriever = Retriever(prompt_maker=retriever_prompt, pipeline=pipeline, retrieve_by='bm25', topk=args.topk)
query_generator.set_target(retriever, post_processing=make_as('query'))
query_generator.add_to_head('prev', sub = False)
retriever.set_target(llm,post_processing=segment)
llm.set_target(ranker,post_processing=make_as('answer'))
ranker.set_output(post_processing=lambda x: x['answer'], end = False)

ranker.add_to_head('ans',sub = True,  process = lambda text: one_paragraph(text['answer']) )
ranker.set_target(query_generator,post_processing=lambda x:{'ans': x['answer']})
pipeline.set_initial_module(query_generator)
pipeline.set_data_keys(['question'])

attributer = AttributingModule(model = 'gpt-4o-mini')
attributer.connect_to(pipeline)

# 交互界面加入:make_as之类预定义的输出格式, add to head的设置,prompt输入时的特殊格式。
#retriever.set_target(attributer, post_processing=make_as('docs'))  
#retriever.add_to_head('docs', sub=True)
#attributer.set_target(llm)

#pipeline.run(datakeys=['question'],initial_module=query_generator)




graph = PipelineGraph(pipeline=pipeline)

#html = graph.generate_html(results='results.json')
#graph.visualize()
#print(html)
#with open('pipeline_.html','w') as file:
#    file.write(html)
# RUN PIPELINE
#pipeline.run_on_dataset(datakeys=['question'],initial_module=query_generator)