similar-cards / app /download_models.py
uripper
it needs to be able to see a function named main
f490dc1
import os
from huggingface_hub import hf_hub_download
# download_models.py
import os
from huggingface_hub import hf_hub_download
def create_directory(dir_path):
"""
Create the directory if it doesn't exist.
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print(f"Directory {dir_path} created.")
else:
print(f"Directory {dir_path} already exists.")
def download_model(repo_id, filename, destination_dir, token):
"""
Download the model from Hugging Face and move it to the destination directory.
"""
# Create destination directory if it doesn't exist
create_directory(destination_dir)
# Download the file from the Hugging Face Hub using the token
model_path = hf_hub_download(
repo_id=repo_id, filename=filename, use_auth_token=token
)
# Move the downloaded file to the destination directory
destination_path = os.path.join(destination_dir, filename)
os.rename(model_path, destination_path)
# Print the full path of the downloaded file
full_path = os.path.abspath(destination_path)
print(f"Downloaded {filename} and moved to: {full_path}")
return full_path
def main():
# Get the Hugging Face token from environment variable (automatically provided from the secret)
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
if HUGGINGFACE_TOKEN is None:
raise ValueError(
"Hugging Face token is missing. Please set the 'HF_TOKEN' environment variable."
)
# Define the repositories and filenames
magic_gnn_repo = "uripper/magic-gnn"
oracle_cards_graph_repo = "uripper/oracle_cards_graph"
# Define the files to download and their destination directories
card_graph_file = "card_graph.pkl"
oracle_cards_graph_file = "oracle_cards_graph.pkl"
# Destination directories where the files will be stored
card_graph_dir = "./models/magic-gnn"
oracle_cards_graph_dir = "./data/oracle_cards_graph"
# Download the card_graph.pkl from uripper/magic-gnn
card_graph_path = download_model(
magic_gnn_repo, card_graph_file, card_graph_dir, HUGGINGFACE_TOKEN
)
# Download the oracle_cards_graph.pkl from uripper/oracle_cards_graph
oracle_cards_graph_path = download_model(
oracle_cards_graph_repo,
oracle_cards_graph_file,
oracle_cards_graph_dir,
HUGGINGFACE_TOKEN,
)
# Print final message with full paths
print("\nDownload complete:")
print(f"Card Graph: {card_graph_path}")
print(f"Oracle Cards Graph: {oracle_cards_graph_path}")
if __name__ == "__main__":
main()