File size: 2,898 Bytes
fa98f1c 958473e fa98f1c 958473e 2bdb6ed fa98f1c 958473e 2022859 958473e fa98f1c c7d9dc5 fa98f1c 52f268c fa98f1c c7d9dc5 c9d4907 fa98f1c 52f268c fa98f1c c9d4907 958473e 2bdb6ed 958473e fa98f1c 2bdb6ed 958473e c9d4907 2bdb6ed 958473e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import typer
import torch
import subprocess
from pathlib import Path
from expert import UpstreamExpert
SUBMISSION_FILES = ["expert.py", "model.pt"]
SAMPLE_RATE = 16000
SECONDS = [2, 1.8, 3.7]
app = typer.Typer()
@app.command()
def validate():
# Check that all the expected files exist
for file in SUBMISSION_FILES:
if not Path(file).is_file():
raise ValueError(f"File {file} not found! Please include {file} in your submission")
try:
upstream = UpstreamExpert(ckpt="model.pt")
samples = [round(SAMPLE_RATE * sec) for sec in SECONDS]
wavs = [torch.rand(sample) for sample in samples]
results = upstream(wavs)
assert isinstance(results, dict)
tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"]
for task in tasks:
hidden_states = results.get(task, results["hidden_states"])
assert isinstance(hidden_states, list)
for state in hidden_states:
assert isinstance(state, torch.Tensor)
assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)"
assert state.shape == hidden_states[0].shape
downsample_rate = upstream.get_downsample_rates(task)
assert isinstance(downsample_rate, int)
assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate"
except:
print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream")
raise
typer.echo("All submission files validated!")
typer.echo("Now you can upload these files to huggingface's Hub.")
@app.command()
def upload(commit_message: str):
subprocess.call("git pull origin main".split())
subprocess.call(["git", "add", "."])
subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "])
subprocess.call(["git", "push"])
typer.echo("Upload successful!")
typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:")
typer.echo("1. Organization Name")
typer.echo("2. Repository Name")
typer.echo("3. Commit Hash (full 40 characters)")
typer.echo("These information can be shown by: python cli.py info")
@app.command()
def info():
result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True)
url = result.stdout.decode("utf-8").strip()
organization = url.split("/")[-2]
repo = url.split("/")[-1]
result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True)
commit_hash = result.stdout.decode("utf-8").strip()
typer.echo(f"Organization Name: {organization}")
typer.echo(f"Repository Name: {repo}")
typer.echo(f"Commit Hash: {commit_hash}")
if __name__ == "__main__":
app()
|