File size: 5,123 Bytes
51e0928
 
 
 
 
 
 
 
8819fac
 
51e0928
 
47f2da0
 
51e0928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edccdde
4439915
51e0928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8819fac
 
 
 
 
 
 
42f5723
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from utils import load_ckpt, print_colored
from tokenizer import make_tokenizer
from model import get_hertz_dev_config
import matplotlib.pyplot as plt
import spaces
import gradio as gr

device = 'cuda' if T.cuda.is_available() else 'cpu'
#T.cuda.set_device(0)
print(f"Using device: {device}")

audio_tokenizer = make_tokenizer(device)

TWO_SPEAKER = False

model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)

generator = model_config()
generator = generator.eval().to(T.bfloat16).to(device)



##############
# Load audio

def load_and_preprocess_audio(audio_path):
    gr.Info("Loading and preprocessing audio...")
    # Load audio file
    audio_tensor, sr = torchaudio.load(audio_path)
    gr.Info(f"Loaded audio shape: {audio_tensor.shape}")
    
    if TWO_SPEAKER:
        if audio_tensor.shape[0] == 1:
            gr.Info("Converting mono to stereo...")
            audio_tensor = audio_tensor.repeat(2, 1)
            gr.Info(f"Stereo audio shape: {audio_tensor.shape}")
    else:
        if audio_tensor.shape[0] == 2:
            gr.Info("Converting stereo to mono...")
            audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
            gr.Info(f"Mono audio shape: {audio_tensor.shape}")
        
    # Resample to 16kHz if needed
    if sr != 16000:
        gr.Info(f"Resampling from {sr}Hz to 16000Hz...")
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
        audio_tensor = resampler(audio_tensor)
        
    # Clip to 5 minutes if needed
    max_samples = 16000 * 60 * 5
    if audio_tensor.shape[1] > max_samples:
        # gr.Info("Clipping audio to 5 minutes...")
        raise gr.Erorr("Maximum prompt is 5 minutes")
        # audio_tensor = audio_tensor[:, :max_samples]

    duration_seconds = audio_tensor.shape[1] / sample_rate

    gr.Info("Audio preprocessing complete!")
    return audio_tensor.unsqueeze(0), duration_seconds

##############
# Return audio to gradio

def display_audio(audio_tensor):
    audio_tensor = audio_tensor.cpu().squeeze()
    if audio_tensor.ndim == 1:
        audio_tensor = audio_tensor.unsqueeze(0)
    audio_tensor = audio_tensor.float()

    # Make a waveform plot
    # plt.figure(figsize=(4, 1))
    # plt.plot(audio_tensor.numpy()[0], linewidth=0.5)
    # plt.axis('off')
    # plt.show()

    # Make an audio player
    return (16000, audio_tensor.numpy())

def get_completion(encoded_prompt_audio, prompt_len):
    prompt_len_seconds = prompt_len / 8
    gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s")
    with T.autocast(device_type='cuda', dtype=T.bfloat16):
        completed_audio_batch = generator.completion(
            encoded_prompt_audio, 
            temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))
            use_cache=True)

        completed_audio = completed_audio_batch
        print_colored(f"Decoding completion...", "blue")
        if TWO_SPEAKER:
            decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())
            decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())
            decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)
        else:
            decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())
        gr.Info(f"Decoded completion shape: {decoded_completion.shape}")

    gr.Info("Preparing audio for playback...")

    audio_tensor = decoded_completion.cpu().squeeze()
    if audio_tensor.ndim == 1:
        audio_tensor = audio_tensor.unsqueeze(0)
    audio_tensor = audio_tensor.float()

    if audio_tensor.abs().max() > 1:
        audio_tensor = audio_tensor / audio_tensor.abs().max()

    return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]

@spaces.GPU
def run(audio_path):
    prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path)
    prompt_len = prompt_len_seconds * 8
    gr.Info("Encoding prompt...")
    with T.autocast(device_type='cuda', dtype=T.bfloat16):
        if TWO_SPEAKER:
            encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))
            encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))
            encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)
        else:
            encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
    gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}")
    gr.Info("Prompt encoded successfully!")
    # num_completions = 10
    completion = get_completion(encoded_prompt_audio, prompt_len)
    return display_audio(completion)



with gr.Blocks() as demo:
    gr.Markdown("# hertz-dev")
    inp = gr.Audio(label="Input Audio", type="filepath", interactive=True)
    btn = gr.Button("Continue", variant="primary")
    out = gr.Audio(label="Output", interactive=False)
    btn.click(run, inputs=inp, outputs=out)

demo.queue().launch()