import gradio as gr import time from pathlib import Path import torchaudio from stepaudio import StepAudio from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess CACHE_DIR = "/tmp/gradio/" system_promtp = {"role": "system", "content": "适配用户的语言,用简短口语化的文字回答"} class CustomAsr: def __init__(self, model_name="iic/SenseVoiceSmall", device="cuda"): self.model = AutoModel( model=model_name, vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 30000}, device=device, ) def run(self, audio_path): res = self.model.generate( input=audio_path, cache={}, language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech" use_itn=True, batch_size_s=60, merge_vad=True, # merge_length_s=15, ) text = rich_transcription_postprocess(res[0]["text"]) return text def add_message(chatbot, history, mic, text, asr_model): if not mic and not text: return chatbot, history, "Input is empty" if text: chatbot.append({"role": "user", "content": text}) history.append({"role": "user", "content": text}) elif mic and Path(mic).exists(): chatbot.append({"role": "user", "content": {"path": mic}}) # 使用用户语音的 asr 结果为了加速推理 text = asr_model.run(mic) chatbot.append({"role": "user", "content": text}) history.append({"role": "user", "content": text}) print(f"{history=}") return chatbot, history, None def reset_state(): """Reset the chat history.""" return [], [system_promtp] def save_tmp_audio(audio, sr): import tempfile with tempfile.NamedTemporaryFile( dir=CACHE_DIR, delete=False, suffix=".wav" ) as temp_audio: temp_audio_path = temp_audio.name torchaudio.save(temp_audio_path, audio, sr) return temp_audio.name def predict(chatbot, history, audio_model): """Generate a response from the model.""" try: text, audio, sr = audio_model(history, "闫雨婷") print(f"predict {text=}") audio_path = save_tmp_audio(audio, sr) chatbot.append({"role": "assistant", "content": {"path": audio_path}}) chatbot.append({"role": "assistant", "content": text}) history.append({"role": "assistant", "content": text}) except Exception as e: print(e) gr.Warning(f"Some error happend, retry submit") return chatbot, history def _launch_demo(args, audio_model, asr_model): with gr.Blocks(delete_cache=(86400, 86400)) as demo: gr.Markdown("""
Step Audio Chat
""") chatbot = gr.Chatbot( elem_id="chatbot", avatar_images=["assets/user.png", "assets/assistant.png"], min_height=800, type="messages", ) # 保存 chat 历史,不需要每次再重新拼格式 history = gr.State([system_promtp]) mic = gr.Audio(type="filepath") text = gr.Textbox(placeholder="Enter message ...") with gr.Row(): clean_btn = gr.Button("🧹 Clear History (清除历史)") regen_btn = gr.Button("🤔️ Regenerate (重试)") submit_btn = gr.Button("🚀 Submit") def on_submit(chatbot, history, mic, text): chatbot, history, error = add_message( chatbot, history, mic, text, asr_model ) if error: gr.Warning(error) # 显示警告消息 return chatbot, history, None, None else: chatbot, history = predict(chatbot, history, audio_model) return chatbot, history, None, None submit_btn.click( fn=on_submit, inputs=[chatbot, history, mic, text], outputs=[chatbot, history, mic, text], concurrency_limit=4, concurrency_id="gpu_queue", ) clean_btn.click( reset_state, outputs=[chatbot, history], show_progress=True, ) def regenerate(chatbot, history): while chatbot and chatbot[-1]["role"] == "assistant": chatbot.pop() while history and history[-1]["role"] == "assistant": print(f"discard {history[-1]}") history.pop() return predict(chatbot, history, audio_model) regen_btn.click( regenerate, [chatbot, history], [chatbot, history], show_progress=True, concurrency_id="gpu_queue", ) demo.queue().launch( share=False, server_port=args.server_port, server_name=args.server_name, ) if __name__ == "__main__": from argparse import ArgumentParser import os parser = ArgumentParser() parser.add_argument("--model-path", type=str, required=True, help="Model path.") parser.add_argument( "--server-port", type=int, default=7860, help="Demo server port." ) parser.add_argument( "--server-name", type=str, default="0.0.0.0", help="Demo server name." ) args = parser.parse_args() audio_model = StepAudio( tokenizer_path=os.path.join(args.model_path, "Step-Audio-Tokenizer"), tts_path=os.path.join(args.model_path, "Step-Audio-TTS-3B"), llm_path=os.path.join(args.model_path, "Step-Audio-Chat"), ) asr_model = CustomAsr() _launch_demo(args, audio_model, asr_model)