|
import os |
|
import requests |
|
import argparse |
|
from huggingface_hub import hf_hub_download |
|
from tqdm import tqdm |
|
|
|
def download_from_url(url, save_path): |
|
"""Download a file from a given URL and save it locally.""" |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get("content-length", 0)) |
|
block_size = 1024 |
|
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) |
|
|
|
with open(save_path, "wb") as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
progress_bar.close() |
|
|
|
if total_size != 0 and progress_bar.n != total_size: |
|
print("Download failed!") |
|
else: |
|
print(f"File downloaded to: {save_path}") |
|
|
|
def download_from_hf(repo_id, filename, save_path): |
|
"""Download a file from Hugging Face Hub.""" |
|
print(f"Downloading from Hugging Face Hub: {repo_id}/{filename}") |
|
try: |
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(save_path), local_dir_use_symlinks=False) |
|
print(f"File downloaded to: {save_path}") |
|
except Exception as e: |
|
print(f"Download failed: {e}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="Automatically download model checkpoints") |
|
parser.add_argument("--source", type=str, required=True, choices=["hf", "url"], help="Download source: hf (Hugging Face Hub) or url (custom URL)") |
|
parser.add_argument("--repo_id", type=str, help="Hugging Face model repository ID (e.g., google/bert-base-uncased)") |
|
parser.add_argument("--filename", type=str, help="Filename in the Hugging Face repository") |
|
parser.add_argument("--url", type=str, help="Custom download URL") |
|
parser.add_argument("--save_path", type=str, required=True, help="Path to save the file (including filename)") |
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(os.path.dirname(args.save_path), exist_ok=True) |
|
|
|
if args.source == "hf": |
|
if not args.repo_id or not args.filename: |
|
print("Please provide a Hugging Face repository ID and filename!") |
|
return |
|
download_from_hf(args.repo_id, args.filename, args.save_path) |
|
elif args.source == "url": |
|
if not args.url: |
|
print("Please provide a download URL!") |
|
return |
|
download_from_url(args.url, args.save_path) |
|
|
|
if __name__ == "__main__": |
|
main() |