farzadab commited on
Commit
5e75763
·
verified ·
1 Parent(s): e3c2ab5

Upload 5 files

Browse files
Files changed (1) hide show
  1. ultravox_config.py +13 -6
ultravox_config.py CHANGED
@@ -32,6 +32,8 @@ class LossFunction(str, Enum):
32
  class LossConfig:
33
  loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
 
 
35
 
36
  @property
37
  def requires_alt_fields(self):
@@ -47,7 +49,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
47
  documentation from [`PretrainedConfig`] for more information.
48
 
49
  Args:
50
- audio_config (`Wav2Vec2Config`, *optional*):
51
  Custom audio config or dict
52
  text_config (`Union[AutoConfig, dict]`, *optional*):
53
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
@@ -72,10 +74,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
72
  Example:
73
 
74
  ```python
75
- >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
76
 
77
  >>> # Initializing an audio encoder config
78
- >>> audio_config = Wav2Vec2Config()
79
 
80
  >>> # Initializing a Llama config
81
  >>> text_config = LlamaConfig()
@@ -84,13 +86,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
84
  >>> configuration = UltravoxConfig(audio_config, text_config)
85
 
86
  >>> # Initializing a completely untrained model from the configuration
87
- >>> model = UltravoxForConditionalGeneration(configuration)
88
 
89
  >>> # Accessing the model configuration
90
  >>> configuration = model.config
91
 
92
  >>> # Initialize a model from pretrained checkpoints and random projector weights
93
- >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
94
  ```"""
95
 
96
  model_type = "ultravox"
@@ -140,7 +142,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
140
  else:
141
  audio_config = audio_config or {}
142
  self.audio_config = transformers.CONFIG_MAPPING[
143
- audio_config.get("model_type", "wav2vec2")
144
  ](**audio_config)
145
 
146
  self.text_model_lora_config = (
@@ -167,7 +169,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
167
  # remove text_config and audio_config if text_model_id and audio_model_id are present
168
  if self.text_model_id is not None:
169
  diff_dict.pop("text_config", None)
 
 
 
170
  if self.audio_model_id is not None:
171
  diff_dict.pop("audio_config", None)
 
 
172
 
173
  return diff_dict
 
32
  class LossConfig:
33
  loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
35
+ # Number of tokens to ignore from the beginning of the sequence. Only used in LSM
36
+ initial_tokens_to_ignore: int = 0
37
 
38
  @property
39
  def requires_alt_fields(self):
 
49
  documentation from [`PretrainedConfig`] for more information.
50
 
51
  Args:
52
+ audio_config (`WhisperConfig`, *optional*):
53
  Custom audio config or dict
54
  text_config (`Union[AutoConfig, dict]`, *optional*):
55
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
 
74
  Example:
75
 
76
  ```python
77
+ >>> from transformers import UltravoxModel, WhisperConfig, UltravoxConfig, LlamaConfig
78
 
79
  >>> # Initializing an audio encoder config
80
+ >>> audio_config = WhisperConfig()
81
 
82
  >>> # Initializing a Llama config
83
  >>> text_config = LlamaConfig()
 
86
  >>> configuration = UltravoxConfig(audio_config, text_config)
87
 
88
  >>> # Initializing a completely untrained model from the configuration
89
+ >>> model = UltravoxModel(configuration)
90
 
91
  >>> # Accessing the model configuration
92
  >>> configuration = model.config
93
 
94
  >>> # Initialize a model from pretrained checkpoints and random projector weights
95
+ >>> config = UltravoxConfig(audio_model_id="openai/whisper-tiny", text_model_id="meta-llama/Llama-2-7b-chat-hf")
96
  ```"""
97
 
98
  model_type = "ultravox"
 
142
  else:
143
  audio_config = audio_config or {}
144
  self.audio_config = transformers.CONFIG_MAPPING[
145
+ audio_config.get("model_type", "whisper")
146
  ](**audio_config)
147
 
148
  self.text_model_lora_config = (
 
169
  # remove text_config and audio_config if text_model_id and audio_model_id are present
170
  if self.text_model_id is not None:
171
  diff_dict.pop("text_config", None)
172
+ elif "text_config" in diff_dict:
173
+ diff_dict["text_config"].pop("_attn_implementation_autoset", None)
174
+
175
  if self.audio_model_id is not None:
176
  diff_dict.pop("audio_config", None)
177
+ elif "audio_config" in diff_dict:
178
+ diff_dict["audio_config"].pop("_attn_implementation_autoset", None)
179
 
180
  return diff_dict