Spaces:
Paused
Paused
# import gdown | |
# import os | |
# import argparse | |
# def download_model(model_id, folder, filename): | |
# os.makedirs(folder, exist_ok=True) | |
# url = f"https://drive.google.com/uc?id={model_id}" | |
# output_path = os.path.join(folder, filename) | |
# print(f"Downloading model to {output_path}...") | |
# gdown.download(url, output_path, quiet=False) | |
# print("Download complete!") | |
# def main(): | |
# parser = argparse.ArgumentParser(description="Download models using gdown and organize them into appropriate folders.") | |
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model") | |
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model") | |
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model") | |
# args = parser.parse_args() | |
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98" | |
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf" | |
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh" | |
# if args.pretrained: | |
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt") | |
# if args.sft: | |
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt") | |
# if args.dpo: | |
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt") | |
# if __name__ == "__main__": | |
# main() | |
# import os | |
# import argparse | |
# def download_model(model_id, folder, filename, access_token): | |
# os.makedirs(folder, exist_ok=True) | |
# output_path = os.path.join(folder, filename) | |
# url = f"https://www.googleapis.com/drive/v3/files/{model_id}?alt=media" | |
# command = f"curl -H \"Authorization: Bearer {access_token}\" {url} -o {output_path}" | |
# print(f"Downloading model to {output_path}...") | |
# os.system(command) | |
# print("Download complete!") | |
# def main(): | |
# parser = argparse.ArgumentParser(description="Download models using Google Drive API and organize them into appropriate folders.") | |
# parser.add_argument("-P", "--pretrained", action="store_true", help="Download the pretrained model") | |
# parser.add_argument("-F", "--sft", action="store_true", help="Download the fine-tuned model") | |
# parser.add_argument("-D", "--dpo", action="store_true", help="Download the DPO model") | |
# parser.add_argument("--token", type=str, required=True, help="Google Drive API Access Token") | |
# args = parser.parse_args() | |
# pretrained_model_file_id = "1CwtDjbN6a7tt7mykywxAANHBTvdSr-98" | |
# fine_tuned_model_id = "10bsea7_MFXw6T967iCrp6zSGMfqDljHf" | |
# dpo_model_file_id = "1hIzV_VVdvmplQQuaH9QQCcmUbfolFjyh" | |
# if args.pretrained: | |
# download_model(pretrained_model_file_id, "weights/pretrained", "pretrained_model.pt", args.token) | |
# if args.sft: | |
# download_model(fine_tuned_model_id, "weights/fine_tuned", "fine_tuned_model.pt", args.token) | |
# if args.dpo: | |
# download_model(dpo_model_file_id, "weights/DPO", "dpo_model.pt", args.token) | |
# if __name__ == "__main__": | |
# main() | |
# download_model_weight.py | |
import os | |
import argparse | |
from huggingface_hub import hf_hub_download, login | |
def download_model(repo_id, filename, cache_dir): | |
try: | |
model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
cache_dir=cache_dir, | |
resume_download=True, | |
force_download=False, | |
token=os.getenv("HF_TOKEN") | |
) | |
if os.path.exists(model_path) and os.path.getsize(model_path) > 1024*1024: | |
return model_path | |
raise ValueError("Downloaded file is too small or invalid") | |
except Exception as e: | |
print(f"Download failed: {str(e)}") | |
raise | |
def main(): | |
parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub") | |
parser.add_argument("--model_type", | |
choices=["pretrained"], | |
required=True, | |
help="Type of model to download") | |
args = parser.parse_args() | |
model_config = { | |
"pretrained": { | |
"repo_id": "YuvrajSingh9886/StoryLlama", | |
"filename": "snapshot_4650.pt", | |
"cache_dir": "weights/pretrained" | |
} | |
} | |
config = model_config[args.model_type] | |
os.makedirs(config["cache_dir"], exist_ok=True) | |
print(f"Downloading {args.model_type} model...") | |
model_path = download_model( | |
config["repo_id"], | |
config["filename"], | |
config["cache_dir"] | |
) | |
print(f"Successfully downloaded to: {model_path}") | |
if __name__ == "__main__": | |
login(token=os.getenv("HF_TOKEN")) | |
main() | |