rd211 commited on
Commit
aa215e0
·
verified ·
1 Parent(s): b50e055

Upload modeling_internlm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +236 -9
modeling_internlm2.py CHANGED
@@ -59,10 +59,6 @@ try:
59
  except:
60
  pass
61
 
62
- try:
63
- support_bf16_triu = torch.__version__ >= "2.1.0"
64
- except Exception:
65
- support_bf16_triu = False
66
 
67
  logger = logging.get_logger(__name__)
68
 
@@ -1097,11 +1093,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
1097
  else:
1098
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1099
  if sequence_length != 1:
1100
- if support_bf16_triu or dtype == torch.float32:
1101
- causal_mask = torch.triu(causal_mask, diagonal=1)
1102
- else:
1103
- triu_mask = torch.triu(torch.ones(causal_mask.size(), device=device), diagonal=1).bool()
1104
- causal_mask.masked_fill_(~triu_mask, 0)
1105
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1106
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1107
  if attention_mask is not None:
@@ -1806,3 +1798,238 @@ class InternLM2ForTokenClassification(InternLM2PreTrainedModel):
1806
  hidden_states=outputs.hidden_states,
1807
  attentions=outputs.attentions,
1808
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  except:
60
  pass
61
 
 
 
 
 
62
 
63
  logger = logging.get_logger(__name__)
64
 
 
1093
  else:
1094
  causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1095
  if sequence_length != 1:
1096
+ causal_mask = torch.triu(causal_mask, diagonal=1)
 
 
 
 
1097
  causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1098
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1099
  if attention_mask is not None:
 
1798
  hidden_states=outputs.hidden_states,
1799
  attentions=outputs.attentions,
1800
  )
