Spaces:
Paused
Paused
Update spaces/zero/torch.py
Browse files- spaces/zero/torch.py +14 -0
spaces/zero/torch.py
CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|
6 |
|
7 |
import multiprocessing
|
8 |
import os
|
|
|
9 |
from concurrent.futures import ProcessPoolExecutor
|
10 |
from contextlib import suppress
|
11 |
from functools import partial
|
@@ -241,8 +242,12 @@ if (torch := maybe_import_torch()):
|
|
241 |
bitsandbytes.unpatch()
|
242 |
|
243 |
def _move(nvidia_uuid: str):
|
|
|
244 |
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
245 |
torch.Tensor([0]).cuda() # CUDA init
|
|
|
|
|
|
|
246 |
for op in to_ops.items():
|
247 |
tensor, parsed_args = op
|
248 |
_, dtype, _, memory_format = parsed_args
|
@@ -251,8 +256,17 @@ if (torch := maybe_import_torch()):
|
|
251 |
dtype=dtype,
|
252 |
memory_format=memory_format,
|
253 |
) # type: ignore
|
|
|
|
|
|
|
254 |
bitsandbytes.move()
|
|
|
|
|
|
|
255 |
torch.cuda.synchronize()
|
|
|
|
|
|
|
256 |
|
257 |
def _is_in_bad_fork():
|
258 |
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
|
|
6 |
|
7 |
import multiprocessing
|
8 |
import os
|
9 |
+
import time
|
10 |
from concurrent.futures import ProcessPoolExecutor
|
11 |
from contextlib import suppress
|
12 |
from functools import partial
|
|
|
242 |
bitsandbytes.unpatch()
|
243 |
|
244 |
def _move(nvidia_uuid: str):
|
245 |
+
t0 = time.perf_counter()
|
246 |
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
247 |
torch.Tensor([0]).cuda() # CUDA init
|
248 |
+
t1 = time.perf_counter()
|
249 |
+
print("CUDA init", t1 - t0)
|
250 |
+
t0 = t1
|
251 |
for op in to_ops.items():
|
252 |
tensor, parsed_args = op
|
253 |
_, dtype, _, memory_format = parsed_args
|
|
|
256 |
dtype=dtype,
|
257 |
memory_format=memory_format,
|
258 |
) # type: ignore
|
259 |
+
t1 = time.perf_counter()
|
260 |
+
print("CUDA move", t1 - t0)
|
261 |
+
t0 = t1
|
262 |
bitsandbytes.move()
|
263 |
+
t1 = time.perf_counter()
|
264 |
+
print("BNB move", t1 - t0)
|
265 |
+
t0 = t1
|
266 |
torch.cuda.synchronize()
|
267 |
+
t1 = time.perf_counter()
|
268 |
+
print("CUDA synchronize", t1 - t0)
|
269 |
+
t0 = t1
|
270 |
|
271 |
def _is_in_bad_fork():
|
272 |
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|