|
import torch |
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
from typing import List, Dict |
|
import decord |
|
import numpy as np |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {DEVICE}") |
|
|
|
class VideoAnalyzer: |
|
def __init__(self): |
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("CUDA is required but not available!") |
|
|
|
logger.info("Initializing VideoAnalyzer") |
|
self.model_path = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" |
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
cache_dir = "/models" |
|
logger.info(f"Using cache directory: {cache_dir}") |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
self.model_path, |
|
cache_dir=cache_dir, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
device_map = {"": 0} |
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
self.model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map=device_map, |
|
_attn_implementation="flash_attention_2", |
|
cache_dir=cache_dir |
|
) |
|
logger.info(f"Model loaded on device: {self.model.device}") |
|
|
|
def process_video(self, video_path: str, frame_interval: int = 30) -> List[Dict]: |
|
logger.info(f"Processing video: {video_path} with frame_interval={frame_interval}") |
|
try: |
|
|
|
messages = [{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "video", "path": video_path}, |
|
{"type": "text", "text": "Describe this video in detail - with all the timestamps and the actions happening in the video. I should be able to understand the video by reading the description, and search for it later."} |
|
] |
|
}] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt" |
|
).to(self.model.device) |
|
|
|
|
|
generated_ids = self.model.generate( |
|
**inputs, |
|
do_sample=False, |
|
max_new_tokens=100 |
|
) |
|
description = self.processor.batch_decode( |
|
generated_ids, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
return [{ |
|
"description": description |
|
}] |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing video: {str(e)}", exc_info=True) |
|
raise |