File size: 8,367 Bytes
fc8b181
f18f98b
bd40662
a32055a
fc8b181
 
 
5c44be0
e173c02
 
544ae95
24e47df
544ae95
bd40662
24e47df
fc8b181
24e47df
 
 
 
 
fc8b181
24e47df
fc8b181
24e47df
 
 
fc8b181
e173c02
fc8b181
544ae95
4ad7b57
24e47df
4ad7b57
24e47df
 
2c32151
4ad7b57
 
 
 
 
 
 
 
24e47df
 
 
 
 
 
b7d2089
4ad7b57
 
 
 
 
24e47df
b7d2089
24e47df
4ad7b57
 
 
 
 
 
 
24e47df
4ad7b57
fc8b181
f28e066
544ae95
fc8b181
544ae95
 
 
85dc4b0
4ad7b57
24e47df
 
 
 
 
 
 
 
9b8102d
 
 
24e47df
 
 
 
 
 
 
 
4ad7b57
 
24e47df
 
4ad7b57
2c32151
4ad7b57
2c32151
4ad7b57
055ea67
 
 
4ad7b57
055ea67
4ad7b57
055ea67
 
4ad7b57
24e47df
2c32151
fc8b181
f28e066
fc8b181
e173c02
4ad7b57
24e47df
4ad7b57
 
24e47df
 
4ad7b57
 
 
 
 
 
 
 
9b8102d
 
 
2c32151
24e47df
 
 
 
 
 
 
4ad7b57
 
 
24e47df
6ea0ef3
2c32151
4ad7b57
2c32151
4ad7b57
055ea67
 
 
4ad7b57
055ea67
4ad7b57
055ea67
 
4ad7b57
24e47df
2c32151
fc8b181
 
 
 
24e47df
fc8b181
24e47df
fc8b181
 
85dc4b0
fc8b181
 
 
 
24e47df
fc8b181
85dc4b0
fc8b181
 
 
 
24e47df
 
 
 
 
 
 
 
fc8b181
 
e173c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import gradio as gr
import numpy as np
import soundfile as sf
from semanticodec import SemantiCodec
from huggingface_hub import HfApi
import spaces
import torch
import tempfile
import io
import uuid
import pickle
from pathlib import Path

# Initialize the model and ensure it's on the correct device
def load_model():
    model = SemantiCodec(token_rate=100, semantic_vocab_size=32768)  # 1.40 kbps
    if torch.cuda.is_available():
        # Move the model to CUDA
        model.to("cuda:0")
    return model

# Initialize model
semanticodec = load_model()
# Get the device of the model
model_device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Model initialized on device: {model_device}")

@spaces.GPU(duration=20)
def encode_audio(audio_path):
    """Encode audio file to tokens and return them as a file"""
    try:
        print(f"Encoding audio on device: {model_device}")
        tokens = semanticodec.encode(audio_path)
        print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
        
        # Move tokens to CPU before converting to numpy
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        
        # Ensure tokens are in the right shape for later decoding
        if tokens.ndim == 1:
            # Reshape to match expected format [batch, seq_len, features]
            tokens = tokens.reshape(1, -1, 1)
        
        # Save tokens in a way that preserves shape information
        token_data = {
            'tokens': tokens,
            'shape': tokens.shape,
            'device': str(model_device)  # Store intended device information
        }
        
        # Create a temporary file in /tmp which is writable in Spaces
        temp_dir = "/tmp"
        os.makedirs(temp_dir, exist_ok=True)
        temp_file_path = os.path.join(temp_dir, f"tokens_{uuid.uuid4()}.oterin")
        
        # Write using pickle instead of numpy save
        with open(temp_file_path, "wb") as f:
            pickle.dump(token_data, f)
        
        # Verify the file exists and has content
        if not os.path.exists(temp_file_path) or os.path.getsize(temp_file_path) == 0:
            raise Exception("Failed to create token file")
        
        return temp_file_path, f"Encoded to {tokens.shape[1]} tokens"
    except Exception as e:
        print(f"Encoding error: {str(e)}")
        return None, f"Error encoding audio: {str(e)}"

