Spaces:
Runtime error
Runtime error
Commit
·
02448aa
1
Parent(s):
20a5e29
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ CAPTION_MODELS = {
|
|
16 |
loaded_models = {}
|
17 |
|
18 |
# Simple caption creation
|
19 |
-
def caption_image(model_choice, image_input, url_inputs, load_in_8bit):
|
20 |
if image_input is not None:
|
21 |
input_data = [image_input]
|
22 |
else:
|
@@ -34,7 +34,7 @@ def caption_image(model_choice, image_input, url_inputs, load_in_8bit):
|
|
34 |
captioner = pipeline(task="image-to-text",
|
35 |
model=CAPTION_MODELS[model_choice],
|
36 |
max_new_tokens=30,
|
37 |
-
device=
|
38 |
model_kwargs=model_kwargs,
|
39 |
torch_dtype=dtype, # Set the floating point
|
40 |
use_fast=True
|
|
|
16 |
loaded_models = {}
|
17 |
|
18 |
# Simple caption creation
|
19 |
+
def caption_image(model_choice, image_input, url_inputs, load_in_8bit, device):
|
20 |
if image_input is not None:
|
21 |
input_data = [image_input]
|
22 |
else:
|
|
|
34 |
captioner = pipeline(task="image-to-text",
|
35 |
model=CAPTION_MODELS[model_choice],
|
36 |
max_new_tokens=30,
|
37 |
+
device=device, # Use selected device
|
38 |
model_kwargs=model_kwargs,
|
39 |
torch_dtype=dtype, # Set the floating point
|
40 |
use_fast=True
|