Mediocreatmybest commited on
Commit
593f239
·
1 Parent(s): d867e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -28,12 +28,13 @@ def caption_image(model_choice, image_input, url_input, load_in_8bit):
28
  captioner = loaded_models[model_key]
29
  else:
30
  model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
 
31
  captioner = pipeline(task="image-to-text",
32
  model=CAPTION_MODELS[model_choice],
33
  max_new_tokens=30,
34
- device=0, # Set the device as CPU
35
  model_kwargs=model_kwargs,
36
- torch_dtype=torch.float16, # Set the floating point to Float16
37
  use_fast=True
38
  )
39
  # Store the loaded model
 
28
  captioner = loaded_models[model_key]
29
  else:
30
  model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
31
+ dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
32
  captioner = pipeline(task="image-to-text",
33
  model=CAPTION_MODELS[model_choice],
34
  max_new_tokens=30,
35
+ device='cpu', # Set the device as CPU
36
  model_kwargs=model_kwargs,
37
+ torch_dtype=dtype, # Set the floating point
38
  use_fast=True
39
  )
40
  # Store the loaded model