1801
+
1802
+
1803
+ # Modified from transformers.models.llama.modeling_llama.LlamaForTokenClassification
1804
+ class InternLM2ForRewardModel(InternLM2PreTrainedModel):
1805
+
1806
+ _auto_class = "AutoModel"
1807
+ _tied_weights_keys = ["v_head.weight"]
1808
+
1809
+ def __init__(self, config):
1810
+ super().__init__(config)
1811
+ self.model = InternLM2Model(config)
1812
+ self.vocab_size = config.vocab_size
1813
+ self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
1814
+ self.reward_token_id = config.reward_token_id
1815
+
1816
+ # Initialize weights and apply final processing
1817
+ self.post_init()
1818
+
1819
+ def get_input_embeddings(self):
1820
+ return self.model.tok_embeddings
1821
+
1822
+ def set_input_embeddings(self, value):
1823
+ self.model.tok_embeddings = value
1824
+
1825
+ def get_output_embeddings(self):
1826
+ return self.v_head
1827
+
1828
+ def set_output_embeddings(self, new_embeddings):
1829
+ self.v_head = new_embeddings
1830
+
1831
+ def set_decoder(self, decoder):
1832
+ self.model = decoder
1833
+
1834
+ def get_decoder(self):
1835
+ return self.model
1836
+
1837
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1838
+ @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
1839
+ def forward(
1840
+ self,
1841
+ input_ids: torch.LongTensor = None,
1842
+ attention_mask: Optional[torch.Tensor] = None,
1843
+ position_ids: Optional[torch.LongTensor] = None,
1844
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1845
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1846
+ labels: Optional[torch.LongTensor] = None,
1847
+ use_cache: Optional[bool] = None,
1848
+ output_attentions: Optional[bool] = None,
1849
+ output_hidden_states: Optional[bool] = None,
1850
+ return_dict: Optional[bool] = None,
1851
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1852
+ """
1853
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1854
+ Labels for computing the sequence classification/regression loss.
1855
+
1856
+ Returns:
1857
+
1858
+ """
1859
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1860
+ output_hidden_states = (
1861
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1862
+ )
1863
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1864
+
1865
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1866
+ outputs = self.model(
1867
+ input_ids=input_ids,
1868
+ attention_mask=attention_mask,
1869
+ position_ids=position_ids,
1870
+ past_key_values=past_key_values,
1871
+ inputs_embeds=inputs_embeds,
1872
+ use_cache=use_cache,
1873
+ output_attentions=output_attentions,
1874
+ output_hidden_states=output_hidden_states,
1875
+ return_dict=return_dict,
1876
+ )
1877
+
1878
+ hidden_states = outputs[0]
1879
+ hidden_states = self.v_head(hidden_states)
1880
+ # get end reward token's score
1881
+ ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1)
1882
+
1883
+ reward_scores = torch.gather(hidden_states.squeeze(-1), 1, ends)
1884
+
1885
+ loss = None
1886
+
1887
+ if not return_dict:
1888
+ output = (reward_scores,) + outputs[1:]
1889
+ return (loss,) + output if loss is not None else output
1890
+
1891
+ return SequenceClassifierOutputWithPast(
1892
+ loss=loss,
1893
+ logits=reward_scores,
1894
+ past_key_values=outputs.past_key_values,
1895
+ hidden_states=outputs.hidden_states,
1896
+ attentions=outputs.attentions,
1897
+ )
1898
+
1899
+ @torch.no_grad()
1900
+ def get_score(
1901
+ self,
1902
+ tokenizer,
1903
+ conversation: List[dict],
1904
+ **kwargs,
1905
+ ):
1906
+ """
1907
+ Computes the reward score for a given conversation.
1908
+
1909
+ This function takes a conversation represented as a list of dictionaries, formats it into a string using the chat
1910
+ template from the tokenizer, and passes it through the model to compute the score. A special token representing
1911
+ the reward score is appended to the input sequence. The reward score is then extracted from the model's output.
1912
+
1913
+ Args:
1914
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
1915
+ conversation (List[dict]): A list of dictionaries where each dictionary represents a message in the conversation.
1916
+
1917
+ Returns:
1918
+ float: The computed reward score from the model.
1919
+ """
1920
+ conversation_str = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)
1921
+ input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
1922
+ # add reward score token at the end of the input_ids if it is not already there
1923
+ if input_ids[0, -1] != self.reward_token_id:
1924
+ input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
1925
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1926
+
1927
+ outputs = self.forward(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device), **kwargs)
1928
+ score = outputs[0].cpu().item()
1929
+ return score
1930
+
1931
+ @torch.no_grad()
1932
+ def get_scores(
1933
+ self,
1934
+ tokenizer,
1935
+ conversations: List[List[dict]],
1936
+ **kwargs,
1937
+ ):
1938
+ """
1939
+ Computes the reward scores for multiple conversations in a batched manner.
1940
+
1941
+ This function takes multiple conversations, each represented as a list of dictionaries, formats them into strings using the chat
1942
+ template from the tokenizer, and passes these formatted strings through the model to compute scores for each conversation.
1943
+ Each input sequence has a special token representing the reward score appended before passing to the model.
1944
+ The reward scores are then extracted from the model's output.
1945
+
1946
+ Args:
1947
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
1948
+ conversations (List[List[dict]]): A list of conversations, with each conversation represented as a list of dictionaries where each dictionary contains a message.
1949
+
1950
+ Returns:
1951
+ List[float]: A list of computed reward scores for each conversation in the input batch.
1952
+ """
1953
+ conversation_strs = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversations]
1954
+ batch_input_ids = []
1955
+ attention_masks = []
1956
+
1957
+ for conversation_str in conversation_strs:
1958
+ input_ids = tokenizer.encode(conversation_str, return_tensors="pt", add_special_tokens=False)
1959
+ # add reward score token at the end of the input_ids if it is not already there
1960
+ if input_ids[0, -1] != self.reward_token_id:
1961
+ input_ids = torch.cat([input_ids, torch.tensor([[self.reward_token_id]], dtype=torch.long)], dim=1)
1962
+ input_ids = input_ids.squeeze(0)
1963
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.bool)
1964
+ batch_input_ids.append(input_ids)
1965
+ attention_masks.append(attention_mask)
1966
+
1967
+ r_pad_batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
1968
+ r_pad_attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)
1969
+
1970
+ outputs = self.forward(input_ids=r_pad_batch_input_ids.to(self.device), attention_mask=r_pad_attention_masks.to(self.device), **kwargs)
1971
+ scores = outputs[0].squeeze().cpu().tolist()
1972
+ return scores
1973
+
1974
+ @torch.no_grad()
1975
+ def compare(
1976
+ self,
1977
+ tokenizer,
1978
+ conversation1: List[dict],
1979
+ conversation2: List[dict],
1980
+ return_logits: bool = False,
1981
+ **kwargs,
1982
+ ):
1983
+ """
1984
+ Compares the reward scores of two conversations and determines which conversation has a higher score.
1985
+
1986
+ This function computes reward scores for two given conversations using the `get_score` method and compares the scores to determine which conversation has a higher score.
1987
+ The function can optionally return the actual scores (logits) along with the comparison result.
1988
+
1989
+ Parameters:
1990
+ tokenizer: The tokenizer used for formatting and tokenizing the conversation.
1991
+ conversation1 (List[dict]): The first conversation to compare, represented as a list of dictionaries where each dictionary contains a message.
1992
+ conversation2 (List[dict]): The second conversation to compare, similarly represented.
1993
+ return_logits (bool, optional): If True, the function returns both the comparison result and the actual scores of the two conversations. Defaults to False.
1994
+
1995
+ Returns:
1996
+
1997
+ bool: True if the score of the first conversation is greater than the second, otherwise False.
1998
+ List[float] (optional): A list containing the scores of the first and second conversations respectively.
1999
+
2000
+ Note:
2001
+ - This function is designed for inference, with `@torch.no_grad()` used to disable gradient calculations to optimize performance.
2002
+ """
2003
+ score1 = self.get_score(tokenizer, conversation1, **kwargs)
2004
+ score2 = self.get_score(tokenizer, conversation2, **kwargs)
2005
+ if return_logits:
2006
+ return score1 > score2, [score1, score2]
2007
+ else:
2008
+ return score1 > score2
2009
+
2010
+ @torch.no_grad()
2011
+ def rank(
2012
+ self,
2013
+ tokenizer,
2014
+ conversations: List[List[dict]],
2015
+ return_logits: bool = False,
2016
+ **kwargs,
2017
+ ):
2018
+ """
2019
+ Ranks the conversations based on their scores.
2020
+
2021
+ Args:
2022
+ tokenizer: The tokenizer to be used for formatting and tokenizing the conversation.
2023
+ conversations: A list of conversations, where each conversation is represented as a list of dictionaries. Each dictionary contains the necessary information for the conversation.
2024
+ return_logits: If True, returns the conversation indices along with their logits. Defaults to False.
2025
+
2026
+ Returns:
2027
+ list: A list of conversation rank indices based on their scores. Smaller index means higher score.
2028
+ List[float] (optional): If return_logits is True, a list of conversation indices and their corresponding logits.
2029
+
2030
+ """
2031
+ scores = self.get_scores(tokenizer, conversations, **kwargs)
2032
+ if return_logits:
2033
+ return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True), scores
2034
+ else:
2035
+ return sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)