VisionScout / clip_analyzer.py
DawnC's picture
Add new feature "Video Process" and fix format issue
c0fe80d verified
raw
history blame
16.5 kB
import torch
import clip
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Any, Optional, Union
from clip_prompts import (
SCENE_TYPE_PROMPTS,
CULTURAL_SCENE_PROMPTS,
COMPARATIVE_PROMPTS,
LIGHTING_CONDITION_PROMPTS,
SPECIALIZED_SCENE_PROMPTS,
VIEWPOINT_PROMPTS,
OBJECT_COMBINATION_PROMPTS,
ACTIVITY_PROMPTS
)
class CLIPAnalyzer:
"""
Use Clip to intergrate scene understanding function
"""
def __init__(self, model_name: str = "ViT-B/32", device: str = None):
"""
初始化 CLIP 分析器。
Args:
model_name: CLIP Model name, "ViT-B/32"、"ViT-B/16"、"ViT-L/14"
device: Use GPU if it can use
"""
# 自動選擇設備
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
print(f"Loading CLIP model {model_name} on {self.device}...")
try:
self.model, self.preprocess = clip.load(model_name, device=self.device)
print(f"CLIP model loaded successfully.")
except Exception as e:
print(f"Error loading CLIP model: {e}")
raise
self.scene_type_prompts = SCENE_TYPE_PROMPTS
self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS
self.comparative_prompts = COMPARATIVE_PROMPTS
self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS
self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS
self.viewpoint_prompts = VIEWPOINT_PROMPTS
self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS
self.activity_prompts = ACTIVITY_PROMPTS
# turn to CLIP format
self._prepare_text_prompts()
def _prepare_text_prompts(self):
"""準備所有文本提示的 CLIP 特徵"""
# base prompt
scene_texts = [self.scene_type_prompts[scene_type] for scene_type in self.scene_type_prompts]
self.scene_type_tokens = clip.tokenize(scene_texts).to(self.device)
# cultural
self.cultural_tokens_dict = {}
for scene_type, prompts in self.cultural_scene_prompts.items():
self.cultural_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device)
# Light
lighting_texts = [self.lighting_condition_prompts[cond] for cond in self.lighting_condition_prompts]
self.lighting_tokens = clip.tokenize(lighting_texts).to(self.device)
# specializes_status
self.specialized_tokens_dict = {}
for scene_type, prompts in self.specialized_scene_prompts.items():
self.specialized_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device)
# view point
viewpoint_texts = [self.viewpoint_prompts[viewpoint] for viewpoint in self.viewpoint_prompts]
self.viewpoint_tokens = clip.tokenize(viewpoint_texts).to(self.device)
# object combination
object_combination_texts = [self.object_combination_prompts[combo] for combo in self.object_combination_prompts]
self.object_combination_tokens = clip.tokenize(object_combination_texts).to(self.device)
# activicty prompt
activity_texts = [self.activity_prompts[activity] for activity in self.activity_prompts]
self.activity_tokens = clip.tokenize(activity_texts).to(self.device)
def analyze_image(self, image, include_cultural_analysis: bool = True) -> Dict[str, Any]:
"""
分析圖像,預測場景類型和光照條件。
Args:
image: 輸入圖像 (PIL Image 或 numpy array)
include_cultural_analysis: 是否包含文化場景的詳細分析
Returns:
Dict: 包含場景類型預測和光照條件的分析結果
"""
try:
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理圖像
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
# 獲取圖像特徵
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 分析場景類型
scene_scores = self._analyze_scene_type(image_features)
# 分析光照條件
lighting_scores = self._analyze_lighting_condition(image_features)
# 文化場景的增強分析
cultural_analysis = {}
if include_cultural_analysis:
for scene_type in self.cultural_scene_prompts:
if scene_type in scene_scores and scene_scores[scene_type] > 0.2:
cultural_analysis[scene_type] = self._analyze_cultural_scene(
image_features, scene_type
)
specialized_analysis = {}
for scene_type in self.specialized_scene_prompts:
if scene_type in scene_scores and scene_scores[scene_type] > 0.2:
specialized_analysis[scene_type] = self._analyze_specialized_scene(
image_features, scene_type
)
viewpoint_scores = self._analyze_viewpoint(image_features)
object_combination_scores = self._analyze_object_combinations(image_features)
activity_scores = self._analyze_activities(image_features)
# display results
result = {
"scene_scores": scene_scores,
"top_scene": max(scene_scores.items(), key=lambda x: x[1]),
"lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]),
"embedding": image_features.cpu().numpy().tolist()[0] if self.device == "cuda" else image_features.numpy().tolist()[0],
"viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]),
"object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3],
"activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3]
}
if cultural_analysis:
result["cultural_analysis"] = cultural_analysis
if specialized_analysis:
result["specialized_analysis"] = specialized_analysis
return result
except Exception as e:
print(f"Error analyzing image with CLIP: {e}")
import traceback
traceback.print_exc()
return {"error": str(e)}
def _analyze_scene_type(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像特徵與各場景類型的相似度"""
with torch.no_grad():
# 計算場景類型文本特徵
text_features = self.model.encode_text(self.scene_type_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立場景分數字典
scene_scores = {}
for i, scene_type in enumerate(self.scene_type_prompts.keys()):
scene_scores[scene_type] = float(similarity[i])
return scene_scores
def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的光照條件"""
with torch.no_grad():
# 計算光照條件文本特徵
text_features = self.model.encode_text(self.lighting_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立光照條件分數字典
lighting_scores = {}
for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()):
lighting_scores[lighting_type] = float(similarity[i])
return lighting_scores
def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定文化場景進行深入分析"""
if scene_type not in self.cultural_tokens_dict:
return {"error": f"No cultural analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定文化場景的文本特徵
cultural_tokens = self.cultural_tokens_dict[scene_type]
text_features = self.model.encode_text(cultural_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的文化描述
prompts = self.cultural_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]:
"""針對特定專門場景進行深入分析"""
if scene_type not in self.specialized_tokens_dict:
return {"error": f"No specialized analysis available for {scene_type}"}
with torch.no_grad():
# 獲取特定專門場景的文本特徵
specialized_tokens = self.specialized_tokens_dict[scene_type]
text_features = self.model.encode_text(specialized_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 找到最匹配的專門描述
prompts = self.specialized_scene_prompts[scene_type]
scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))]
scores.sort(key=lambda x: x[1], reverse=True)
return {
"best_description": scores[0][0],
"confidence": scores[0][1],
"all_matches": scores
}
def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像的拍攝視角"""
with torch.no_grad():
# 計算視角文本特徵
text_features = self.model.encode_text(self.viewpoint_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立視角分數字典
viewpoint_scores = {}
for i, viewpoint in enumerate(self.viewpoint_prompts.keys()):
viewpoint_scores[viewpoint] = float(similarity[i])
return viewpoint_scores
def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的物體組合"""
with torch.no_grad():
# 計算物體組合文本特徵
text_features = self.model.encode_text(self.object_combination_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立物體組合分數字典
combination_scores = {}
for i, combination in enumerate(self.object_combination_prompts.keys()):
combination_scores[combination] = float(similarity[i])
return combination_scores
def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]:
"""分析圖像中的活動"""
with torch.no_grad():
# 計算活動文本特徵
text_features = self.model.encode_text(self.activity_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 計算相似度分數
similarity = (100 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# 建立活動分數字典
activity_scores = {}
for i, activity in enumerate(self.activity_prompts.keys()):
activity_scores[activity] = float(similarity[i])
return activity_scores
def get_image_embedding(self, image) -> np.ndarray:
"""
獲取圖像的 CLIP 嵌入表示
Args:
image: PIL Image 或 numpy array
Returns:
np.ndarray: 圖像的 CLIP 特徵向量
"""
# 確保圖像是 PIL 格式
if not isinstance(image, Image.Image):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
raise ValueError("Unsupported image format. Expected PIL Image or numpy array.")
# 預處理並編碼
image_input = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# 轉換為 numpy 並返回
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
def text_to_embedding(self, text: str) -> np.ndarray:
"""
將文本轉換為 CLIP 嵌入表示
Args:
text: 輸入文本
Returns:
np.ndarray: 文本的 CLIP 特徵向量
"""
text_token = clip.tokenize([text]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_token)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0]
def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]:
"""
計算圖像與多個文本查詢的相似度
Args:
image: PIL Image 或 numpy array
text_queries: 文本查詢列表
Returns:
Dict: 每個查詢的相似度分數
"""
# 獲取圖像嵌入
if isinstance(image, np.ndarray) and len(image.shape) == 1:
# 已經是嵌入向量
image_features = torch.tensor(image).unsqueeze(0).to(self.device)
else:
# 是圖像,需要提取嵌入
image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device)
# calulate similarity
text_tokens = clip.tokenize(text_queries).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(text_tokens)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0]
# display results
result = {}
for i, query in enumerate(text_queries):
result[query] = float(similarity[i])
return result