Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from PIL import Image | |
from transformers import AutoModel, AutoTokenizer | |
from decord import VideoReader, cpu | |
import spaces | |
class VLMCaptioning: | |
def __init__(self): | |
print("Loading MiniCPM-O model...") | |
self.model = AutoModel.from_pretrained( | |
'openbmb/MiniCPM-o-2_6', | |
trust_remote_code=True, | |
attn_implementation='sdpa', | |
torch_dtype=torch.bfloat16, | |
init_vision=True, | |
) | |
self.model = self.model.eval().cuda() | |
self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-o-2_6', trust_remote_code=True) | |
def describe_image( | |
self, | |
image: str, | |
question: str = "Describe this image in detail.", | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
top_k: int = 40, | |
max_new_tokens: int = 512, | |
stream=False, | |
sampling=False | |
) -> str: | |
""" | |
Generate description for a single image | |
Args: | |
image (str): Path to image file | |
question (str): Question to ask about the image | |
temperature (float): Sampling temperature | |
top_p (float): Nucleus sampling parameter | |
top_k (int): Top-k sampling parameter | |
max_new_tokens (int): Maximum new tokens to generate | |
Returns: | |
str: Generated description | |
""" | |
try: | |
if not image: | |
return "Please provide an image." | |
# Convert image to RGB | |
image = Image.open(image).convert('RGB') | |
# Prepare message | |
msgs = [{'role': 'user', 'content': [image, question]}] | |
# Generate response | |
response = self.model.chat( | |
image=None, | |
msgs=msgs, | |
tokenizer=self.tokenizer, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
stream=stream, | |
sampling=sampling | |
) | |
return response | |
except Exception as e: | |
return f"Error analyzing image: {str(e)}" | |
def describe_video( | |
self, | |
video_path: str, | |
frame_interval: int = 30, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
top_k: int = 40, | |
max_new_tokens: int = 512, | |
stream=False, | |
sampling=False | |
) -> str: | |
""" | |
Generate description for video frames | |
Args: | |
video_path (str): Path to video file | |
frame_interval (int): Interval between frames to analyze | |
temperature (float): Sampling temperature | |
top_p (float): Nucleus sampling parameter | |
top_k (int): Top-k sampling parameter | |
max_new_tokens (int): Maximum new tokens to generate | |
Returns: | |
str: Generated description | |
""" | |
try: | |
# Load video and extract frames | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
total_frames = len(vr) | |
frame_indices = list(range(0, total_frames, frame_interval)) | |
frames = vr.get_batch(frame_indices).asnumpy() | |
# Convert frames to PIL Images | |
frame_images = [Image.fromarray(frame) for frame in frames] | |
# Prepare messages for all frames | |
msgs = [ | |
{'role': 'user', 'content': [frame, "Describe the main action in this scene."]} | |
for frame in frame_images | |
] | |
# Generate response for all frames at once | |
response = self.model.chat( | |
image=None, | |
msgs=msgs, | |
tokenizer=self.tokenizer, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
max_new_tokens=max_new_tokens, | |
stream=stream, | |
sampling=sampling | |
) | |
return response | |
except Exception as e: | |
return f"Error processing video: {str(e)}" |