Update modeling_codeshell.py
Browse files- modeling_codeshell.py +121 -5
modeling_codeshell.py
CHANGED
|
@@ -32,14 +32,17 @@
|
|
| 32 |
"""PyTorch CodeShell model."""
|
| 33 |
import os
|
| 34 |
import math
|
| 35 |
-
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
import torch
|
| 38 |
import torch.utils.checkpoint
|
| 39 |
from torch import nn
|
| 40 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 41 |
|
| 42 |
-
from transformers import PreTrainedModel, PretrainedConfig
|
| 43 |
from transformers.generation.utils import GenerationConfig
|
| 44 |
|
| 45 |
from transformers.activations import ACT2FN
|
|
@@ -54,7 +57,6 @@ from transformers.utils import (
|
|
| 54 |
)
|
| 55 |
from .configuration_codeshell import CodeShellConfig
|
| 56 |
|
| 57 |
-
|
| 58 |
# Fused kernels
|
| 59 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
| 60 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
|
@@ -743,6 +745,62 @@ class CodeShellModel(CodeShellPreTrainedModel):
|
|
| 743 |
hidden_states=all_hidden_states,
|
| 744 |
attentions=all_self_attentions,
|
| 745 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
|
| 748 |
@add_start_docstrings(
|
|
@@ -886,6 +944,65 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
|
|
| 886 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 887 |
)
|
| 888 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
|
| 890 |
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
| 891 |
def __init__(self, config):
|
|
@@ -966,5 +1083,4 @@ class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
|
| 966 |
if device_map is not None:
|
| 967 |
model = model.to(torch.device(device_map))
|
| 968 |
|
| 969 |
-
return model
|
| 970 |
-
|
|
|
|
| 32 |
"""PyTorch CodeShell model."""
|
| 33 |
import os
|
| 34 |
import math
|
| 35 |
+
from typing import List, Optional, Tuple, Union, Callable
|
| 36 |
+
from threading import Thread
|
| 37 |
+
from queue import Queue
|
| 38 |
+
|
| 39 |
|
| 40 |
import torch
|
| 41 |
import torch.utils.checkpoint
|
| 42 |
from torch import nn
|
| 43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 44 |
|
| 45 |
+
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
|
| 46 |
from transformers.generation.utils import GenerationConfig
|
| 47 |
|
| 48 |
from transformers.activations import ACT2FN
|
|
|
|
| 57 |
)
|
| 58 |
from .configuration_codeshell import CodeShellConfig
|
| 59 |
|
|
|
|
| 60 |
# Fused kernels
|
| 61 |
# Use separate functions for each case because conditionals prevent kernel fusion.
|
| 62 |
# TODO: Could have better fused kernels depending on scaling, dropout and head mask.
|
|
|
|
| 745 |
hidden_states=all_hidden_states,
|
| 746 |
attentions=all_self_attentions,
|
| 747 |
)
|
| 748 |
+
|
| 749 |
+
class EndOfFunctionCriteria(StoppingCriteria):
|
| 750 |
+
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
|
| 751 |
+
def __init__(self, input_lengths, eof_strings, tokenizer):
|
| 752 |
+
self.input_lengths = input_lengths
|
| 753 |
+
self.eof_strings = eof_strings
|
| 754 |
+
self.tokenizer = tokenizer
|
| 755 |
+
|
| 756 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 757 |
+
"""Returns true if all generated sequences contain any of the end-of-function strings."""
|
| 758 |
+
decoded_generations = []
|
| 759 |
+
for _input_ids, input_length in zip(input_ids, self.input_lengths):
|
| 760 |
+
decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
|
| 761 |
+
done = []
|
| 762 |
+
for decoded_generation in decoded_generations:
|
| 763 |
+
done.append(
|
| 764 |
+
any(
|
| 765 |
+
[
|
| 766 |
+
stop_string in decoded_generation
|
| 767 |
+
for stop_string in self.eof_strings
|
| 768 |
+
]
|
| 769 |
+
)
|
| 770 |
+
)
|
| 771 |
+
return all(done)
|
| 772 |
+
|
| 773 |
+
class TextIterStreamer:
|
| 774 |
+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
|
| 775 |
+
self.tokenizer = tokenizer
|
| 776 |
+
self.skip_prompt = skip_prompt
|
| 777 |
+
self.skip_special_tokens = skip_special_tokens
|
| 778 |
+
self.tokens = []
|
| 779 |
+
self.text_queue = Queue()
|
| 780 |
+
self.next_tokens_are_prompt = True
|
| 781 |
+
|
| 782 |
+
def put(self, value):
|
| 783 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
| 784 |
+
self.next_tokens_are_prompt = False
|
| 785 |
+
else:
|
| 786 |
+
if len(value.shape) > 1:
|
| 787 |
+
value = value[0]
|
| 788 |
+
self.tokens.extend(value.tolist())
|
| 789 |
+
self.text_queue.put(
|
| 790 |
+
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
|
| 791 |
+
|
| 792 |
+
def end(self):
|
| 793 |
+
self.text_queue.put(None)
|
| 794 |
+
|
| 795 |
+
def __iter__(self):
|
| 796 |
+
return self
|
| 797 |
+
|
| 798 |
+
def __next__(self):
|
| 799 |
+
value = self.text_queue.get()
|
| 800 |
+
if value is None:
|
| 801 |
+
raise StopIteration()
|
| 802 |
+
else:
|
| 803 |
+
return value
|
| 804 |
|
| 805 |
|
| 806 |
@add_start_docstrings(
|
|
|
|
| 944 |
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 945 |
)
|
| 946 |
return reordered_past
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
|
| 950 |
+
user_name = "\n## human:"
|
| 951 |
+
ai_name = "\n## assistant: "
|
| 952 |
+
stop = '|<end>|'
|
| 953 |
+
|
| 954 |
+
prompt = ''
|
| 955 |
+
for q, r in history:
|
| 956 |
+
prompt += f"{user_name}{q}{stop}"
|
| 957 |
+
prompt += f"{ai_name}{r}{stop}"
|
| 958 |
+
prompt += f"{user_name}{query}{stop}"
|
| 959 |
+
prompt += ai_name.rstrip()
|
| 960 |
+
|
| 961 |
+
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
| 962 |
+
max_input_tokens = self.config.n_positions - max_new_tokens
|
| 963 |
+
|
| 964 |
+
input_tokens = tokenizer.encode(prompt)
|
| 965 |
+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
|
| 966 |
+
return torch.LongTensor([input_tokens]).to(self.device)
|
| 967 |
+
|
| 968 |
+
def chat(self, query, history, tokenizer, stream=False,
|
| 969 |
+
generation_config: Optional[GenerationConfig]=None):
|
| 970 |
+
generation_config = generation_config or self.generation_config
|
| 971 |
+
input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
|
| 972 |
+
stopping_criteria = StoppingCriteriaList(
|
| 973 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '|end|', '<|endoftext|>'], tokenizer)]
|
| 974 |
+
)
|
| 975 |
+
|
| 976 |
+
if stream:
|
| 977 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 978 |
+
Thread(target=self.generate, kwargs=dict(
|
| 979 |
+
inputs=input_ids, streamer=streamer,
|
| 980 |
+
stopping_criteria = stopping_criteria,
|
| 981 |
+
generation_config=generation_config,
|
| 982 |
+
)).start()
|
| 983 |
+
return streamer
|
| 984 |
+
else:
|
| 985 |
+
outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
|
| 986 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
| 987 |
+
return response
|
| 988 |
+
|
| 989 |
+
def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
|
| 990 |
+
generation_config = generation_config or self.generation_config
|
| 991 |
+
max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
|
| 992 |
+
|
| 993 |
+
input_ids = tokenizer.encode(prompt)
|
| 994 |
+
input_ids = input_ids[-max_input_tokens:] # truncate left
|
| 995 |
+
|
| 996 |
+
stopping_criteria = StoppingCriteriaList(
|
| 997 |
+
[EndOfFunctionCriteria([len(input_ids[0])], ['|end|', '|<end>|', '<|endoftext|>'], tokenizer)]
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 1001 |
+
Thread(target=self.generate, kwargs=dict(
|
| 1002 |
+
inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
|
| 1003 |
+
)).start()
|
| 1004 |
+
return streamer
|
| 1005 |
+
|
| 1006 |
|
| 1007 |
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
|
| 1008 |
def __init__(self, config):
|
|
|
|
| 1083 |
if device_map is not None:
|
| 1084 |
model = model.to(torch.device(device_map))
|
| 1085 |
|
| 1086 |
+
return model
|
|
|