|  | from pathlib import Path | 
					
						
						|  |  | 
					
						
						|  | import argbind | 
					
						
						|  | from audiotools import ml | 
					
						
						|  |  | 
					
						
						|  | import dac | 
					
						
						|  |  | 
					
						
						|  | DAC = dac.model.DAC | 
					
						
						|  | Accelerator = ml.Accelerator | 
					
						
						|  |  | 
					
						
						|  | __MODEL_LATEST_TAGS__ = { | 
					
						
						|  | ("44khz", "8kbps"): "0.0.1", | 
					
						
						|  | ("24khz", "8kbps"): "0.0.4", | 
					
						
						|  | ("16khz", "8kbps"): "0.0.5", | 
					
						
						|  | ("44khz", "16kbps"): "1.0.0", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | __MODEL_URLS__ = { | 
					
						
						|  | ( | 
					
						
						|  | "44khz", | 
					
						
						|  | "0.0.1", | 
					
						
						|  | "8kbps", | 
					
						
						|  | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", | 
					
						
						|  | ( | 
					
						
						|  | "24khz", | 
					
						
						|  | "0.0.4", | 
					
						
						|  | "8kbps", | 
					
						
						|  | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", | 
					
						
						|  | ( | 
					
						
						|  | "16khz", | 
					
						
						|  | "0.0.5", | 
					
						
						|  | "8kbps", | 
					
						
						|  | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", | 
					
						
						|  | ( | 
					
						
						|  | "44khz", | 
					
						
						|  | "1.0.0", | 
					
						
						|  | "16kbps", | 
					
						
						|  | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @argbind.bind(group="download", positional=True, without_prefix=True) | 
					
						
						|  | def download( | 
					
						
						|  | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Function that downloads the weights file from URL if a local cache is not found. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | model_type : str | 
					
						
						|  | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". | 
					
						
						|  | model_bitrate: str | 
					
						
						|  | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". | 
					
						
						|  | Only 44khz model supports 16kbps. | 
					
						
						|  | tag : str | 
					
						
						|  | The tag of the model to download. Defaults to "latest". | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | Path | 
					
						
						|  | Directory path required to load model via audiotools. | 
					
						
						|  | """ | 
					
						
						|  | model_type = model_type.lower() | 
					
						
						|  | tag = tag.lower() | 
					
						
						|  |  | 
					
						
						|  | assert model_type in [ | 
					
						
						|  | "44khz", | 
					
						
						|  | "24khz", | 
					
						
						|  | "16khz", | 
					
						
						|  | ], "model_type must be one of '44khz', '24khz', or '16khz'" | 
					
						
						|  |  | 
					
						
						|  | assert model_bitrate in [ | 
					
						
						|  | "8kbps", | 
					
						
						|  | "16kbps", | 
					
						
						|  | ], "model_bitrate must be one of '8kbps', or '16kbps'" | 
					
						
						|  |  | 
					
						
						|  | if tag == "latest": | 
					
						
						|  | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] | 
					
						
						|  |  | 
					
						
						|  | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) | 
					
						
						|  |  | 
					
						
						|  | if download_link is None: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Could not find model with tag {tag} and model type {model_type}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | local_path = ( | 
					
						
						|  | Path.home() | 
					
						
						|  | / ".cache" | 
					
						
						|  | / "descript" | 
					
						
						|  | / "dac" | 
					
						
						|  | / f"weights_{model_type}_{model_bitrate}_{tag}.pth" | 
					
						
						|  | ) | 
					
						
						|  | if not local_path.exists(): | 
					
						
						|  | local_path.parent.mkdir(parents=True, exist_ok=True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import requests | 
					
						
						|  |  | 
					
						
						|  | response = requests.get(download_link) | 
					
						
						|  |  | 
					
						
						|  | if response.status_code != 200: | 
					
						
						|  | raise ValueError( | 
					
						
						|  | f"Could not download model. Received response code {response.status_code}" | 
					
						
						|  | ) | 
					
						
						|  | local_path.write_bytes(response.content) | 
					
						
						|  |  | 
					
						
						|  | return local_path | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_model( | 
					
						
						|  | model_type: str = "44khz", | 
					
						
						|  | model_bitrate: str = "8kbps", | 
					
						
						|  | tag: str = "latest", | 
					
						
						|  | load_path: str = None, | 
					
						
						|  | ): | 
					
						
						|  | if not load_path: | 
					
						
						|  | load_path = download( | 
					
						
						|  | model_type=model_type, model_bitrate=model_bitrate, tag=tag | 
					
						
						|  | ) | 
					
						
						|  | generator = DAC.load(load_path) | 
					
						
						|  | return generator | 
					
						
						|  |  |