刘虹雨 commited on
Commit
42d9724
·
1 Parent(s): ba4fe8c

update code

Browse files
app.py CHANGED
@@ -4,6 +4,7 @@ import sys
4
  import warnings
5
  import logging
6
  import spaces
 
7
  # Configure logging settings
8
  logging.basicConfig(
9
  level=logging.INFO,
@@ -52,6 +53,19 @@ if _get_output(["nvcc", "--version"]) is None:
52
  logging.info("installCUDA: %s" % _get_output(["nvcc", "--version"]))
53
  else:
54
  logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  print("CUDA_HOME =", os.environ.get("CUDA_HOME"))
56
  from torch.utils.cpp_extension import CUDA_HOME
57
  print("CUDA_HOME from PyTorch:", CUDA_HOME)
 
4
  import warnings
5
  import logging
6
  import spaces
7
+
8
  # Configure logging settings
9
  logging.basicConfig(
10
  level=logging.INFO,
 
53
  logging.info("installCUDA: %s" % _get_output(["nvcc", "--version"]))
54
  else:
55
  logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))
56
+
57
+ import torch
58
+
59
+ # 设置当前默认 GPU 设备(推荐在 CUDA 初始化前设置)
60
+ torch.cuda.set_device(0)
61
+
62
+ # 显式初始化 CUDA(通常是可选的,但在多线程中有助于避免问题)
63
+ torch.cuda.init()
64
+
65
+ # 测试
66
+ print("CUDA available:", torch.cuda.is_available())
67
+ print("Current device:", torch.cuda.current_device())
68
+ print("Device name:", torch.cuda.get_device_name(0))
69
  print("CUDA_HOME =", os.environ.get("CUDA_HOME"))
70
  from torch.utils.cpp_extension import CUDA_HOME
71
  print("CUDA_HOME from PyTorch:", CUDA_HOME)
data_process/lib/faceverse_process/core/FaceVerseModel_v3.py CHANGED
@@ -104,7 +104,7 @@ def get_renderer(img_size, device, R=None, T=None, K=None, orthoCam=False, raste
104
 
105
 
106
  class FaceVerseModel(nn.Module):
107
- def __init__(self, model_dict, batch_size=1, device='cuda:1', expr_52=True, **kargs):
108
  super(FaceVerseModel, self).__init__()
109
 
110
  self.batch_size = batch_size
 
104
 
105
 
106
  class FaceVerseModel(nn.Module):
107
+ def __init__(self, model_dict, batch_size=1, device='cuda', expr_52=True, **kargs):
108
  super(FaceVerseModel, self).__init__()
109
 
110
  self.batch_size = batch_size