File size: 2,898 Bytes
fa98f1c
 
958473e
 
 
fa98f1c
958473e
2bdb6ed
fa98f1c
 
958473e
 
 
 
 
 
2022859
 
 
958473e
fa98f1c
 
c7d9dc5
 
fa98f1c
 
 
 
 
52f268c
fa98f1c
 
 
 
 
 
 
 
 
c7d9dc5
c9d4907
fa98f1c
52f268c
fa98f1c
 
 
c9d4907
958473e
 
 
2bdb6ed
958473e
fa98f1c
2bdb6ed
958473e
c9d4907
2bdb6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958473e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import typer
import torch
import subprocess
from pathlib import Path

from expert import UpstreamExpert

SUBMISSION_FILES = ["expert.py", "model.pt"]
SAMPLE_RATE = 16000
SECONDS = [2, 1.8, 3.7]

app = typer.Typer()

@app.command()
def validate():
    # Check that all the expected files exist
    for file in SUBMISSION_FILES:
        if not Path(file).is_file():
            raise ValueError(f"File {file} not found! Please include {file} in your submission")

    try:
        upstream = UpstreamExpert(ckpt="model.pt")
        samples = [round(SAMPLE_RATE * sec) for sec in SECONDS]
        wavs = [torch.rand(sample) for sample in samples]
        results = upstream(wavs)

        assert isinstance(results, dict)
        tasks = ["PR", "SID", "ER", "ASR", "ASV", "SD", "QbE", "ST", "SS", "SE", "secret"]
        for task in tasks:
            hidden_states = results.get(task, results["hidden_states"])
            assert isinstance(hidden_states, list)

            for state in hidden_states:
                assert isinstance(state, torch.Tensor)
                assert state.dim() == 3, "(batch_size, max_sequence_length_of_batch, hidden_size)"
                assert state.shape == hidden_states[0].shape

            downsample_rate = upstream.get_downsample_rates(task)
            assert isinstance(downsample_rate, int)
            assert abs(round(max(samples) / downsample_rate) - hidden_states[0].size(1)) < 5, "wrong downsample rate"

    except:
        print("Please check the Upstream Specification on https://superbbenchmark.org/challenge-slt2022/upstream")
        raise

    typer.echo("All submission files validated!")
    typer.echo("Now you can upload these files to huggingface's Hub.")


@app.command()
def upload(commit_message: str):
    subprocess.call("git pull origin main".split())
    subprocess.call(["git", "add", "."])
    subprocess.call(["git", "commit", "-m", f"Upload Upstream: {commit_message} "])
    subprocess.call(["git", "push"])
    typer.echo("Upload successful!")
    typer.echo("Please go to https://superbbenchmark.org/submit to make a submission with the following information:")
    typer.echo("1. Organization Name")
    typer.echo("2. Repository Name")
    typer.echo("3. Commit Hash (full 40 characters)")
    typer.echo("These information can be shown by: python cli.py info")

@app.command()
def info():
    result = subprocess.run(["git", "config", "--get", "remote.origin.url"], capture_output=True)
    url = result.stdout.decode("utf-8").strip()
    organization = url.split("/")[-2]
    repo = url.split("/")[-1]

    result = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True)
    commit_hash = result.stdout.decode("utf-8").strip()

    typer.echo(f"Organization Name: {organization}")
    typer.echo(f"Repository Name: {repo}")
    typer.echo(f"Commit Hash: {commit_hash}")

if __name__ == "__main__":
    app()