vampnet-music / unloop /client.py
Hugo Flores Garcia
quality of life UI improvements
7a44f56
import time
from pathlib import Path
import shutil
import json
import argbind
import audiotools as at
from gradio_client import Client, handle_file
from pythonosc.osc_server import ThreadingOSCUDPServer
from pythonosc.udp_client import SimpleUDPClient
from pythonosc.dispatcher import Dispatcher
import torch
class Timer:
def __init__(self):
self.times = {}
def tick(self, name: str):
self.times[name] = time.time()
def tock(self, name: str):
toc = time.time() - self.times[name]
print(f"{name} took {toc} seconds")
return toc
def __str__(self):
return str(self.times)
timer = Timer()
DOWNLOADS_DIR = ".gradio"
def clear_file(file):
file = Path(file)
if file.exists():
file.unlink()
class OSCManager:
def __init__(
self,
ip: str,
s_port: str,
r_port: str,
process_fn: callable,
# param_change_callback: callable = None
):
self.ip = ip
self.s_port = s_port
self.r_port = r_port
# register the process_fn
self.process_fn = process_fn
print(f"will send to {ip}:{s_port}")
self.client = SimpleUDPClient(ip, s_port)
def start_server(self,):
dispatcher = Dispatcher()
dispatcher.map("/process", self.process_fn)
def send_heartbeat(_, *args):
# print("Received heartbeat")
self.client.send_message("/heartbeat", "pong")
dispatcher.map("/heartbeat", lambda a, *r: send_heartbeat(a, *r))
dispatcher.map("/cleanup", lambda a, *r: clear_file(r[0]))
dispatcher.set_default_handler(lambda a, *r: print(a, r))
server = ThreadingOSCUDPServer((self.ip, self.r_port), dispatcher)
print(f"Serving on {server.server_address}")
server.serve_forever()
def error(self, msg: str):
self.client.send_message("/error", msg)
def log(self, msg: str):
self.client.send_message("/log", msg)
class GradioOSCClient:
def __init__(self,
ip: str,
s_port: int, r_port: int,
vampnet_url: str = None, # url for vampnet
):
self.osc_manager = OSCManager(
ip=ip, s_port=s_port, r_port=r_port,
process_fn=self.process,
)
self.clients = {}
if vampnet_url is not None:
self.clients["vampnet"] = Client(src=vampnet_url, download_files=DOWNLOADS_DIR)
assert len(self.clients) > 0, "At least one client must be specified!"
self.batch_size = 2# TODO: automatically get batch size from client.
self.osc_manager.log("hello from gradio client!")
self.inf_idx = 0
def param_changed(self, param_name, new_value):
print(f"Parameter {param_name} changed to {new_value}")
def vampnet_process(self, address: str, *args):
client = self.clients["vampnet"]
# query id --- audiofile ---- model_choice --- periodic --- drop --- seed
query_id = args[0]
client_type = args[1]
audio_path = Path(args[2])
model_choice = args[3]
periodic_p = args[4]
dropout = args[5]
seed = args[6]
looplength_ms = args[7]
typical_filter = args[8]
typical_mass = args[9]
typical_min_tokens = args[10]
upper_codebook_mask = args[11]
onset_mask_width = args[12]
sampling_steps = args[13]
temperature = args[14]
top_p = args[15]
beat_mask_ms = args[16]
num_feedback_steps = args[17]
if not audio_path.exists():
print(f"File {audio_path} does not exist")
self.osc_manager.error(f"File {audio_path} does not exist")
return
sig = at.AudioSignal(audio_path)
sig.to_mono()
sig.sample_rate = 48000 # HOT PATCH (FIXME IN MAX: sample rate is being forced to 48k)
# grab the looplength only
# TODO: although I added this,
# the max patch is still configured to crop anything past the looplength off
# so we'll have to change that in order to make an effect.
end_sample = int((looplength_ms * sig.sample_rate) / 1000)
# grab the remainder of the waveform
num_cut_samples = sig.samples.shape[-1] - end_sample
cut_wav = sig.samples[..., -num_cut_samples:]
sig.samples = sig.samples[..., :end_sample]
# write the file back
sig.write(audio_path)
timer.tick("predict")
print(f"Processing {address} with args {args}")
# breakpoint()
job = client.submit(
input_audio=handle_file(audio_path),
sampletemp=temperature,
top_p=top_p,
periodic_p=periodic_p,
dropout=dropout,
stretch_factor=1,
onset_mask_width=onset_mask_width,
typical_filtering=bool(typical_filter),
typical_mass=typical_mass,
typical_min_tokens=typical_min_tokens,
seed=seed,
model_choice=model_choice,
n_mask_codebooks=upper_codebook_mask,
pitch_shift_amt=0,
sample_cutoff=1.0,
sampling_steps=sampling_steps,
beat_mask_ms=int(beat_mask_ms),
num_feedback_steps=num_feedback_steps,
api_name="/vamp_1"
)
while not job.done():
time.sleep(0.1)
self.osc_manager.client.send_message("/progress", [query_id, str(job.status().code)])
result = job.result()
# audio_file = result
# audio_files = [audio_file] * self.batch_size
audio_files = list(result[:self.batch_size])
# if each file is missing a .wav at the end, add it
first_audio = audio_files[0]
if not first_audio.endswith(".wav"):
for audio_file in set(audio_files):
if not audio_file.endswith(".wav"):
shutil.move(audio_file, f"{audio_file}.wav")
audio_file = f"{audio_file}.wav"
audio_files = [f"{audio}.wav" for audio in audio_files if not audio.endswith(".wav")]
for audio_file in audio_files:
# load the file, add the cut samples back
sig = at.AudioSignal(audio_file)
sig.resample(48000)
sig.samples = torch.cat([sig.samples, cut_wav], dim=-1)
sig.write(audio_file)
seed = result[-1]
timer.tock("predict")
# send a message that the process is done
self.osc_manager.log(f"query {query_id} has been processed")
self.osc_manager.client.send_message("/process-result", [query_id] + audio_files)
def process(self, address: str, *args):
query_id = args[0]
client_type = args[1]
audio_path = Path(args[2])
if client_type == "vampnet":
self.vampnet_process(address, *args)
return
elif client_type == "sketch2sound":
self.process_s2s(address, *args)
return
else:
raise ValueError(f"Unknown client type {client_type}")
def gradio_main(
vampnet_url: str = None
):
system = GradioOSCClient(
vampnet_url=vampnet_url,
ip="127.0.0.1", s_port=8003, r_port=8001,
)
system.osc_manager.start_server()
if __name__ == "__main__":
try:
gradio_main = argbind.bind(gradio_main, without_prefix=True)
args = argbind.parse_args()
with argbind.scope(args):
gradio_main()
except Exception as e:
import shutil
shutil.rmtree(DOWNLOADS_DIR, ignore_errors=True)
raise e