zouzx commited on
Commit
a870321
·
0 Parent(s):

init commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gradio_splatting/frontend/node_modules/@esbuild/linux-x64/bin/esbuild filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.3.1-devel-ubuntu20.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6"
8
+ ENV TCNN_CUDA_ARCHITECTURES=86;80;75;70;61;60
9
+ ENV FORCE_CUDA=1
10
+
11
+ ENV CUDA_HOME=/usr/local/cuda
12
+ ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
13
+ ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
14
+ ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
15
+
16
+ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ curl \
19
+ git \
20
+ libegl1-mesa-dev \
21
+ libgl1-mesa-dev \
22
+ libgles2-mesa-dev \
23
+ libglib2.0-0 \
24
+ libsm6 \
25
+ libxext6 \
26
+ libxrender1 \
27
+ python-is-python3 \
28
+ python3-dev \
29
+ python3-pip \
30
+ wget \
31
+ && rm -rf /var/lib/apt/lists/*
32
+
33
+ # Set up a new user named "user" with user ID 1000
34
+ RUN useradd -m -u 1000 user
35
+ # Switch to the "user" user
36
+ USER user
37
+ # Set home to the user's home directory
38
+ ENV HOME=/home/user \
39
+ PATH=/home/user/.local/bin:$PATH \
40
+ PYTHONPATH=$HOME/app \
41
+ PYTHONUNBUFFERED=1 \
42
+ GRADIO_ALLOW_FLAGGING=never \
43
+ GRADIO_NUM_PORTS=1 \
44
+ GRADIO_SERVER_NAME=0.0.0.0 \
45
+ GRADIO_THEME=huggingface \
46
+ SYSTEM=spaces
47
+
48
+ RUN pip install --upgrade pip setuptools ninja
49
+ RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
50
+
51
+ RUN python -c "import torch; print(torch.version.cuda)"
52
+ COPY requirements.txt /tmp
53
+ RUN cd /tmp && pip install -r requirements.txt
54
+
55
+ # install pointnet2_ops from snowflake
56
+ RUN git clone https://github.com/AllenXiangX/SnowflakeNet.git /home/user/SnowflakeNet
57
+ WORKDIR /home/user/SnowflakeNet/models/pointnet2_ops_lib
58
+ RUN python setup.py install --user
59
+
60
+ # install pytorch3d
61
+ RUN git clone -b v0.7.3 https://github.com/facebookresearch/pytorch3d.git /home/user/pytorch3d-0.7.3
62
+ WORKDIR /home/user/pytorch3d-0.7.3
63
+ RUN python setup.py install --user
64
+
65
+ # install torch-scatter
66
+ RUN git clone https://github.com/rusty1s/pytorch_scatter.git /home/user/pytorch_scatter
67
+ WORKDIR /home/user/pytorch_scatter
68
+ RUN python setup.py install --user
69
+
70
+ # install diff-gaussian-rasterization
71
+ RUN git clone --recursive https://github.com/graphdeco-inria/diff-gaussian-rasterization.git /home/user/diff-gaussian-rasterization
72
+ WORKDIR /home/user/diff-gaussian-rasterization
73
+ RUN python setup.py install --user
74
+
75
+ # Set the working directory to the user's home directory
76
+ WORKDIR $HOME/app
77
+
78
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
79
+ COPY --chown=user . $HOME/app
80
+
81
+ RUN git clone https://github.com/dylanebert/gradio-splatting.git gradio_splatting
82
+
83
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TriplaneGaussian
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: docker
7
+ # sdk: gradio
8
+ # sdk_version: 4.13.0
9
+ app_file: app.py
10
+ pinned: false
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os
4
+ import glob
5
+ import torch
6
+ from PIL import Image
7
+ from copy import deepcopy
8
+ import sys
9
+ import tempfile
10
+ from huggingface_hub import snapshot_download
11
+
12
+ LOCAL_CODE = os.environ.get("LOCAL_CODE", "1") == "1"
13
+ AUTH = ("admin", os.environ["PASSWD"]) if "PASSWD" in os.environ else None
14
+
15
+ code_dir = snapshot_download("zouzx/TriplaneGaussian", local_dir="./code", token=os.environ["HF_TOKEN"]) if not LOCAL_CODE else "./code"
16
+
17
+ sys.path.append(code_dir)
18
+
19
+ from utils import image_preprocess, pred_bbox, sam_init, sam_out_nosave, todevice
20
+ from gradio_splatting.backend.gradio_model3dgs import Model3DGS
21
+ import tgs
22
+ from tgs.utils.config import ExperimentConfig, load_config
23
+ from tgs.systems.infer import TGS
24
+
25
+ SAM_CKPT_PATH = "code/checkpoints/sam_vit_h_4b8939.pth"
26
+ MODEL_CKPT_PATH = "code/checkpoints/tgs_lvis_100v_rel.ckpt"
27
+ CONFIG = "code/configs/single-rel.yaml"
28
+ EXP_ROOT_DIR = "./outputs-gradio"
29
+
30
+ gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
31
+ device = "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu"
32
+
33
+ print("device: ", device)
34
+
35
+ # load SAM checkpoint
36
+ sam_predictor = sam_init(SAM_CKPT_PATH, gpu)
37
+ print("load sam ckpt done.")
38
+
39
+ # init system
40
+ base_cfg: ExperimentConfig
41
+ base_cfg = load_config(CONFIG, cli_args=[], n_gpus=1)
42
+ base_cfg.system.weights = MODEL_CKPT_PATH
43
+ system = TGS(cfg=base_cfg.system).to(device)
44
+ print("load model ckpt done.")
45
+
46
+ HEADER = """
47
+ # Triplane Meets Gaussian Splatting: Fast and Generalizable Single-View 3D Reconstruction with Transformers
48
+
49
+ <div>
50
+ <a style="display: inline-block;" href="https://arxiv.org/abs/2312.09147"><img src="https://img.shields.io/badge/arxiv-2312.09147-B31B1B.svg"></a>
51
+ </div>
52
+
53
+ TGS enables fast reconstruction from single-view image in a few seconds based on a hybrid Triplane-Gaussian 3D representation.
54
+
55
+ This model is trained on Objaverse-LVIS (~40K synthetic objects) only. And note that we normalize the input camera pose to a pre-set viewpoint during training stage following LRM, rather than directly using camera pose of input camera as implemented in our original paper.
56
+ """
57
+
58
+ def preprocess(image_path, save_path=None, lower_contrast=False):
59
+ input_raw = Image.open(image_path)
60
+
61
+ input_raw.thumbnail([512, 512], Image.Resampling.LANCZOS)
62
+ image_sam = sam_out_nosave(
63
+ sam_predictor, input_raw.convert("RGB"), pred_bbox(input_raw)
64
+ )
65
+
66
+ if save_path is None:
67
+ save_path, ext = os.path.splitext(image_path)
68
+ save_path = save_path + "_rgba.png"
69
+ image_preprocess(image_sam, save_path, lower_contrast=lower_contrast, rescale=True)
70
+
71
+ return save_path
72
+
73
+ def init_trial_dir():
74
+ if not os.path.exists(EXP_ROOT_DIR):
75
+ os.makedirs(EXP_ROOT_DIR, exist_ok=True)
76
+ trial_dir = tempfile.TemporaryDirectory(dir=EXP_ROOT_DIR).name
77
+ system.set_save_dir(trial_dir)
78
+ return trial_dir
79
+
80
+ @torch.no_grad()
81
+ def infer(image_path: str,
82
+ cam_dist: float,
83
+ fovy_deg: float,
84
+ only_3dgs: bool = False):
85
+ data_cfg = deepcopy(base_cfg.data)
86
+ data_cfg.only_3dgs = only_3dgs
87
+ data_cfg.cond_fovy_deg = fovy_deg
88
+ data_cfg.cond_camera_distance = cam_dist
89
+ data_cfg.image_list = [image_path]
90
+ dm = tgs.find(base_cfg.data_cls)(data_cfg)
91
+
92
+ dm.setup()
93
+ for batch_idx, batch in enumerate(dm.test_dataloader()):
94
+ batch = todevice(batch, device)
95
+ system.test_step(batch, batch_idx, save_3dgs=only_3dgs)
96
+ if not only_3dgs:
97
+ system.on_test_epoch_end()
98
+
99
+ def run(image_path: str,
100
+ cam_dist: float,
101
+ fov_degree: float):
102
+ infer(image_path, cam_dist, fov_degree, only_3dgs=True)
103
+ save_path = system.get_save_dir()
104
+ gs = glob.glob(os.path.join(save_path, "*.ply"))[0]
105
+ return gs
106
+
107
+ def run_video(image_path: str,
108
+ cam_dist: float,
109
+ fov_degree: float):
110
+ infer(image_path, cam_dist, fov_degree)
111
+ save_path = system.get_save_dir()
112
+ video = glob.glob(os.path.join(save_path, "*.mp4"))[0]
113
+ return video
114
+
115
+ def launch(port):
116
+ with gr.Blocks(
117
+ title="TGS - Demo",
118
+ theme=gr.themes.Monochrome()
119
+ ) as demo:
120
+ with gr.Row(variant='panel'):
121
+ gr.Markdown(HEADER)
122
+
123
+ with gr.Row(variant='panel'):
124
+ with gr.Column(scale=1):
125
+ input_image = gr.Image(value=None, width=512, height=512, type="filepath", label="Input Image")
126
+ fov_deg_slider = gr.Slider(20, 80, value=40, step=1, label="Camera Fov Degree")
127
+ camera_dist_slider = gr.Slider(1.0, 4.0, value=1.6, step=0.1, label="Camera Distance")
128
+ img_run_btn = gr.Button("Reconstruction")
129
+
130
+ gr.Examples(
131
+ examples=[
132
+ "example_images/green_parrot.webp",
133
+ "example_images/rusty_gameboy.webp",
134
+ "example_images/a_pikachu_with_smily_face.webp",
135
+ "example_images/an_otter_wearing_sunglasses.webp",
136
+ "example_images/lumberjack_axe.webp",
137
+ "example_images/medieval_shield.webp"
138
+ ],
139
+ inputs=[input_image],
140
+ cache_examples=False,
141
+ label="Examples",
142
+ examples_per_page=40
143
+ )
144
+
145
+ with gr.Column(scale=1):
146
+ with gr.Row(variant='panel'):
147
+ seg_image = gr.Image(value=None, type="filepath", height=256, width=256, image_mode="RGBA", label="Segmented Image", interactive=False)
148
+ output_video = gr.Video(value=None, label="Video", height=256, autoplay=True)
149
+ output_3dgs = Model3DGS(value=None, label="3DGS")
150
+
151
+ img_run_btn.click(
152
+ fn=preprocess,
153
+ inputs=[input_image],
154
+ outputs=[seg_image],
155
+ concurrency_limit=1,
156
+ ).success(
157
+ fn=init_trial_dir,
158
+ concurrency_limit=1,
159
+ ).success(fn=run,
160
+ inputs=[seg_image, camera_dist_slider, fov_deg_slider],
161
+ outputs=[output_3dgs],
162
+ concurrency_limit=1
163
+ ).success(fn=run_video,
164
+ inputs=[seg_image, camera_dist_slider, fov_deg_slider],
165
+ outputs=[output_video],
166
+ concurrency_limit=1)
167
+
168
+ launch_args = {"server_port": port}
169
+ demo.queue(max_size=10)
170
+ demo.launch(auth=AUTH, **launch_args)
171
+
172
+ if __name__ == "__main__":
173
+ parser = argparse.ArgumentParser()
174
+ args, extra = parser.parse_known_args()
175
+ parser.add_argument("--port", type=int, default=7860)
176
+ args = parser.parse_args()
177
+ launch(args.port)
example_images/a_pikachu_with_smily_face.webp ADDED
example_images/an_otter_wearing_sunglasses.webp ADDED
example_images/green_parrot.webp ADDED
example_images/lumberjack_axe.webp ADDED
example_images/medieval_shield.webp ADDED
example_images/rusty_gameboy.webp ADDED
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning==2.0.7
2
+ pytorch-lightning==2.0.2
3
+ plyfile
4
+ OmegaConf
5
+ matplotlib
6
+ einops
7
+ gradio
8
+ diffusers==0.19.3
9
+ transformers==4.34.1
10
+ rembg
11
+ segment_anything
12
+ jaxtyping
13
+ imageio
14
+ imageio-ffmpeg
utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from rembg import remove
9
+ from segment_anything import SamPredictor, sam_model_registry
10
+ import urllib.request
11
+ from tqdm import tqdm
12
+
13
+
14
+ def sam_init(sam_checkpoint, device_id=0):
15
+ # sam_checkpoint = os.path.join(os.path.dirname(__file__), "./sam_vit_h_4b8939.pth")
16
+ model_type = "vit_h"
17
+
18
+ device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
19
+
20
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
21
+ predictor = SamPredictor(sam)
22
+ return predictor
23
+
24
+
25
+ def sam_out_nosave(predictor, input_image, *bbox_sliders):
26
+ bbox = np.array(bbox_sliders)
27
+ image = np.asarray(input_image)
28
+
29
+ start_time = time.time()
30
+ predictor.set_image(image)
31
+
32
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
33
+ box=bbox, multimask_output=True
34
+ )
35
+
36
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
37
+ out_image[:, :, :3] = image
38
+ out_image_bbox = out_image.copy()
39
+ out_image_bbox[:, :, 3] = (
40
+ masks_bbox[-1].astype(np.uint8) * 255
41
+ ) # np.argmax(scores_bbox)
42
+ torch.cuda.empty_cache()
43
+ return Image.fromarray(out_image_bbox, mode="RGBA")
44
+
45
+
46
+ # contrast correction, rescale and recenter
47
+ def image_preprocess(input_image, save_path, lower_contrast=True, rescale=True):
48
+ image_arr = np.array(input_image)
49
+ in_w, in_h = image_arr.shape[:2]
50
+
51
+ if lower_contrast:
52
+ alpha = 0.8 # Contrast control (1.0-3.0)
53
+ beta = 0 # Brightness control (0-100)
54
+ # Apply the contrast adjustment
55
+ image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
56
+ image_arr[image_arr[..., -1] > 200, -1] = 255
57
+
58
+ ret, mask = cv2.threshold(
59
+ np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
60
+ )
61
+ x, y, w, h = cv2.boundingRect(mask)
62
+ max_size = max(w, h)
63
+ ratio = 0.75
64
+ if rescale:
65
+ side_len = int(max_size / ratio)
66
+ else:
67
+ side_len = in_w
68
+ padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
69
+ center = side_len // 2
70
+ padded_image[
71
+ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
72
+ ] = image_arr[y : y + h, x : x + w]
73
+ rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS)
74
+ rgba.save(save_path)
75
+
76
+ # rgba_arr = np.array(rgba) / 255.0
77
+ # rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
78
+ # return Image.fromarray((rgb * 255).astype(np.uint8))
79
+
80
+
81
+ def pred_bbox(image):
82
+ image_nobg = remove(image.convert("RGBA"), alpha_matting=True)
83
+ alpha = np.asarray(image_nobg)[:, :, -1]
84
+ x_nonzero = np.nonzero(alpha.sum(axis=0))
85
+ y_nonzero = np.nonzero(alpha.sum(axis=1))
86
+ x_min = int(x_nonzero[0].min())
87
+ y_min = int(y_nonzero[0].min())
88
+ x_max = int(x_nonzero[0].max())
89
+ y_max = int(y_nonzero[0].max())
90
+ return x_min, y_min, x_max, y_max
91
+
92
+ # convert a function into recursive style to handle nested dict/list/tuple variables
93
+ def make_recursive_func(func):
94
+ def wrapper(vars, *args, **kwargs):
95
+ if isinstance(vars, list):
96
+ return [wrapper(x, *args, **kwargs) for x in vars]
97
+ elif isinstance(vars, tuple):
98
+ return tuple([wrapper(x, *args, **kwargs) for x in vars])
99
+ elif isinstance(vars, dict):
100
+ return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
101
+ else:
102
+ return func(vars, *args, **kwargs)
103
+
104
+ return wrapper
105
+
106
+ @make_recursive_func
107
+ def todevice(vars, device="cuda"):
108
+ if isinstance(vars, torch.Tensor):
109
+ return vars.to(device)
110
+ elif isinstance(vars, str):
111
+ return vars
112
+ elif isinstance(vars, bool):
113
+ return vars
114
+ elif isinstance(vars, float):
115
+ return vars
116
+ elif isinstance(vars, int):
117
+ return vars
118
+ else:
119
+ raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
120
+
121
+ def download_checkpoint(url, save_path):
122
+ try:
123
+ with urllib.request.urlopen(url) as response, open(save_path, 'wb') as file:
124
+ file_size = int(response.info().get('Content-Length', -1))
125
+ chunk_size = 8192
126
+ num_chunks = file_size // chunk_size if file_size > chunk_size else 1
127
+
128
+ with tqdm(total=file_size, unit='B', unit_scale=True, desc='Downloading', ncols=100) as pbar:
129
+ for chunk in iter(lambda: response.read(chunk_size), b''):
130
+ file.write(chunk)
131
+ pbar.update(len(chunk))
132
+
133
+ print(f"Checkpoint downloaded and saved to: {save_path}")
134
+ except Exception as e:
135
+ print(f"Error downloading checkpoint: {e}")