mattmdjaga commited on
Commit
13e3eda
·
verified ·
1 Parent(s): 2c8dc44

Added with torch no grad

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -94,7 +94,8 @@ def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor:
94
  inputs["pixel_values"] = (
95
  inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"]
96
  )
97
- outputs = model(**inputs)
 
98
  logits_per_image = outputs.logits_per_image
99
 
100
  return logits_per_image
 
94
  inputs["pixel_values"] = (
95
  inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"]
96
  )
97
+ with torch.no_grad():
98
+ outputs = model(**inputs)
99
  logits_per_image = outputs.logits_per_image
100
 
101
  return logits_per_image