Spaces:
Runtime error
Runtime error
Commit
·
127d24d
1
Parent(s):
df9baae
localize
Browse files
dialoggen/dialoggen_demo.py
CHANGED
@@ -26,8 +26,11 @@ import requests
|
|
26 |
from PIL import Image
|
27 |
from io import BytesIO
|
28 |
import re
|
|
|
29 |
|
30 |
|
|
|
|
|
31 |
def image_parser(image_file, sep=','):
|
32 |
out = image_file.split(sep)
|
33 |
return out
|
@@ -107,7 +110,8 @@ def eval_model(models,
|
|
107 |
input_ids = (
|
108 |
tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
|
109 |
.unsqueeze(0)
|
110 |
-
.
|
|
|
111 |
)
|
112 |
with torch.inference_mode():
|
113 |
output_ids = models["model"].generate(
|
|
|
26 |
from PIL import Image
|
27 |
from io import BytesIO
|
28 |
import re
|
29 |
+
import devicetorch
|
30 |
|
31 |
|
32 |
+
device = devictorch.get(torch)
|
33 |
+
|
34 |
def image_parser(image_file, sep=','):
|
35 |
out = image_file.split(sep)
|
36 |
return out
|
|
|
110 |
input_ids = (
|
111 |
tokenizer_image_token(prompt, models["tokenizer"], IMAGE_TOKEN_INDEX, return_tensors="pt")
|
112 |
.unsqueeze(0)
|
113 |
+
.to(device)
|
114 |
+
# .cuda()
|
115 |
)
|
116 |
with torch.inference_mode():
|
117 |
output_ids = models["model"].generate(
|