|
import torch |
|
from citekit.prompt.prompt import Prompt |
|
import re |
|
from citekit.utils.utils import one_paragraph, first_sentence, make_as |
|
import random |
|
import os |
|
|
|
|
|
|
|
class Module: |
|
module_count = 1 |
|
def __init__(self,prompt_maker: Prompt = None, pipeline = None, self_prompt = {}, iterative = False, merge = False, max_turn =6, output_as = None, parallel = False) -> None: |
|
self.self_prompt = self_prompt |
|
self.use_head_prompt = True |
|
self.connect_to(pipeline) |
|
self.prompt_maker = prompt_maker |
|
self.last_message = '' |
|
self.destinations = [] |
|
self.conditions = {} |
|
self.head_key = None |
|
self.parallel = parallel |
|
self.iterative = iterative |
|
self.merge = merge |
|
self.head_process = one_paragraph |
|
self.max_turn = max_turn |
|
self.multi_process = False |
|
self.output_cond = {} |
|
self.count = Module.module_count |
|
Module.module_count += 1 |
|
self.if_add_output_to_head = False |
|
self.turns = 0 |
|
self.end = False |
|
|
|
def __str__(self) -> str: |
|
if self.model_type: |
|
return f'{self.model_type}-[{self.count}]' |
|
else: |
|
return f'Unknown-type module-[{self.count}]' |
|
|
|
def get_json_config(self, config): |
|
print('get_json_config:',config) |
|
avaliable_mapping = { |
|
'max turn': 'max_turn', |
|
'prompt': 'prompt', |
|
'destination': 'destination', |
|
'global prompt': 'head_key', |
|
} |
|
if config == 'prompt': |
|
prompt_info = { |
|
'template': self.prompt_maker.template, |
|
'components': self.prompt_maker.components |
|
} |
|
self_info = self.self_prompt |
|
|
|
return { |
|
'prompt_info': prompt_info, |
|
'self_info': self_info |
|
} |
|
elif config == 'destination': |
|
return { |
|
'destination': str(self.destinations[0]) |
|
} |
|
elif config in ['max turn','global prompt']: |
|
config = avaliable_mapping[config] |
|
print('getting the config:',config) |
|
return getattr(self, config) |
|
else: |
|
raise NotImplementedError(f'get_json_config for {config} is not implemented') |
|
|
|
def get_destinations(self): |
|
return self.destinations |
|
|
|
def update(self, config, update_info): |
|
|
|
if config == 'prompt': |
|
template = update_info['template'] |
|
components = update_info['components'] |
|
self_prompt = update_info['self_prompt'] |
|
import copy |
|
|
|
self.prompt_maker = copy.deepcopy(self.prompt_maker) |
|
|
|
self.prompt_maker.update(template=template, components=components) |
|
self.self_prompt = self_prompt |
|
|
|
elif config == 'destination': |
|
print('update destination:',update_info[0], 'post_processing:',update_info[1]) |
|
if update_info[1] == 'None': |
|
self.set_target(update_info[0]) |
|
else: |
|
self.set_target(update_info[0], post_processing=make_as(update_info[1])) |
|
|
|
elif config == 'delete_destination': |
|
for i, d in enumerate(self.destinations): |
|
if str(d) == str(update_info): |
|
self.destinations.remove(d) |
|
del self.conditions[d] |
|
break |
|
elif config == 'header': |
|
self.add_to_head(update_info, sub = True) |
|
elif config == 'max turn': |
|
self.max_turn = update_info |
|
else: |
|
raise NotImplementedError(f'update for {config} is not implemented') |
|
|
|
def end_multi(self): |
|
return |
|
|
|
def set_use_head_prompt(self,use): |
|
assert isinstance(use,bool) |
|
self.use_head_prompt = use |
|
|
|
def reset(self): |
|
self.end = False |
|
self.turns = 0 |
|
|
|
def change_to_multi_process(self,bool_value): |
|
if bool_value: |
|
self.last_message = [] |
|
else: |
|
self.last_message = '' |
|
self.multi_process = bool_value |
|
@property |
|
def get_use_head_prompt(self): |
|
return self.use_head_prompt |
|
|
|
def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}): |
|
raise NotImplementedError |
|
|
|
def send(self): |
|
for destination in self.destinations: |
|
cond = self.conditions[destination]['condition'] |
|
if cond(self): |
|
return destination |
|
return None |
|
|
|
def set_target(self,destination, condition = lambda self: True, post_processing = lambda x:x) -> None: |
|
self.conditions[destination] = {'condition': condition, 'post_processing' : post_processing} |
|
self.destinations = [destination] + self.destinations |
|
destination.connect_to(self.pipeline) |
|
|
|
def clear_destination(self): |
|
self.destinations = [] |
|
self.conditions = {} |
|
|
|
def add_output_to_head(self, outputs): |
|
if self.if_add_output_to_head: |
|
if not self.head_sub: |
|
if self.head_key not in self.pipeline.head.keys(): |
|
self.pipeline.head.update({self.head_key: self.head_process(outputs)}) |
|
else: |
|
self.pipeline.head[self.head_key] += '\n' |
|
self.pipeline.head[self.head_key] += self.head_process(outputs) |
|
else: |
|
self.pipeline.head[self.head_key] = self.head_process(outputs) |
|
|
|
def connect_to(self, pipeline = None) -> None: |
|
self.pipeline = pipeline |
|
if pipeline: |
|
pipeline.module.append(self) |
|
|
|
def output(self): |
|
outed = False |
|
for cond, post_and_end in self.output_cond.items(): |
|
if cond(self): |
|
if not outed: |
|
if not self.merge: |
|
self.pipeline.output.append(post_and_end['post_processing'](self.last_message)) |
|
else: |
|
self.pipeline.output.append(post_and_end['post_processing'](''.join(self.last_message))) |
|
outed = True |
|
if post_and_end['end']: |
|
self.end = True |
|
|
|
def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True): |
|
self.output_cond[cond] = {'post_processing': post_processing, 'end' : end} |
|
|
|
def get_first_module(self): |
|
return self |
|
|
|
def add_to_head(self, datakey, sub = False, process = None): |
|
self.if_add_output_to_head = True |
|
self.head_key = datakey |
|
self.head_sub = sub |
|
if process: |
|
self.head_process = process |
|
|
|
|
|
def load_model(model_name_or_path,dtype = torch.float16): |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=dtype, |
|
device_map='auto', |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
|
|
class LLM(Module): |
|
model_type = 'Generator' |
|
def __init__(self, model = None, prompt_maker: Prompt =None, pipeline = None, post_processing = None, self_prompt = {}, device = 'cpu',temperature = 0.5 ,stop = None, max_turn = 6, share_model_with = None, iterative = False, auto_cite = False, output = None,merge = False, noisy = True, parallel = False, output_as ='Answer', auto_cite_from = 'docs') -> None: |
|
super().__init__(prompt_maker,pipeline,self_prompt, iterative, merge, parallel = parallel) |
|
self.max_turn = max_turn |
|
if post_processing: |
|
self.post_processing = post_processing |
|
else: |
|
self.post_processing = lambda x: {output_as:x} |
|
if model: |
|
self.model_name = model |
|
self.stop = stop |
|
self.multi_process = False |
|
self.noisy = noisy |
|
self.head_process = one_paragraph |
|
self.auto_cite = auto_cite |
|
if auto_cite: |
|
self.cite_from = auto_cite_from |
|
if model: |
|
if 'gpt' not in model.lower(): |
|
if not share_model_with: |
|
print('loading model...') |
|
self.model, self.tokenizer = self.load_model(model) |
|
else: |
|
print('sharing model...') |
|
self.model, self.tokenizer = share_model_with.model, share_model_with.tokenizer |
|
self.temperature = temperature |
|
self.device = device |
|
else: |
|
self.openai_key = os.getenv('OPENAI_API_KEY') |
|
self.output_cond = {} |
|
self.if_add_output_to_head = False |
|
|
|
self.token_used = 0 |
|
|
|
def reset(self): |
|
self.end = False |
|
self.turns = 0 |
|
self.token_used = 0 |
|
|
|
|
|
def __str__(self) -> str: |
|
if self.model_name: |
|
return f'{self.model_name}-[{self.count}]' |
|
else: |
|
return 'unknown model' |
|
|
|
def __repr__(self) -> str: |
|
return (f'{self.prompt_maker}\n|\n|\nV\n{self}\n|\n|\nV\n'+ '/'.join([str(des) for des in self.destinations]+['output'])) |
|
|
|
def load_model(self, model_name_or_path,dtype = torch.float16): |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=dtype, |
|
device_map='auto', |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
model.eval() |
|
return model, tokenizer |
|
|
|
def set_cite(self,key): |
|
self.cite_from = key |
|
self.auto_cite = True |
|
|
|
def generate_content(self, prompt): |
|
if 'gpt' in self.model_name.lower(): |
|
import openai |
|
openai.api_key = self.openai_key |
|
prompt = [ |
|
{'role': 'system', |
|
'content': "You are a good helper who follow the instructions"}, |
|
{'role': 'user', 'content': prompt} |
|
] |
|
response = openai.ChatCompletion.create( |
|
model=self.model_name, |
|
messages=prompt, |
|
max_tokens=500, |
|
stop = self.stop |
|
) |
|
self.token_used += response['usage']['completion_tokens'] + response['usage']['prompt_tokens'] |
|
return response['choices'][0]['message']['content'] |
|
|
|
else: |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
stop = [] if self.stop is None else self.stop |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
do_sample = True, |
|
max_new_tokens = 200, |
|
temperature = self.temperature |
|
) |
|
self.token_used += len(outputs[0]) |
|
|
|
outputs = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True) |
|
return one_paragraph(outputs) |
|
print(outputs) |
|
|
|
|
|
def generate(self, head_prompt: dict = {}, dynamic_prompt: dict = {}): |
|
if self.use_head_prompt: |
|
|
|
prompt = self.prompt_maker(head_prompt,self.self_prompt,dynamic_prompt) |
|
else: |
|
prompt = self.prompt_maker(self.self_prompt,dynamic_prompt) |
|
if self.noisy: |
|
print(f'prompt to {str(self)}:\n',prompt,'\n\n') |
|
self.turns += 1 |
|
|
|
outputs = self.generate_content(prompt) |
|
|
|
if self.noisy: |
|
print('OUTPUT:') |
|
print(outputs) |
|
if self.auto_cite: |
|
outputs = self.cite_from_prompt({**head_prompt,**self.self_prompt,**dynamic_prompt},outputs) |
|
if self.multi_process: |
|
self.last_message.append(outputs) |
|
else: |
|
self.last_message = outputs |
|
|
|
|
|
self.add_output_to_head(outputs) |
|
|
|
destination = self.send() |
|
|
|
if self.turns > self.max_turn: |
|
self.end = True |
|
if destination in self.conditions: |
|
return self.conditions[destination]['post_processing'](outputs) |
|
else: |
|
return self.post_processing(outputs) |
|
|
|
def add_output_to_head(self, outputs): |
|
if self.if_add_output_to_head: |
|
if not self.head_sub: |
|
if self.head_key not in self.pipeline.head.keys(): |
|
self.pipeline.head.update({self.head_key: self.head_process(outputs)}) |
|
else: |
|
self.pipeline.head[self.head_key] += '\n' |
|
self.pipeline.head[self.head_key] += self.head_process(outputs) |
|
else: |
|
self.pipeline.head[self.head_key] = self.head_process(outputs) |
|
|
|
def output(self): |
|
outed = False |
|
for cond, post_and_end in self.output_cond.items(): |
|
if cond(self): |
|
if not outed: |
|
if not self.merge and not self.iterative: |
|
self.pipeline.output.append(post_and_end['post_processing'](self.last_message)) |
|
else: |
|
self.pipeline.output.append(post_and_end['post_processing'](' '.join(self.last_message))) |
|
outed = True |
|
if post_and_end['end']: |
|
self.end = True |
|
|
|
def set_output(self, cond = lambda self: True, post_processing = lambda x:x, end = True): |
|
self.output_cond[cond] = {'post_processing': post_processing, 'end' : end} |
|
|
|
|
|
def cite_from_prompt(self,prompt_dict,input): |
|
input = first_sentence(input) |
|
cite_docs = prompt_dict[self.cite_from] |
|
refs = re.findall(r'\[\d+\]', cite_docs) |
|
pattern = r'([.!?])\s*$' |
|
if refs: |
|
cite = ''.join(refs) |
|
else: |
|
cite = '' |
|
output = re.sub(pattern, rf' {cite}\1 ', input) |
|
if cite not in output: |
|
output += cite |
|
return output |
|
def add_to_head(self, datakey, sub = False, process = None): |
|
self.if_add_output_to_head = True |
|
self.head_key = datakey |
|
self.head_sub = sub |
|
if process: |
|
self.head_process = process |
|
|
|
|
|
|
|
class TestLLM(LLM): |
|
def __init__(self, model='gpt-4', prompt_maker: Prompt = None, pipeline=None, post_processing=lambda x: x, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6,share_model_with = None, iterative= False, ans = None) -> None: |
|
super().__init__(model,prompt_maker,pipeline,self_prompt=self_prompt,share_model_with=share_model_with,iterative=iterative) |
|
self.max_turn = max_turn |
|
self.post_processing = post_processing |
|
self.model_name = model |
|
self.last_message = '' |
|
self.stop = stop |
|
self.output_cond = {} |
|
self.if_add_output_to_head = False |
|
|
|
self.token_used = 0 |
|
self.ans = 'Strain[1], turns:, heat[2][4]. Sent2[5]. Sent3.\n\n rdd' if not ans else ans |
|
def generate_content(self, prompt): |
|
return self.ans |
|
|
|
|
|
class AutoAISLLM(LLM): |
|
def __init__(self, model=None, prompt_maker: Prompt = None, pipeline=None, post_processing=None, self_prompt={}, device='cpu', temperature=0.5, stop=None, max_turn=6, share_model_with=None, iterative=False, auto_cite=False, output=None, merge=False, noisy=False, output_as='Answer') -> None: |
|
super().__init__(model, prompt_maker, pipeline, post_processing, self_prompt, device, temperature, stop, max_turn, share_model_with, iterative, auto_cite, output, merge, noisy, output_as) |
|
|
|
self.prompt_maker = Prompt('<INST><premise><claim>\n Answer: ',components={ |
|
'INST':'{INST}\n\n', |
|
'premise':'Premise: {premise}\n\n', |
|
'claim':'Claim: {claim}\n', |
|
}) |
|
self.self_prompt={'INST': 'In this task, you will be presented a premise and a claim. If the premise entails the claim, output "1", otherwise output "1". Your answer should only contains one number without any other letters and punctuations.'} |
|
|
|
def generate(self, premise, claim): |
|
dict_answer = super().generate({'premise':premise,'claim':claim}) |
|
return dict_answer.get('Answer') |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
prompt = Prompt(template='<INST><Question><Docs><feedback><Answer>',components={'INST':'{INST}\n\n', |
|
'Question':'Question:{Question}\n\n', |
|
'Docs':'{Docs}\n', |
|
'feedback':'Here is the feed back of your last response:{feedback}\n', |
|
'Answer':'Here is answer and you have to give feedback:{Answer}'}) |
|
m = LLM('gpt') |