youssef
commited on
Commit
·
d200533
1
Parent(s):
bd727fa
remove flash attn
Browse files
src/video_processor/processor.py
CHANGED
@@ -34,7 +34,7 @@ class VideoAnalyzer:
|
|
34 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
35 |
self.model_path,
|
36 |
torch_dtype=torch.bfloat16,
|
37 |
-
_attn_implementation="flash_attention_2"
|
38 |
).to(DEVICE)
|
39 |
logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
|
40 |
|
@@ -70,6 +70,11 @@ class VideoAnalyzer:
|
|
70 |
return_tensors="pt"
|
71 |
).to(self.model.device)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
73 |
# Generate description with increased token limit
|
74 |
generated_ids = self.model.generate(
|
75 |
**inputs,
|
|
|
34 |
self.model = AutoModelForImageTextToText.from_pretrained(
|
35 |
self.model_path,
|
36 |
torch_dtype=torch.bfloat16,
|
37 |
+
# _attn_implementation="flash_attention_2"
|
38 |
).to(DEVICE)
|
39 |
logger.info(f"Model loaded on device: {self.model.device} using attention implementation: flash_attention_2")
|
40 |
|
|
|
70 |
return_tensors="pt"
|
71 |
).to(self.model.device)
|
72 |
|
73 |
+
# Convert inputs to bfloat16 before moving to GPU
|
74 |
+
#for key in inputs:
|
75 |
+
# if torch.is_tensor(inputs[key]):
|
76 |
+
# inputs[key] = inputs[key].to(dtype=torch.bfloat16, device=self.model.device)
|
77 |
+
|
78 |
# Generate description with increased token limit
|
79 |
generated_ids = self.model.generate(
|
80 |
**inputs,
|