File size: 2,438 Bytes
7383a83
 
 
 
 
 
af310d3
7383a83
 
 
 
 
 
 
 
 
 
 
 
 
f67b703
af310d3
f67b703
 
 
 
 
 
 
 
7383a83
f67b703
 
 
 
 
af310d3
f67b703
 
af310d3
f67b703
 
 
af310d3
f67b703
 
 
 
 
af310d3
f67b703
 
 
7383a83
 
af310d3
f67b703
 
af310d3
f67b703
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import torchaudio
import torch
import numpy as np
from huggingface_hub import hf_hub_download

# βœ… Map of model names to files on Hugging Face
RAVE_MODELS = {
    "Guitar": "guitar_iil_b2048_r48000_z16.ts",
    "Soprano Sax": "sax_soprano_franziskaschroeder_b2048_r48000_z20.ts",
    "Organ (Archive)": "organ_archive_b2048_r48000_z16.ts",
    "Organ (Bach)": "organ_bach_b2048_r48000_z16.ts",
    "Voice Multivoice": "voice-multi-b2048-r48000-z11.ts",
    "Birds Dawn Chorus": "birds_dawnchorus_b2048_r48000_z8.ts",
    "Magnets": "magnets_b2048_r48000_z8.ts",
    "Whale Songs": "humpbacks_pondbrain_b2048_r48000_z20.ts"
}

MODEL_CACHE = {}

def load_rave_model(model_name):
    """Load TorchScript RAVE model from Hugging Face Hub."""
    if model_name in MODEL_CACHE:
        return MODEL_CACHE[model_name]

    model_file = hf_hub_download(
        repo_id="Intelligent-Instruments-Lab/rave-models",
        filename=RAVE_MODELS[model_name]
    )

    model = torch.jit.load(model_file, map_location="cpu")
    model.eval()
    MODEL_CACHE[model_name] = model
    return model

def apply_rave(audio, model_name):
    """Apply selected RAVE model to uploaded audio."""
    model = load_rave_model(model_name)

    # Convert numpy audio to torch tensor
    audio_tensor = torch.tensor(audio[0]).unsqueeze(0)  # [1, samples]
    sr = audio[1]

    # βœ… Resample if needed (most RAVE models expect 48kHz)
    if sr != 48000:
        audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 48000)
        sr = 48000

    with torch.no_grad():
        # βœ… TorchScript models have encode & decode methods
        z = model.encode(audio_tensor)
        processed_audio = model.decode(z)

    return (processed_audio.squeeze().cpu().numpy(), sr)

# πŸŽ› Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## πŸŽ› RAVE Style Transfer on Stems")
    gr.Markdown("Upload audio, pick a RAVE model, and get a transformed version.")

    with gr.Row():
        audio_input = gr.Audio(type="numpy", label="Upload Audio", sources=["upload", "microphone"])
        model_selector = gr.Dropdown(list(RAVE_MODELS.keys()), label="Select Style", value="Guitar")

    with gr.Row():
        output_audio = gr.Audio(type="numpy", label="Transformed Audio")

    process_btn = gr.Button("Apply Style Transfer")
    process_btn.click(fn=apply_rave, inputs=[audio_input, model_selector], outputs=output_audio)

demo.launch()