Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torchaudio | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from tokenizer import StepAudioTokenizer | |
| from tts import StepAudioTTS | |
| from utils import load_audio, speech_adjust, volumn_adjust | |
| class StepAudio: | |
| def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str): | |
| # load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib` | |
| # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib | |
| try: | |
| if torch.__version__ >= "2.5": | |
| torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so')) | |
| elif torch.__version__ >= "2.3": | |
| torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so')) | |
| elif torch.__version__ >= "2.2": | |
| torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so')) | |
| print("Load optimus_ths successfully and flash attn would be enabled") | |
| except Exception as err: | |
| print(f"Fail to load optimus_ths and flash attn is disabled: {err}") | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained( | |
| llm_path, trust_remote_code=True | |
| ) | |
| self.encoder = StepAudioTokenizer(tokenizer_path) | |
| self.decoder = StepAudioTTS(tts_path, self.encoder) | |
| self.llm = AutoModelForCausalLM.from_pretrained( | |
| llm_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| def __call__( | |
| self, | |
| messages: list, | |
| speaker_id: str, | |
| speed_ratio: float = 1.0, | |
| volumn_ratio: float = 1.0, | |
| ): | |
| text_with_audio = self.apply_chat_template(messages) | |
| token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt") | |
| outputs = self.llm.generate( | |
| token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True | |
| ) | |
| output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0] | |
| output_text = self.llm_tokenizer.decode(output_token_ids) | |
| output_audio, sr = self.decoder(output_text, speaker_id) | |
| if speed_ratio != 1.0: | |
| output_audio = speech_adjust(output_audio, sr, speed_ratio) | |
| if volumn_ratio != 1.0: | |
| output_audio = volumn_adjust(output_audio, volumn_ratio) | |
| return output_text, output_audio, sr | |
| def encode_audio(self, audio_path): | |
| audio_wav, sr = load_audio(audio_path) | |
| audio_tokens = self.encoder(audio_wav, sr) | |
| return audio_tokens | |
| def apply_chat_template(self, messages: list): | |
| text_with_audio = "" | |
| for msg in messages: | |
| role = msg["role"] | |
| content = msg["content"] | |
| if role == "user": | |
| role = "human" | |
| if isinstance(content, str): | |
| text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>" | |
| elif isinstance(content, dict): | |
| if content["type"] == "text": | |
| text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>" | |
| elif content["type"] == "audio": | |
| audio_tokens = self.encode_audio(content["audio"]) | |
| text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>" | |
| elif content is None: | |
| text_with_audio += f"<|BOT|>{role}\n" | |
| else: | |
| raise ValueError(f"Unsupported content type: {type(content)}") | |
| if not text_with_audio.endswith("<|BOT|>assistant\n"): | |
| text_with_audio += "<|BOT|>assistant\n" | |
| return text_with_audio | |
| if __name__ == "__main__": | |
| model = StepAudio( | |
| encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder", | |
| decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder", | |
| llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18", | |
| ) | |
| text, audio, sr = model( | |
| [{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}], | |
| "闫雨婷", | |
| ) | |
| torchaudio.save("output/output_e2e_tqta.wav", audio, sr) | |
| text, audio, sr = model( | |
| [ | |
| { | |
| "role": "user", | |
| "content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"}, | |
| } | |
| ], | |
| "闫雨婷", | |
| ) | |
| torchaudio.save("output/output_e2e_aqta.wav", audio, sr) | |