SHEN1017's picture
Upload 97 files
96b6673 verified
import torch
from citekit.prompt.prompt import Prompt
import re
from citekit.utils.utils import one_paragraph, first_sentence, make_as
import random
import os
class Module:
module_count = 1
def __init__(self,prompt_maker: Prompt = None, pipeline = None, self_prompt = {}, iterative = False, merge = False, max_turn =6, output_as = None, parallel = False) -> None:
self.self_prompt = self_prompt
self.use_head_prompt = True
self.connect_to(pipeline)
self.prompt_maker = prompt_maker
self.last_message = ''
self.destinations = []
self.conditions = {}
self.head_key = None
self.parallel = parallel
self.iterative = iterative
self.merge = merge
self.head_process = one_paragraph
self.max_turn = max_turn
self.multi_process = False
self.output_cond = {} # {cond : {'post_processing':post, 'end':end}}
self.count = Module.module_count
Module.module_count += 1
self.if_add_output_to_head = False
self.turns = 0
self.end = False
def __str__(self) -> str:
if self.model_type:
return f'{self.model_type}-[{self.count}]'
else:
return f'Unknown-type module-[{self.count}]'
def get_json_config(self, config):
print('get_json_config:',config)
avaliable_mapping = {
'max turn': 'max_turn',
'prompt': 'prompt',
'destination': 'destination',
'global prompt': 'head_key',
}
if config == 'prompt':
prompt_info = {
'template': self.prompt_maker.template,
'components': self.prompt_maker.components
}
self_info = self.self_prompt
return {
'prompt_info': prompt_info,
'self_info': self_info
}
elif config == 'destination':
return {
'destination': str(self.destinations[0])
}
elif config in ['max turn','global prompt']:
config = avaliable_mapping[config]
print('getting the config:',config)
return getattr(self, config)
else:
raise NotImplementedError(f'get_json_config for {config} is not implemented')
def get_destinations(self):
return self.destinations
def update(self, config, update_info):
if config == 'prompt':
template = update_info['template']
components = update_info['components']
self_prompt = update_info['self_prompt']
import copy
# avoid changing the original prompt_maker
self.prompt_maker = copy.deepcopy(self.prompt_maker)
self.prompt_maker.update(template=template, components=components)
self.self_prompt = self_prompt
elif config == 'destination':
print('update destination:',update_info[0], 'post_processing:',update_info[1])
if update_info[1] == 'None':
self.set_target(update_info[0])
else:
self.set_target(update_info[0], post_processing=make_as(update_info[1]))
elif config == 'delete_destination':
for i, d in enumerate(self.destinations):
if str(d) == str(update_info):
self.destinations.remove(d)
del self.conditions[d]
break
elif config == 'header':
self.add_to_head(update_info, sub = True)
elif config == 'max turn':
self.max_turn = update_info
else:
raise NotImplementedError(f'update for {config} is not implemented')
def end_multi(self):
return
def set_use_head_prompt(self,use):
assert isinstance(use,bool)
self.use_head_prompt = use
def reset(self):
self.end = False
self.turns = 0
def change_to_multi_process(self,bool_value):
if bool_value:
self.last_message = []
else:
self.last_message = ''
self.multi_process = bool_value
@property
def get_use_head_prompt(self):
return self.use_head_prompt
def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}):
raise NotImplementedError
def send(self):
for destination in self.destinations:
cond = self.conditions[destination]['condition']
if cond(self):
return destination
return None
def set_target(self,destination, condition = lambda self: True, post_processing = lambda x:x) -> None:
self.conditions[destination] = {'condition': condition, 'post_processing' : post_processing}
self.destinations = [destination] + self.destinations
destination.connect_to(self.pipeline)
def clear_destination(self):
self.destinations = []
self.conditions = {}
def add_output_to_head(self, outputs):
if self.if_add_output_to_head:
if not self.head_sub:
if self.head_key not in self.pipeline.head.keys():
self.pipeline.head.update({self.head_key: self.head_process(outputs)})
else:
self.pipeline.head[self.head_key] += '\n'
self.pipeline.head[self.head_key] += self.head_process(outputs)
else:
self.pipeline.head[self.head_key] = self.head_process(outputs)
def connect_to(self, pipeline = None) -> None:
self.pipeline = pipeline
if pipeline:
pipeline.module.append(self)
def output(self):
outed = False
for cond, post_and_end in self.output_cond.items():
if cond(self):
if not outed:
if not self.merge:
self.pipeline.output.append(post_and_end['post_processing'](self.last_message))
else:
self.pipeline.output.append(post_and_end['post_processing'](''.join(self.last_message)))
outed = True
if post_and_end['end']:
self.end = True
def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True):
self.output_cond[cond] = {'post_processing': post_processing, 'end' : end}
def get_first_module(self):
return self
def add_to_head(self, datakey, sub = False, process = None):
self.if_add_output_to_head = True
self.head_key = datakey
self.head_sub = sub
if process:
self.head_process = process
def load_model(model_name_or_path,dtype = torch.float16):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.eval()
return model, tokenizer
class LLM(Module):
model_type = 'Generator'
def __init__(self, model = None, prompt_maker: Prompt =None, pipeline = None, post_processing = None, self_prompt = {}, device = 'cpu',temperature = 0.5 ,stop = None, max_turn = 6, share_model_with = None, iterative = False, auto_cite = False, output = None,merge = False, noisy = True, parallel = False, output_as ='Answer', auto_cite_from = 'docs') -> None:
super().__init__(prompt_maker,pipeline,self_prompt, iterative, merge, parallel = parallel)
self.max_turn = max_turn
if post_processing:
self.post_processing = post_processing
else:
self.post_processing = lambda x: {output_as:x}
if model:
self.model_name = model
self.stop = stop
self.multi_process = False
self.noisy = noisy
self.head_process = one_paragraph
self.auto_cite = auto_cite
if auto_cite:
self.cite_from = auto_cite_from
if model:
if 'gpt' not in model.lower():
if not share_model_with:
print('loading model...')
self.model, self.tokenizer = self.load_model(model)
else:
print('sharing model...')
self.model, self.tokenizer = share_model_with.model, share_model_with.tokenizer
self.temperature = temperature
self.device = device
else:
self.openai_key = os.getenv('OPENAI_API_KEY')
self.output_cond = {} # {cond : {'post_processing':post, 'end':end}}
self.if_add_output_to_head = False
self.token_used = 0
def reset(self):
self.end = False
self.turns = 0
self.token_used = 0
def __str__(self) -> str:
if self.model_name:
return f'{self.model_name}-[{self.count}]'
else:
return 'unknown model'
def __repr__(self) -> str:
return (f'{self.prompt_maker}\n|\n|\nV\n{self}\n|\n|\nV\n'+ '/'.join([str(des) for des in self.destinations]+['output']))
def load_model(self, model_name_or_path,dtype = torch.float16):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.eval()
return model, tokenizer
def set_cite(self,key):
self.cite_from = key
self.auto_cite = True
def generate_content(self, prompt):
if 'gpt' in self.model_name.lower():
import openai
openai.api_key = self.openai_key
prompt = [
{'role': 'system',
'content': "You are a good helper who follow the instructions"},
{'role': 'user', 'content': prompt}
]
response = openai.ChatCompletion.create(
model=self.model_name,
messages=prompt,
max_tokens=500,
stop = self.stop
)
self.token_used += response['usage']['completion_tokens'] + response['usage']['prompt_tokens']
return response['choices'][0]['message']['content']
else:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
stop = [] if self.stop is None else self.stop
outputs = self.model.generate(
**inputs,
do_sample = True,
max_new_tokens = 200,
temperature = self.temperature
)
self.token_used += len(outputs[0])
outputs = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True)
return one_paragraph(outputs)
print(outputs)
def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}):
if self.use_head_prompt:
#print(head_prompt,self.self_prompt,dynamic_prompt)
prompt = self.prompt_maker(head_prompt,self.self_prompt,dynamic_prompt)
else:
prompt = self.prompt_maker(self.self_prompt,dynamic_prompt)
if self.noisy:
print(f'prompt to {str(self)}:\n',prompt,'\n\n')
self.turns += 1
outputs = self.generate_content(prompt)
#print('DEBUG:',outputs)
if self.noisy:
print('OUTPUT:')
print(outputs)
if self.auto_cite:
outputs = self.cite_from_prompt({**head_prompt,**self.self_prompt,**dynamic_prompt},outputs)
if self.multi_process:
self.last_message.append(outputs)
else:
self.last_message = outputs
self.add_output_to_head(outputs)
destination = self.send()
if self.turns > self.max_turn:
self.end = True
if destination in self.conditions:
return self.conditions[destination]['post_processing'](outputs)
else:
return self.post_processing(outputs)
def add_output_to_head(self, outputs):
if self.if_add_output_to_head:
if not self.head_sub:
if self.head_key not in self.pipeline.head.keys():
self.pipeline.head.update({self.head_key: self.head_process(outputs)})
else:
self.pipeline.head[self.head_key] += '\n'
self.pipeline.head[self.head_key] += self.head_process(outputs)
else:
self.pipeline.head[self.head_key] = self.head_process(outputs)
def output(self):
outed = False
for cond, post_and_end in self.output_cond.items():
if cond(self):
if not outed:
if not self.merge and not self.iterative:
self.pipeline.output.append(post_and_end['post_processing'](self.last_message))
else:
self.pipeline.output.append(post_and_end['post_processing'](' '.join(self.last_message)))
outed = True
if post_and_end['end']:
self.end = True
def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True):
self.output_cond[cond] = {'post_processing': post_processing, 'end' : end}
def cite_from_prompt(self,prompt_dict,input):
input = first_sentence(input)
cite_docs = prompt_dict[self.cite_from]
refs = re.findall(r'\[\d+\]', cite_docs)
pattern = r'([.!?])\s*$'
if refs:
cite = ''.join(refs)
else:
cite = ''
output = re.sub(pattern, rf' {cite}\1 ', input)
if cite not in output:
output += cite
return output
def add_to_head(self, datakey, sub = False, process = None):
self.if_add_output_to_head = True
self.head_key = datakey
self.head_sub = sub
if process:
self.head_process = process
class TestLLM(LLM):
def __init__(self, model='gpt-4', prompt_maker: Prompt = None, pipeline=None, post_processing=lambda x: x, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6,share_model_with = None, iterative= False, ans = None) -> None:
super().__init__(model,prompt_maker,pipeline,self_prompt=self_prompt,share_model_with=share_model_with,iterative=iterative)
self.max_turn = max_turn
self.post_processing = post_processing
self.model_name = model
self.last_message = ''
self.stop = stop
self.output_cond = {} # {cond : {'post_processing':post, 'end':end}}
self.if_add_output_to_head = False
self.token_used = 0
self.ans = 'Strain[1], turns:, heat[2][4]. Sent2[5]. Sent3.\n\n rdd' if not ans else ans
def generate_content(self, prompt):
return self.ans
class AutoAISLLM(LLM):
def __init__(self, model=None, prompt_maker: Prompt = None, pipeline=None, post_processing=None, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6, share_model_with=None, iterative=False, auto_cite=False, output=None, merge=False, noisy=False, output_as='Answer') -> None:
super().__init__(model, prompt_maker, pipeline, post_processing, self_prompt, device, temperature, stop, max_turn, share_model_with, iterative, auto_cite, output, merge, noisy, output_as)
self.prompt_maker = Prompt('<INST><premise><claim>\n Answer: ',components={
'INST':'{INST}\n\n',
'premise':'Premise: {premise}\n\n',
'claim':'Claim: {claim}\n',
})
self.self_prompt={'INST': 'In this task, you will be presented a premise and a claim. If the premise entails the claim, output "1", otherwise output "1". Your answer should only contains one number without any other letters and punctuations.'}
def generate(self, premise, claim):
dict_answer = super().generate({'premise':premise,'claim':claim})
return dict_answer.get('Answer')
if __name__ == '__main__':
prompt = Prompt(template='<INST><Question><Docs><feedback><Answer>',components={'INST':'{INST}\n\n',
'Question':'Question:{Question}\n\n',
'Docs':'{Docs}\n',
'feedback':'Here is the feed back of your last response:{feedback}\n',
'Answer':'Here is answer and you have to give feedback:{Answer}'})
m = LLM('gpt')