farzadab commited on
Commit
d0ed04b
·
verified ·
1 Parent(s): 0fdbb39

Upload UltravoxPipeline

Browse files
Files changed (2) hide show
  1. config.json +18 -0
  2. ultravox_pipeline.py +111 -0
config.json CHANGED
@@ -25,6 +25,24 @@
25
  ]
26
  },
27
  "audio_token_index": 32000,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  "hidden_size": 4096,
29
  "ignore_index": -100,
30
  "initializer_range": 0.02,
 
25
  ]
26
  },
27
  "audio_token_index": 32000,
28
+ "custom_pipelines": {
29
+ "ultravox-pipeline": {
30
+ "default": {
31
+ "model": {
32
+ "pt": [
33
+ "fixie-ai/ultravox-v0.2",
34
+ "main"
35
+ ]
36
+ }
37
+ },
38
+ "impl": "ultravox_pipeline.UltravoxPipeline",
39
+ "pt": [
40
+ "UltravoxModel"
41
+ ],
42
+ "tf": [],
43
+ "type": "multimodal"
44
+ }
45
+ },
46
  "hidden_size": 4096,
47
  "ignore_index": -100,
48
  "initializer_range": 0.02,
ultravox_pipeline.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import transformers
5
+
6
+ # We must use relative import in this directory to allow uploading to HF Hub
7
+ from . import ultravox_model
8
+ from . import ultravox_processing
9
+
10
+
11
+ class UltravoxPipeline(transformers.Pipeline):
12
+ def __init__(
13
+ self,
14
+ model: ultravox_model.UltravoxModel,
15
+ tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
16
+ audio_processor: Optional[transformers.ProcessorMixin] = None,
17
+ **kwargs
18
+ ):
19
+ if tokenizer is None:
20
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
21
+ model.config._name_or_path
22
+ )
23
+
24
+ if audio_processor is None:
25
+ audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
26
+ model.config.audio_model_id
27
+ )
28
+
29
+ self.processor = ultravox_processing.UltravoxProcessor(
30
+ audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor
31
+ )
32
+
33
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
34
+
35
+ def _sanitize_parameters(self, **kwargs):
36
+ generation_kwargs = {}
37
+ if "temperature" in kwargs:
38
+ generation_kwargs["temperature"] = kwargs["temperature"]
39
+ if "max_new_tokens" in kwargs:
40
+ generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
41
+ if "repetition_penalty" in kwargs:
42
+ generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
43
+ return {}, generation_kwargs, {}
44
+
45
+ def preprocess(self, inputs: Dict[str, Any]):
46
+ if "turns" in inputs:
47
+ turns = inputs["turns"]
48
+ else:
49
+ prompt = inputs.get("prompt", "<|audio|>")
50
+ if "<|audio|>" not in prompt:
51
+ logging.warning(
52
+ "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
53
+ )
54
+ prompt += " <|audio|>"
55
+ turns = [{"role": "user", "content": prompt}]
56
+
57
+ text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False)
58
+
59
+ # TODO: allow text-only mode?
60
+ assert "audio" in inputs, "Audio input is required"
61
+
62
+ if "sampling_rate" not in inputs:
63
+ logging.warning(
64
+ "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
65
+ )
66
+
67
+ return self.processor(
68
+ text=text,
69
+ audio=inputs["audio"],
70
+ sampling_rate=inputs.get("sampling_rate", 16000),
71
+ )
72
+
73
+ def _forward(
74
+ self,
75
+ model_inputs: Dict[str, Any],
76
+ temperature: Optional[float] = None,
77
+ max_new_tokens: Optional[int] = None,
78
+ repetition_penalty: float = 1.1,
79
+ ) -> List[int]:
80
+ temperature = temperature or None
81
+ do_sample = temperature is not None
82
+
83
+ terminators = [self.tokenizer.eos_token_id]
84
+ if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
85
+ terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
86
+
87
+ input_len = model_inputs["input_ids"].shape[1]
88
+
89
+ outputs = self.model.generate(
90
+ **model_inputs,
91
+ do_sample=do_sample,
92
+ temperature=temperature,
93
+ max_new_tokens=max_new_tokens,
94
+ repetition_penalty=repetition_penalty,
95
+ eos_token_id=terminators
96
+ )
97
+ return outputs[0][input_len:]
98
+
99
+ def postprocess(self, model_outputs) -> str:
100
+ output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
101
+ return output_text
102
+
103
+
104
+ transformers.pipeline
105
+ transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
106
+ "ultravox-pipeline", # TODO: make it broader later on
107
+ pipeline_class=UltravoxPipeline,
108
+ pt_model=ultravox_model.UltravoxModel,
109
+ default={"pt": ("fixie-ai/ultravox-v0.2", "main")},
110
+ type="multimodal",
111
+ )