@spaces.GPU(duration=340)
def decode_tokens(token_file):
    """Decode tokens to audio"""
    # Ensure the file exists and has content
    if not token_file or not os.path.exists(token_file):
        return None, "Error: Empty or missing token file"
    
    try:
        # Load tokens using pickle instead of numpy load
        with open(token_file, "rb") as f:
            token_data = pickle.load(f)
        
        tokens = token_data['tokens']
        intended_device = token_data.get('device', model_device)
        print(f"Loaded tokens with shape {tokens.shape}, intended device: {intended_device}")
        
        # Convert to torch tensor with Long dtype for embedding
        tokens_tensor = torch.tensor(tokens, dtype=torch.long)
        print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
        
        # Explicitly move tokens to the model's device
        tokens_tensor = tokens_tensor.to(model_device)
        print(f"Tokens moved to device: {tokens_tensor.device}")
        
        # Also ensure model is on the expected device
        semanticodec.to(model_device)
        print(f"Model device before decode: {next(semanticodec.parameters()).device}")
        
        # Decode the tokens
        waveform = semanticodec.decode(tokens_tensor)
        print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
        
        # Move waveform to CPU for audio processing
        if isinstance(waveform, torch.Tensor):
            waveform = waveform.cpu().numpy()
        
        # Extract audio data - this should be a numpy array
        audio_data = waveform[0, 0]  # Shape should be [time]
        sample_rate = 32000
        
        print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
        
        # Return in Gradio Audio compatible format: (sample_rate, audio_data)
        return (sample_rate, audio_data), f"Decoded {tokens.shape[1]} tokens to audio"
    except Exception as e:
        print(f"Decoding error: {str(e)}")
        return None, f"Error decoding tokens: {str(e)}"

@spaces.GPU(duration=360)
def process_both(audio_path):
    """Encode and then decode the audio without saving intermediate files"""
    try:
        print(f"Processing both on device: {model_device}")
        # Encode
        tokens = semanticodec.encode(audio_path)
        print(f"Tokens device after encode: {tokens.device if isinstance(tokens, torch.Tensor) else 'numpy'}")
        
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        
        # Ensure tokens are in the right shape for decoding
        if tokens.ndim == 1:
            # Reshape to match expected format [batch, seq_len, features]
            tokens = tokens.reshape(1, -1, 1)
        
        # Convert back to torch tensor with Long dtype for embedding
        tokens_tensor = torch.tensor(tokens, dtype=torch.long)
        print(f"Tokens tensor created on device: {tokens_tensor.device} with dtype: {tokens_tensor.dtype}")
        
        # Explicitly move tokens to the model's device
        tokens_tensor = tokens_tensor.to(model_device)
        print(f"Tokens moved to device: {tokens_tensor.device}")
        
        # Also ensure model is on the expected device
        semanticodec.to(model_device)
        print(f"Model device before decode: {next(semanticodec.parameters()).device}")
        
        # Decode
        waveform = semanticodec.decode(tokens_tensor)
        print(f"Waveform device after decode: {waveform.device if isinstance(waveform, torch.Tensor) else 'numpy'}")
        
        # Move waveform to CPU for audio processing
        if isinstance(waveform, torch.Tensor):
            waveform = waveform.cpu().numpy()
        
        # Extract audio data - this should be a numpy array
        audio_data = waveform[0, 0]  # Shape should be [time]
        sample_rate = 32000
        
        print(f"Audio data shape: {audio_data.shape}, dtype: {audio_data.dtype}")
        
        # Return in Gradio Audio compatible format: (sample_rate, audio_data)
        return (sample_rate, audio_data), f"Encoded to {tokens.shape[1]} tokens\nDecoded {tokens.shape[1]} tokens to audio"
    except Exception as e:
        print(f"Processing error: {str(e)}")
        return None, f"Error processing audio: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Oterin Audio Codec") as demo:
    gr.Markdown("# Oterin Audio Codec")
    gr.Markdown("Upload an audio file to encode it to semantic tokens, decode tokens back to audio, or do both.")
    
    with gr.Tab("Encode Audio"):
        with gr.Row():
            encode_input = gr.Audio(type="filepath", label="Input Audio")
            encode_output = gr.File(label="Encoded Tokens (.oterin)", file_types=[".oterin"])
        encode_status = gr.Textbox(label="Status")
        encode_btn = gr.Button("Encode")
        encode_btn.click(encode_audio, inputs=encode_input, outputs=[encode_output, encode_status])
    
    with gr.Tab("Decode Tokens"):
        with gr.Row():
            decode_input = gr.File(label="Token File (.oterin)", file_types=[".oterin"])
            decode_output = gr.Audio(label="Decoded Audio")
        decode_status = gr.Textbox(label="Status")
        decode_btn = gr.Button("Decode")
        decode_btn.click(decode_tokens, inputs=decode_input, outputs=[decode_output, decode_status])
    
    with gr.Tab("Both (Encode & Decode)"):
        with gr.Row():
            both_input = gr.Audio(type="filepath", label="Input Audio")
            both_output = gr.Audio(label="Reconstructed Audio")
        both_status = gr.Textbox(label="Status")
        both_btn = gr.Button("Process")
        both_btn.click(process_both, inputs=both_input, outputs=[both_output, both_status])

if __name__ == "__main__":
    demo.launch(share=True)