CiSiMi-At-Home / app.py
Lyte's picture
Update app.py
fe7937e verified
import torch
import gradio as gr
import whisper
import outetts
import numpy as np
from huggingface_hub import hf_hub_download
from outetts.wav_tokenizer.audio_codec import AudioCodec
from outetts.version.v2.prompt_processor import PromptProcessor
from outetts.version.playback import ModelOutput
model_path = hf_hub_download(
repo_id="KandirResearch/CiSiMi-v0.1",
filename="unsloth.Q8_0.gguf", # unsloth.Q4_K_M.gguf
)
model_config = outetts.GGUFModelConfig_v2(
model_path=model_path,
tokenizer_path="KandirResearch/CiSiMi-v0.1",
)
interface = outetts.InterfaceGGUF(model_version="0.3", cfg=model_config)
audio_codec = AudioCodec()
prompt_processor = PromptProcessor("KandirResearch/CiSiMi-v0.1")
whisper_model = whisper.load_model("base.en")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gguf_model = interface.get_model()
def get_audio(tokens):
outputs = prompt_processor.extract_audio_from_tokens(tokens)
if not outputs:
return None
audio_tensor = audio_codec.decode(torch.tensor([[outputs]], dtype=torch.int64).to(device))
return ModelOutput(audio_tensor, audio_codec.sr)
def extract_text_from_tts_output(tts_output):
text = ""
for line in tts_output.strip().split('\n'):
if '<|audio_end|>' in line or '<|im_end|>' in line:
continue
if '<|' in line:
word = line.split('<|')[0].strip()
if word:
text += word + " "
else:
text += line.strip() + " "
return text.strip()
def process_input(audio_input, text_input):
if audio_input is None and (text_input is None or text_input.strip() == ""):
return "Please provide either audio or text input.", None
if audio_input is not None:
return process_audio(audio_input)
else:
return process_text(text_input)
def process_audio(audio):
result = whisper_model.transcribe(audio)
instruction = result["text"]
return generate_response(instruction)
def process_text(text):
instruction = text
return generate_response(instruction)
def generate_response(instruction):
prompt = f"<|im_start|>\nInstructions:\n{instruction}\n<|im_end|>\nAnswer:\n"
gen_cfg = outetts.GenerationConfig(
text=prompt,
temperature=0.6,
repetition_penalty=1.1,
max_length=4096,
speaker=None
)
input_ids = prompt_processor.tokenizer.encode(prompt)
tokens = gguf_model.generate(input_ids, gen_cfg)
output_text = prompt_processor.tokenizer.decode(tokens, skip_special_tokens=False)
if "<|audio_end|>" in output_text:
first_part, _, _ = output_text.partition("<|audio_end|>")
if "<|audio_end|>\n<|im_end|>\n" not in first_part:
first_part += "<|audio_end|>\n<|im_end|>\n"
extracted_text = extract_text_from_tts_output(first_part)
audio_start_pos = first_part.find("<|audio_start|>\n") + len("<|audio_start|>\n")
audio_end_pos = first_part.find("<|audio_end|>\n<|im_end|>\n") + len("<|audio_end|>\n<|im_end|>\n")
if audio_start_pos >= len("<|audio_start|>\n") and audio_end_pos > audio_start_pos:
audio_tokens_text = first_part[audio_start_pos:audio_end_pos]
audio_tokens = prompt_processor.tokenizer.encode(audio_tokens_text)
#print(f"Decoding audio with {len(audio_tokens)} tokens")
#print(f"audio_tokens: {audio_tokens_text}")
audio_output = get_audio(audio_tokens)
if audio_output is not None and hasattr(audio_output, 'audio') and audio_output.audio is not None:
audio_numpy = audio_output.audio.cpu().numpy()
if audio_numpy.ndim > 1:
audio_numpy = audio_numpy.squeeze()
#display(Audio(data=audio_numpy, rate=audio_output.sr, autoplay=True))
return extracted_text, (audio_output.sr, audio_numpy)
return output_text, None
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Audio(type="filepath", label="Audio Input (Optional)"),
gr.Textbox(label="Text Input (Optional)")
],
outputs=[
gr.Textbox(label="Response Text"),
gr.Audio(type="numpy", label="Generated Speech")
],
title="CiSiMi-v0.1 @ Home Demo",
description="Me: Mom can we have CSM locally! Mom: we have CSM locally. CSM locally:",
examples=[
[None, "Hello, what are you capable of?"],
[None, "Explain to me how gravity works!"]
]
)
iface.launch(debug=True)