ClownRat commited on
Commit
ee906b7
·
1 Parent(s): fc5df74

update demo.

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -219,8 +219,35 @@ if __name__ == '__main__':
219
  conv_mode = "llama_2"
220
  model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B'
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
223
  # handler.model.to(dtype=torch.float16)
 
224
 
225
  if not os.path.exists("temp"):
226
  os.makedirs("temp")
 
219
  conv_mode = "llama_2"
220
  model_path = 'DAMO-NLP-SG/VideoLLaMA2-7B'
221
 
222
+ def find_cuda():
223
+ # Check if CUDA_HOME or CUDA_PATH environment variables are set
224
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
225
+
226
+ if cuda_home and os.path.exists(cuda_home):
227
+ return cuda_home
228
+
229
+ # Search for the nvcc executable in the system's PATH
230
+ nvcc_path = shutil.which('nvcc')
231
+
232
+ if nvcc_path:
233
+ # Remove the 'bin/nvcc' part to get the CUDA installation path
234
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
235
+ return cuda_path
236
+
237
+ return None
238
+
239
+ cuda_path = find_cuda()
240
+
241
+ if cuda_path:
242
+ print(f"CUDA installation found at: {cuda_path}")
243
+ else:
244
+ print("CUDA installation not found")
245
+
246
+ device = torch.device("cuda")
247
+
248
  handler = Chat(model_path, conv_mode=conv_mode, load_8bit=False, load_4bit=True)
249
  # handler.model.to(dtype=torch.float16)
250
+ handler = handler.model.to(device)
251
 
252
  if not os.path.exists("temp"):
253
  os.makedirs("temp")