youssef commited on
Commit
d200533
·
1 Parent(s): bd727fa

remove flash attn

Browse files
Files changed (1) hide show
  1. src/video_processor/processor.py +6 -1
src/video_processor/processor.py CHANGED
@@ -34,7 +34,7 @@ class VideoAnalyzer:
34
  self.model = AutoModelForImageTextToText.from_pretrained(
35
  self.model_path,
36
  torch_dtype=torch.bfloat16,
37
- _attn_implementation="flash_attention_2"
38
  ).to(DEVICE)
39
  logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
40
 
@@ -70,6 +70,11 @@ class VideoAnalyzer:
70
  return_tensors="pt"
71
  ).to(self.model.device)
72
 
 
 
 
 
 
73
  # Generate description with increased token limit
74
  generated_ids = self.model.generate(
75
  **inputs,
 
34
  self.model = AutoModelForImageTextToText.from_pretrained(
35
  self.model_path,
36
  torch_dtype=torch.bfloat16,
37
+ # _attn_implementation="flash_attention_2"
38
  ).to(DEVICE)
39
  logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
40
 
 
70
  return_tensors="pt"
71
  ).to(self.model.device)
72
 
73
+ # Convert inputs to bfloat16 before moving to GPU
74
+ #for key in inputs:
75
+ # if torch.is_tensor(inputs[key]):
76
+ # inputs[key] = inputs[key].to(dtype=torch.bfloat16, device=self.model.device)
77
+
78
  # Generate description with increased token limit
79
  generated_ids = self.model.generate(
80
  **inputs,