File size: 2,632 Bytes
ccbd581
 
 
d78b29e
f490dc1
 
 
 
 
ccbd581
 
 
 
 
 
 
 
 
 
 
 
d78b29e
ccbd581
 
 
 
 
 
d78b29e
 
 
 
ccbd581
 
 
 
 
 
 
 
 
 
 
f490dc1
 
 
 
 
 
 
 
 
ccbd581
 
 
 
 
 
 
 
 
 
 
 
 
d78b29e
 
 
ccbd581
 
 
d78b29e
 
 
 
ccbd581
 
 
 
 
 
f490dc1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from huggingface_hub import hf_hub_download


# download_models.py

import os
from huggingface_hub import hf_hub_download


def create_directory(dir_path):
    """
    Create the directory if it doesn't exist.
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        print(f"Directory {dir_path} created.")
    else:
        print(f"Directory {dir_path} already exists.")


def download_model(repo_id, filename, destination_dir, token):
    """
    Download the model from Hugging Face and move it to the destination directory.
    """
    # Create destination directory if it doesn't exist
    create_directory(destination_dir)

    # Download the file from the Hugging Face Hub using the token
    model_path = hf_hub_download(
        repo_id=repo_id, filename=filename, use_auth_token=token
    )

    # Move the downloaded file to the destination directory
    destination_path = os.path.join(destination_dir, filename)
    os.rename(model_path, destination_path)

    # Print the full path of the downloaded file
    full_path = os.path.abspath(destination_path)
    print(f"Downloaded {filename} and moved to: {full_path}")
    return full_path


def main():
    # Get the Hugging Face token from environment variable (automatically provided from the secret)
    HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")

    if HUGGINGFACE_TOKEN is None:
        raise ValueError(
            "Hugging Face token is missing. Please set the 'HF_TOKEN' environment variable."
        )

    # Define the repositories and filenames
    magic_gnn_repo = "uripper/magic-gnn"
    oracle_cards_graph_repo = "uripper/oracle_cards_graph"

    # Define the files to download and their destination directories
    card_graph_file = "card_graph.pkl"
    oracle_cards_graph_file = "oracle_cards_graph.pkl"

    # Destination directories where the files will be stored
    card_graph_dir = "./models/magic-gnn"
    oracle_cards_graph_dir = "./data/oracle_cards_graph"

    # Download the card_graph.pkl from uripper/magic-gnn
    card_graph_path = download_model(
        magic_gnn_repo, card_graph_file, card_graph_dir, HUGGINGFACE_TOKEN
    )

    # Download the oracle_cards_graph.pkl from uripper/oracle_cards_graph
    oracle_cards_graph_path = download_model(
        oracle_cards_graph_repo,
        oracle_cards_graph_file,
        oracle_cards_graph_dir,
        HUGGINGFACE_TOKEN,
    )

    # Print final message with full paths
    print("\nDownload complete:")
    print(f"Card Graph: {card_graph_path}")
    print(f"Oracle Cards Graph: {oracle_cards_graph_path}")


if __name__ == "__main__":
    main()