Upload modeling_internlm2.py with huggingface_hub
Browse files- modeling_internlm2.py +236 -9
modeling_internlm2.py
CHANGED
@@ -59,10 +59,6 @@ try:
|
|
59 |
except:
|
60 |
pass
|
61 |
|
62 |
-
try:
|
63 |
-
support_bf16_triu = torch.__version__ >= "2.1.0"
|
64 |
-
except Exception:
|
65 |
-
support_bf16_triu = False
|
66 |
|
67 |
logger = logging.get_logger(__name__)
|
68 |
|
@@ -1097,11 +1093,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
1097 |
else:
|
1098 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
1099 |
if sequence_length != 1:
|
1100 |
-
|
1101 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
1102 |
-
else:
|
1103 |
-
triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
|
1104 |
-
causal_mask.masked_fill_(~triu_mask, 0)
|
1105 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
1106 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
1107 |
if attention_mask is not None:
|
@@ -1806,3 +1798,238 @@ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
|
|
1806 |
hidden_states=outputs.hidden_states,
|
1807 |
attentions=outputs.attentions,
|
1808 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
except:
|
60 |
pass
|
61 |
|
|
|
|
|
|
|
|
|
62 |
|
63 |
logger = logging.get_logger(__name__)
|
64 |
|
|
|
1093 |
else:
|
1094 |
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
1095 |
if sequence_length != 1:
|
1096 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
|
|
|
|
|
|
|
1097 |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
1098 |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
1099 |
if attention_mask is not None:
|
|
|
1798 |
hidden_states=outputs.hidden_states,
|
1799 |
attentions=outputs.attentions,
|
1800 |
)
|
1801 |
+
|
1802 |
+
|
1803 |
+
# Modified from transformers.models.llama.modeling_llama.LlamaForTokenClassification
|
1804 |
+
class InternLM2ForRewardModel(InternLM2PreTrainedModel):
|
1805 |
+
|
1806 |
+
_auto_class = "AutoModel"
|
1807 |
+
_tied_weights_keys = ["v_head.weight"]
|
1808 |
+
|
1809 |
+
def __init__(self, config):
|
1810 |
+
super().__init__(config)
|
1811 |
+
self.model = InternLM2Model(config)
|
1812 |
+
self.vocab_size = config.vocab_size
|
1813 |
+
self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
|
1814 |
+
self.reward_token_id = config.reward_token_id
|
1815 |
+
|
1816 |
+
# Initialize weights and apply final processing
|
1817 |
+
self.post_init()
|
1818 |
+
|
1819 |
+
def get_input_embeddings(self):
|
1820 |
+
return self.model.tok_embeddings
|
1821 |
+
|
1822 |
+
def set_input_embeddings(self, value):
|
1823 |
+
self.model.tok_embeddings = value
|
1824 |
+
|
1825 |
+
def get_output_embeddings(self):
|
1826 |
+
return self.v_head
|
1827 |
+
|
1828 |
+
def set_output_embeddings(self, new_embeddings):
|
1829 |
+
self.v_head = new_embeddings
|
1830 |
+
|
1831 |
+
def set_decoder(self, decoder):
|
1832 |
+
self.model = decoder
|
1833 |
+
|
1834 |
+
def get_decoder(self):
|
1835 |
+
return self.model
|
1836 |
+
|
1837 |
+
@add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
|
1838 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1839 |
+
def forward(
|
1840 |
+
self,
|
1841 |
+
input_ids: torch.LongTensor = None,
|
1842 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1843 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1844 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1845 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1846 |
+
labels: Optional[torch.LongTensor] = None,
|
1847 |
+
use_cache: Optional[bool] = None,
|
1848 |
+
output_attentions: Optional[bool] = None,
|
1849 |
+
output_hidden_states: Optional[bool] = None,
|
1850 |
+
return_dict: Optional[bool] = None,
|
1851 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1852 |
+
"""
|
1853 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1854 |
+
Labels for computing the sequence classification/regression loss.
|
1855 |
+
|
1856 |
+
Returns:
|
1857 |
+
|
1858 |
+
"""
|
1859 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1860 |
+
output_hidden_states = (
|
1861 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1862 |
+
)
|
1863 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1864 |
+
|
1865 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1866 |
+
outputs = self.model(
|
1867 |
+
input_ids=input_ids,
|
1868 |
+
attention_mask=attention_mask,
|
1869 |
+
position_ids=position_ids,
|
1870 |
+
past_key_values=past_key_values,
|
1871 |
+
inputs_embeds=inputs_embeds,
|
1872 |
+
use_cache=use_cache,
|
1873 |
+
output_attentions=output_attentions,
|
1874 |
+
output_hidden_states=output_hidden_states,
|
1875 |
+
return_dict=return_dict,
|
1876 |
+
)
|
1877 |
+
|
1878 |
+
hidden_states = outputs[0]
|
1879 |
+
hidden_states = self.v_head(hidden_states)
|
1880 |
+
# get end reward token's score
|
1881 |
+
ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
|
1882 |
+
|
1883 |
+
reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends)
|
1884 |
+
|
1885 |
+
loss = None
|
1886 |
+
|
1887 |
+
if not return_dict:
|
1888 |
+
output = (reward_scores,) + outputs[1:]
|
1889 |
+
return (loss,) + output if loss is not None else output
|
1890 |
+
|
1891 |
+
return SequenceClassifierOutputWithPast(
|
1892 |
+
loss=loss,
|
1893 |
+
logits=reward_scores,
|
1894 |
+
past_key_values=outputs.past_key_values,
|
1895 |
+
hidden_states=outputs.hidden_states,
|
1896 |
+
attentions=outputs.attentions,
|
1897 |
+
)
|
1898 |
+
|
1899 |
+
@torch.no_grad()
|
1900 |
+
def get_score(
|
1901 |
+
self,
|
1902 |
+
tokenizer,
|
1903 |
+
conversation: List[dict],
|
1904 |
+
**kwargs,
|
1905 |
+
):
|
1906 |
+
"""
|
1907 |
+
Computes the reward score for a given conversation.
|
1908 |
+
|
1909 |
+
This function takes a conversation represented as a list of dictionaries, formats it into a string using the chat
|
1910 |
+
template from the tokenizer, and passes it through the model to compute the score. A special token representing
|
1911 |
+
the reward score is appended to the input sequence. The reward score is then extracted from the model's output.
|
1912 |
+
|
1913 |
+
Args:
|
1914 |
+
tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
|
1915 |
+
conversation (List[dict]): A list of dictionaries where each dictionary represents a message in the conversation.
|
1916 |
+
|
1917 |
+
Returns:
|
1918 |
+
float: The computed reward score from the model.
|
1919 |
+
"""
|
1920 |
+
conversation_str = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
|
1921 |
+
input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
|
1922 |
+
# add reward score token at the end of the input_ids if it is not already there
|
1923 |
+
if input_ids[0, -1] != self.reward_token_id:
|
1924 |
+
input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
|
1925 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
1926 |
+
|
1927 |
+
outputs = self.forward(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), **kwargs)
|
1928 |
+
score = outputs[0].cpu().item()
|
1929 |
+
return score
|
1930 |
+
|
1931 |
+
@torch.no_grad()
|
1932 |
+
def get_scores(
|
1933 |
+
self,
|
1934 |
+
tokenizer,
|
1935 |
+
conversations: List[List[dict]],
|
1936 |
+
**kwargs,
|
1937 |
+
):
|
1938 |
+
"""
|
1939 |
+
Computes the reward scores for multiple conversations in a batched manner.
|
1940 |
+
|
1941 |
+
This function takes multiple conversations, each represented as a list of dictionaries, formats them into strings using the chat
|
1942 |
+
template from the tokenizer, and passes these formatted strings through the model to compute scores for each conversation.
|
1943 |
+
Each input sequence has a special token representing the reward score appended before passing to the model.
|
1944 |
+
The reward scores are then extracted from the model's output.
|
1945 |
+
|
1946 |
+
Args:
|
1947 |
+
tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
|
1948 |
+
conversations (List[List[dict]]): A list of conversations, with each conversation represented as a list of dictionaries where each dictionary contains a message.
|
1949 |
+
|
1950 |
+
Returns:
|
1951 |
+
List[float]: A list of computed reward scores for each conversation in the input batch.
|
1952 |
+
"""
|
1953 |
+
conversation_strs = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversations]
|
1954 |
+
batch_input_ids = []
|
1955 |
+
attention_masks = []
|
1956 |
+
|
1957 |
+
for conversation_str in conversation_strs:
|
1958 |
+
input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
|
1959 |
+
# add reward score token at the end of the input_ids if it is not already there
|
1960 |
+
if input_ids[0, -1] != self.reward_token_id:
|
1961 |
+
input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
|
1962 |
+
input_ids = input_ids.squeeze(0)
|
1963 |
+
attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
|
1964 |
+
batch_input_ids.append(input_ids)
|
1965 |
+
attention_masks.append(attention_mask)
|
1966 |
+
|
1967 |
+
r_pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
1968 |
+
r_pad_attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)
|
1969 |
+
|
1970 |
+
outputs = self.forward(input_ids=r_pad_batch_input_ids.to(self.device), attention_mask=r_pad_attention_masks.to(self.device), **kwargs)
|
1971 |
+
scores = outputs[0].squeeze().cpu().tolist()
|
1972 |
+
return scores
|
1973 |
+
|
1974 |
+
@torch.no_grad()
|
1975 |
+
def compare(
|
1976 |
+
self,
|
1977 |
+
tokenizer,
|
1978 |
+
conversation1: List[dict],
|
1979 |
+
conversation2: List[dict],
|
1980 |
+
return_logits: bool = False,
|
1981 |
+
**kwargs,
|
1982 |
+
):
|
1983 |
+
"""
|
1984 |
+
Compares the reward scores of two conversations and determines which conversation has a higher score.
|
1985 |
+
|
1986 |
+
This function computes reward scores for two given conversations using the `get_score` method and compares the scores to determine which conversation has a higher score.
|
1987 |
+
The function can optionally return the actual scores (logits) along with the comparison result.
|
1988 |
+
|
1989 |
+
Parameters:
|
1990 |
+
tokenizer: The tokenizer used for formatting and tokenizing the conversation.
|
1991 |
+
conversation1 (List[dict]): The first conversation to compare, represented as a list of dictionaries where each dictionary contains a message.
|
1992 |
+
conversation2 (List[dict]): The second conversation to compare, similarly represented.
|
1993 |
+
return_logits (bool, optional): If True, the function returns both the comparison result and the actual scores of the two conversations. Defaults to False.
|
1994 |
+
|
1995 |
+
Returns:
|
1996 |
+
|
1997 |
+
bool: True if the score of the first conversation is greater than the second, otherwise False.
|
1998 |
+
List[float] (optional): A list containing the scores of the first and second conversations respectively.
|
1999 |
+
|
2000 |
+
Note:
|
2001 |
+
- This function is designed for inference, with `@torch.no_grad()` used to disable gradient calculations to optimize performance.
|
2002 |
+
"""
|
2003 |
+
score1 = self.get_score(tokenizer, conversation1, **kwargs)
|
2004 |
+
score2 = self.get_score(tokenizer, conversation2, **kwargs)
|
2005 |
+
if return_logits:
|
2006 |
+
return score1 > score2, [score1, score2]
|
2007 |
+
else:
|
2008 |
+
return score1 > score2
|
2009 |
+
|
2010 |
+
@torch.no_grad()
|
2011 |
+
def rank(
|
2012 |
+
self,
|
2013 |
+
tokenizer,
|
2014 |
+
conversations: List[List[dict]],
|
2015 |
+
return_logits: bool = False,
|
2016 |
+
**kwargs,
|
2017 |
+
):
|
2018 |
+
"""
|
2019 |
+
Ranks the conversations based on their scores.
|
2020 |
+
|
2021 |
+
Args:
|
2022 |
+
tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
|
2023 |
+
conversations: A list of conversations, where each conversation is represented as a list of dictionaries. Each dictionary contains the necessary information for the conversation.
|
2024 |
+
return_logits: If True, returns the conversation indices along with their logits. Defaults to False.
|
2025 |
+
|
2026 |
+
Returns:
|
2027 |
+
list: A list of conversation rank indices based on their scores. Smaller index means higher score.
|
2028 |
+
List[float] (optional): If return_logits is True, a list of conversation indices and their corresponding logits.
|
2029 |
+
|
2030 |
+
"""
|
2031 |
+
scores = self.get_scores(tokenizer, conversations, **kwargs)
|
2032 |
+
if return_logits:
|
2033 |
+
return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True), scores
|
2034 |
+
else:
|
2035 |
+
return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|