Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from pathlib import Path | |
import shutil | |
import json | |
import argbind | |
import audiotools as at | |
from gradio_client import Client, handle_file | |
from pythonosc.osc_server import ThreadingOSCUDPServer | |
from pythonosc.udp_client import SimpleUDPClient | |
from pythonosc.dispatcher import Dispatcher | |
import torch | |
class Timer: | |
def __init__(self): | |
self.times = {} | |
def tick(self, name: str): | |
self.times[name] = time.time() | |
def tock(self, name: str): | |
toc = time.time() - self.times[name] | |
print(f"{name} took {toc} seconds") | |
return toc | |
def __str__(self): | |
return str(self.times) | |
timer = Timer() | |
DOWNLOADS_DIR = ".gradio" | |
def clear_file(file): | |
file = Path(file) | |
if file.exists(): | |
file.unlink() | |
class OSCManager: | |
def __init__( | |
self, | |
ip: str, | |
s_port: str, | |
r_port: str, | |
process_fn: callable, | |
# param_change_callback: callable = None | |
): | |
self.ip = ip | |
self.s_port = s_port | |
self.r_port = r_port | |
# register the process_fn | |
self.process_fn = process_fn | |
print(f"will send to {ip}:{s_port}") | |
self.client = SimpleUDPClient(ip, s_port) | |
def start_server(self,): | |
dispatcher = Dispatcher() | |
dispatcher.map("/process", self.process_fn) | |
def send_heartbeat(_, *args): | |
# print("Received heartbeat") | |
self.client.send_message("/heartbeat", "pong") | |
dispatcher.map("/heartbeat", lambda a, *r: send_heartbeat(a, *r)) | |
dispatcher.map("/cleanup", lambda a, *r: clear_file(r[0])) | |
dispatcher.set_default_handler(lambda a, *r: print(a, r)) | |
server = ThreadingOSCUDPServer((self.ip, self.r_port), dispatcher) | |
print(f"Serving on {server.server_address}") | |
server.serve_forever() | |
def error(self, msg: str): | |
self.client.send_message("/error", msg) | |
def log(self, msg: str): | |
self.client.send_message("/log", msg) | |
class GradioOSCClient: | |
def __init__(self, | |
ip: str, | |
s_port: int, r_port: int, | |
vampnet_url: str = None, # url for vampnet | |
): | |
self.osc_manager = OSCManager( | |
ip=ip, s_port=s_port, r_port=r_port, | |
process_fn=self.process, | |
) | |
self.clients = {} | |
if vampnet_url is not None: | |
self.clients["vampnet"] = Client(src=vampnet_url, download_files=DOWNLOADS_DIR) | |
assert len(self.clients) > 0, "At least one client must be specified!" | |
self.batch_size = 2# TODO: automatically get batch size from client. | |
self.osc_manager.log("hello from gradio client!") | |
self.inf_idx = 0 | |
def param_changed(self, param_name, new_value): | |
print(f"Parameter {param_name} changed to {new_value}") | |
def vampnet_process(self, address: str, *args): | |
client = self.clients["vampnet"] | |
# query id --- audiofile ---- model_choice --- periodic --- drop --- seed | |
query_id = args[0] | |
client_type = args[1] | |
audio_path = Path(args[2]) | |
model_choice = args[3] | |
periodic_p = args[4] | |
dropout = args[5] | |
seed = args[6] | |
looplength_ms = args[7] | |
typical_filter = args[8] | |
typical_mass = args[9] | |
typical_min_tokens = args[10] | |
upper_codebook_mask = args[11] | |
onset_mask_width = args[12] | |
sampling_steps = args[13] | |
temperature = args[14] | |
top_p = args[15] | |
beat_mask_ms = args[16] | |
num_feedback_steps = args[17] | |
if not audio_path.exists(): | |
print(f"File {audio_path} does not exist") | |
self.osc_manager.error(f"File {audio_path} does not exist") | |
return | |
sig = at.AudioSignal(audio_path) | |
sig.to_mono() | |
sig.sample_rate = 48000 # HOT PATCH (FIXME IN MAX: sample rate is being forced to 48k) | |
# grab the looplength only | |
# TODO: although I added this, | |
# the max patch is still configured to crop anything past the looplength off | |
# so we'll have to change that in order to make an effect. | |
end_sample = int((looplength_ms * sig.sample_rate) / 1000) | |
# grab the remainder of the waveform | |
num_cut_samples = sig.samples.shape[-1] - end_sample | |
cut_wav = sig.samples[..., -num_cut_samples:] | |
sig.samples = sig.samples[..., :end_sample] | |
# write the file back | |
sig.write(audio_path) | |
timer.tick("predict") | |
print(f"Processing {address} with args {args}") | |
# breakpoint() | |
job = client.submit( | |
input_audio=handle_file(audio_path), | |
sampletemp=temperature, | |
top_p=top_p, | |
periodic_p=periodic_p, | |
dropout=dropout, | |
stretch_factor=1, | |
onset_mask_width=onset_mask_width, | |
typical_filtering=bool(typical_filter), | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
seed=seed, | |
model_choice=model_choice, | |
n_mask_codebooks=upper_codebook_mask, | |
pitch_shift_amt=0, | |
sample_cutoff=1.0, | |
sampling_steps=sampling_steps, | |
beat_mask_ms=int(beat_mask_ms), | |
num_feedback_steps=num_feedback_steps, | |
api_name="/vamp_1" | |
) | |
while not job.done(): | |
time.sleep(0.1) | |
self.osc_manager.client.send_message("/progress", [query_id, str(job.status().code)]) | |
result = job.result() | |
# audio_file = result | |
# audio_files = [audio_file] * self.batch_size | |
audio_files = list(result[:self.batch_size]) | |
# if each file is missing a .wav at the end, add it | |
first_audio = audio_files[0] | |
if not first_audio.endswith(".wav"): | |
for audio_file in set(audio_files): | |
if not audio_file.endswith(".wav"): | |
shutil.move(audio_file, f"{audio_file}.wav") | |
audio_file = f"{audio_file}.wav" | |
audio_files = [f"{audio}.wav" for audio in audio_files if not audio.endswith(".wav")] | |
for audio_file in audio_files: | |
# load the file, add the cut samples back | |
sig = at.AudioSignal(audio_file) | |
sig.resample(48000) | |
sig.samples = torch.cat([sig.samples, cut_wav], dim=-1) | |
sig.write(audio_file) | |
seed = result[-1] | |
timer.tock("predict") | |
# send a message that the process is done | |
self.osc_manager.log(f"query {query_id} has been processed") | |
self.osc_manager.client.send_message("/process-result", [query_id] + audio_files) | |
def process(self, address: str, *args): | |
query_id = args[0] | |
client_type = args[1] | |
audio_path = Path(args[2]) | |
if client_type == "vampnet": | |
self.vampnet_process(address, *args) | |
return | |
elif client_type == "sketch2sound": | |
self.process_s2s(address, *args) | |
return | |
else: | |
raise ValueError(f"Unknown client type {client_type}") | |
def gradio_main( | |
vampnet_url: str = None | |
): | |
system = GradioOSCClient( | |
vampnet_url=vampnet_url, | |
ip="127.0.0.1", s_port=8003, r_port=8001, | |
) | |
system.osc_manager.start_server() | |
if __name__ == "__main__": | |
try: | |
gradio_main = argbind.bind(gradio_main, without_prefix=True) | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
gradio_main() | |
except Exception as e: | |
import shutil | |
shutil.rmtree(DOWNLOADS_DIR, ignore_errors=True) | |
raise e |