cocktailpeanut commited on
Commit
127d24d
·
1 Parent(s): df9baae
Files changed (1) hide show
  1. dialoggen/dialoggen_demo.py +5 -1
dialoggen/dialoggen_demo.py CHANGED
@@ -26,8 +26,11 @@ import requests
26
  from PIL import Image
27
  from io import BytesIO
28
  import re
 
29
 
30
 
 
 
31
  def image_parser(image_file, sep=','):
32
  out = image_file.split(sep)
33
  return out
@@ -107,7 +110,8 @@ def eval_model(models,
107
  input_ids = (
108
  tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
109
  .unsqueeze(0)
110
- .cuda()
 
111
  )
112
  with torch.inference_mode():
113
  output_ids = models["model"].generate(
 
26
  from PIL import Image
27
  from io import BytesIO
28
  import re
29
+ import devicetorch
30
 
31
 
32
+ device = devictorch.get(torch)
33
+
34
  def image_parser(image_file, sep=','):
35
  out = image_file.split(sep)
36
  return out
 
110
  input_ids = (
111
  tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
112
  .unsqueeze(0)
113
+ .to(device)
114
+ # .cuda()
115
  )
116
  with torch.inference_mode():
117
  output_ids = models["model"].generate(