weiyi01191's picture
Update app.py
4bdf408 verified
raw
history blame
21.8 kB
#!/usr/bin/env python3
"""
🎥 Video Content Safety Analysis - MiniGPT4-Video + 巨量引擎规则集成版
基于MiniGPT4-Video的真实视频内容分析 + 巨量引擎299条禁投规则检测
"""
# ZeroGPU装饰器 - 必须在torch等包之前导入!
try:
import spaces
GPU_AVAILABLE = True
print("✅ ZeroGPU spaces 可用")
except ImportError:
print("⚠️ ZeroGPU spaces 不可用,使用CPU模式")
GPU_AVAILABLE = False
# 创建一个空的装饰器
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
import os
import gradio as gr
import torch
import gc
import whisper
import argparse
import yaml
import random
import numpy as np
import torch.backends.cudnn as cudnn
from minigpt4.common.eval_utils import init_model
from minigpt4.conversation.conversation import CONV_VISION
import tempfile
import shutil
import cv2
import webvtt
import moviepy.editor as mp
from torchvision import transforms
from datetime import timedelta
from moviepy.editor import VideoFileClip
# 导入巨量引擎禁投规则引擎
from prohibited_rules import ProhibitedRulesEngine
# 设置中国镜像
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 全局变量
model = None
vis_processor = None
whisper_model = None
args = None
seed = 42
# 初始化巨量引擎规则引擎
rules_engine = ProhibitedRulesEngine()
print("✅ 巨量引擎299条禁投规则引擎初始化完成")
# ======================== MiniGPT4-Video 核心函数 ========================
def format_timestamp(seconds):
"""格式化时间戳为VTT格式"""
td = timedelta(seconds=seconds)
total_seconds = int(td.total_seconds())
milliseconds = int(td.microseconds / 1000)
hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
def extract_video_info(video_path, max_images_length):
"""提取视频信息"""
clip = VideoFileClip(video_path)
total_num_frames = int(clip.duration * clip.fps)
clip.close()
sampling_interval = int(total_num_frames / max_images_length)
if sampling_interval == 0:
sampling_interval = 1
return sampling_interval, clip.fps
def time_to_milliseconds(time_str):
"""将时间格式转换为毫秒"""
h, m, s = map(float, time_str.split(':'))
return int((h * 3600 + m * 60 + s) * 1000)
def extract_subtitles(subtitle_path):
"""提取字幕"""
if not subtitle_path or not os.path.exists(subtitle_path):
return []
subtitles = []
try:
for caption in webvtt.read(subtitle_path):
start_ms = time_to_milliseconds(caption.start)
end_ms = time_to_milliseconds(caption.end)
text = caption.text.strip().replace('\n', ' ')
subtitles.append((start_ms, end_ms, text))
except:
return []
return subtitles
def find_subtitle(subtitles, frame_count, fps):
"""查找对应帧的字幕"""
if not subtitles:
return None
frame_time = (frame_count / fps) * 1000
left, right = 0, len(subtitles) - 1
while left <= right:
mid = (left + right) // 2
start, end, subtitle_text = subtitles[mid]
if start <= frame_time <= end:
return subtitle_text
elif frame_time < start:
right = mid - 1
else:
left = mid + 1
return None
def match_frames_and_subtitles(video_path, subtitles, sampling_interval, max_sub_len, fps, max_frames):
"""匹配视频帧和字幕"""
global vis_processor
cap = cv2.VideoCapture(video_path)
images = []
frame_count = 0
img_placeholder = ""
subtitle_text_in_interval = ""
history_subtitles = {}
number_of_words = 0
transform = transforms.Compose([
transforms.ToPILImage(),
])
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if len(subtitles) > 0:
frame_subtitle = find_subtitle(subtitles, frame_count, fps)
if frame_subtitle and not history_subtitles.get(frame_subtitle, False):
subtitle_text_in_interval += frame_subtitle + " "
history_subtitles[frame_subtitle] = True
if frame_count % sampling_interval == 0:
frame = transform(frame[:,:,::-1]) # 转换为RGB
frame = vis_processor(frame)
images.append(frame)
img_placeholder += '<Img><ImageHere>'
if subtitle_text_in_interval != "" and number_of_words < max_sub_len:
img_placeholder += f'<Cap>{subtitle_text_in_interval}'
number_of_words += len(subtitle_text_in_interval.split(' '))
subtitle_text_in_interval = ""
frame_count += 1
if len(images) >= max_frames:
break
cap.release()
cv2.destroyAllWindows()
if len(images) == 0:
return None, None
images = torch.stack(images)
return images, img_placeholder
def extract_audio(video_path, audio_path):
"""提取音频"""
video_clip = mp.VideoFileClip(video_path)
audio_clip = video_clip.audio
audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k", verbose=False, logger=None)
video_clip.close()
def get_subtitles(video_path):
"""生成字幕"""
global whisper_model
if whisper_model is None:
return None
audio_dir = "workspace/inference_subtitles/mp3"
subtitle_dir = "workspace/inference_subtitles"
os.makedirs(subtitle_dir, exist_ok=True)
os.makedirs(audio_dir, exist_ok=True)
video_id = video_path.split('/')[-1].split('.')[0]
audio_path = f"{audio_dir}/{video_id}.mp3"
subtitle_path = f"{subtitle_dir}/{video_id}.vtt"
# 如果字幕已存在,直接返回
if os.path.exists(subtitle_path):
return subtitle_path
try:
extract_audio(video_path, audio_path)
result = whisper_model.transcribe(audio_path, language="en")
# 创建VTT文件
with open(subtitle_path, "w", encoding="utf-8") as vtt_file:
vtt_file.write("WEBVTT\n\n")
for segment in result['segments']:
start = format_timestamp(segment['start'])
end = format_timestamp(segment['end'])
text = segment['text']
vtt_file.write(f"{start} --> {end}\n{text}\n\n")
return subtitle_path
except Exception as e:
print(f"字幕生成错误: {e}")
return None
def prepare_input(video_path, subtitle_path, instruction):
"""准备输入"""
global args
# 根据模型设置参数
if args and "mistral" in args.ckpt:
max_frames = 90
max_sub_len = 800
else:
max_frames = 45
max_sub_len = 400
sampling_interval, fps = extract_video_info(video_path, max_frames)
subtitles = extract_subtitles(subtitle_path)
frames_features, input_placeholder = match_frames_and_subtitles(
video_path, subtitles, sampling_interval, max_sub_len, fps, max_frames
)
if input_placeholder:
input_placeholder += "\n" + instruction
else:
input_placeholder = instruction
return frames_features, input_placeholder
def model_generate(*model_args, **kwargs):
"""模型生成函数"""
global model
with model.maybe_autocast():
output = model.llama_model.generate(*model_args, **kwargs)
return output
def generate_prediction(video_path, instruction, gen_subtitles=True, stream=False):
"""生成预测结果"""
global model, args, seed
if gen_subtitles:
subtitle_path = get_subtitles(video_path)
else:
subtitle_path = None
prepared_images, prepared_instruction = prepare_input(video_path, subtitle_path, instruction)
if prepared_images is None:
return "视频无法打开,请检查视频路径"
length = len(prepared_images)
prepared_images = prepared_images.unsqueeze(0)
conv = CONV_VISION.copy()
conv.system = ""
conv.append_message(conv.roles[0], prepared_instruction)
conv.append_message(conv.roles[1], None)
prompt = [conv.get_prompt()]
# 设置随机种子
setup_seeds(seed)
try:
answers = model.generate(
prepared_images,
prompt,
max_new_tokens=args.max_new_tokens if args else 512,
do_sample=True,
lengths=[length],
num_beams=1
)
return answers[0]
except Exception as e:
return f"生成预测时出错: {str(e)}"
# ======================== 巨量引擎规则检测函数 ========================
def format_violations_report(violations_result):
"""格式化违规检测报告"""
if not violations_result["has_violations"]:
return """
🛡️ **巨量引擎规则检测结果**: ✅ 无违规内容
- 已检测规则: 299条巨量引擎禁投规则
- 检测维度: 低危(P1) + 中危(P2) + 高危(P3)
- 检测结果: 内容符合平台规范
"""
report = f"""
🚨 **巨量引擎规则检测结果**: ⚠️ 发现 {violations_result["total_violations"]} 项违规
📊 **违规统计**:
- 🔴 高危违规(P3): {violations_result["high_risk"]["count"]}
- 🟡 中危违规(P2): {violations_result["medium_risk"]["count"]}
- 🟠 低危违规(P1): {violations_result["low_risk"]["count"]}
📋 **详细违规列表**:
"""
# 按风险等级排序显示违规
for violation in sorted(violations_result["all_violations"],
key=lambda x: {"P3": 3, "P2": 2, "P1": 1}[x["risk_level"]],
reverse=True):
risk_icon = {"P3": "🚨", "P2": "⚠️", "P1": "💭"}[violation["risk_level"]]
report += f"""
{risk_icon} **{violation["risk_level"]} - {violation["category"]}**
规则: {violation["description"]}
匹配词: "{violation["matched_keyword"]}"
规则ID: {violation["rule_id"]}
"""
return report
def get_overall_risk_level(violations_result):
"""获取综合风险等级"""
if not violations_result["has_violations"]:
return "✅ P3 (安全) - 内容健康,符合平台规范"
if violations_result["high_risk"]["count"] > 0:
return f"🚨 P0 (极高危) - 发现 {violations_result['high_risk']['count']} 项高危违规,禁止投放"
elif violations_result["medium_risk"]["count"] > 2:
return f"⚠️ P1 (高危) - 发现 {violations_result['medium_risk']['count']} 项中危违规,需严格审核"
elif violations_result["medium_risk"]["count"] > 0:
return f"⚠️ P1 (中危) - 发现 {violations_result['medium_risk']['count']} 项中危违规,需要审核"
else:
return f"⚡ P2 (低危) - 发现 {violations_result['low_risk']['count']} 项低危违规,建议关注"
# ======================== 应用主要函数 ========================
def setup_seeds(seed):
"""设置随机种子"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
def optimize_gpu_memory():
"""GPU内存优化"""
print("🔍 开始GPU内存优化...")
# 设置环境变量优化内存分配
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256,garbage_collection_threshold:0.6'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
if torch.cuda.is_available():
print(f"🔍 GPU: {torch.cuda.get_device_name(0)}")
print(f"💾 总显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
# 强制清理所有GPU缓存
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
# 设置内存增长策略
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print(f"💾 清理后可用显存: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3:.1f} GB")
def get_arguments():
"""获取参数配置"""
parser = argparse.ArgumentParser(description="MiniGPT4-Video参数")
parser.add_argument("--cfg-path", help="配置文件路径",
default="test_configs/mistral_test_config.yaml") # 使用mistral配置
parser.add_argument("--ckpt", type=str,
default='checkpoints/video_mistral_checkpoint_last.pth', # 使用mistral checkpoint
help="模型检查点路径")
parser.add_argument("--max_new_tokens", type=int, default=512,
help="最大生成token数")
parser.add_argument("--lora_r", type=int, default=64, help="LoRA rank") # 修改为64匹配checkpoint
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha") # 修改为16匹配checkpoint
parser.add_argument("--options", nargs="+", help="覆盖配置选项")
return parser.parse_args()
def load_minigpt4_model():
"""加载MiniGPT4-Video模型"""
global model, vis_processor, whisper_model, args, seed
if model is not None:
return model, vis_processor, whisper_model
try:
print("🔄 正在加载MiniGPT4-Video模型...")
# 获取参数
args = get_arguments()
# 加载配置
config_path = args.cfg_path
if not os.path.exists(config_path):
config_path = "test_configs/llama2_test_config.yaml" # 回退到默认配置
with open(config_path) as file:
config = yaml.load(file, Loader=yaml.FullLoader)
seed = config['run']['seed']
setup_seeds(seed)
# GPU内存优化
optimize_gpu_memory()
print("🚀 开始初始化MiniGPT4-Video模型...")
model, vis_processor, whisper_gpu_id, minigpt4_gpu_id, answer_module_gpu_id = init_model(args)
# 清理缓存
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"💾 模型加载后显存使用: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
print("🚀 开始初始化Whisper模型...")
whisper_model = whisper.load_model("base").to(f"cuda:{whisper_gpu_id}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(f"💾 全部加载后显存使用: {torch.cuda.memory_allocated(0) / 1024**3:.1f} GB")
print("✅ 所有模型加载完成!")
return model, vis_processor, whisper_model
except Exception as e:
print(f"❌ 模型加载失败: {e}")
print("🔄 回退到模拟模式...")
return None, None, None
@spaces.GPU(duration=600) # 增加到10分钟以支持模型下载
def analyze_video_with_minigpt4(video_file, instruction):
"""使用MiniGPT4-Video分析视频内容并进行巨量引擎规则检测"""
if video_file is None:
return "❌ 请上传视频文件", "无法评估"
try:
# 加载模型
model_loaded, vis_proc, whisper_loaded = load_minigpt4_model()
if model_loaded is None:
# 模拟模式
return f"""
🎬 **视频内容分析结果 (模拟模式)**
📋 **基本信息**:
- 视频文件: {video_file}
- 分析指令: {instruction}
⚠️ **注意**: 当前运行在模拟模式,真实模型加载失败
请检查模型文件和配置是否正确
🛡️ **巨量引擎规则检测**: 仅在真实模式下可用
""", "⚠️ 模拟模式"
print(f"🔄 开始分析视频: {video_file}")
print(f"📝 分析指令: {instruction}")
# 复制视频到临时路径(如果需要)
temp_video_path = video_file
if not os.path.exists(video_file):
# 如果是Gradio的临时文件,复制到工作目录
temp_dir = "workspace/tmp"
os.makedirs(temp_dir, exist_ok=True)
temp_video_path = os.path.join(temp_dir, "analysis_video.mp4")
shutil.copy2(video_file, temp_video_path)
# 使用MiniGPT4-Video进行真实分析
if not instruction or instruction.strip() == "":
instruction = "请详细分析这个视频的内容,包括场景、人物、动作、对话等,并描述所有可见和可听的元素。"
# 调用MiniGPT4-Video的生成函数
prediction = generate_prediction(
video_path=temp_video_path,
instruction=instruction,
gen_subtitles=True, # 生成字幕
stream=False
)
# 🚨 巨量引擎规则检测 🚨
print("🔍 开始巨量引擎299条规则检测...")
violations_result = rules_engine.check_all_content(prediction, instruction)
# 格式化完整分析报告
enhanced_result = f"""
🎬 **MiniGPT4-Video 视频内容分析 + 巨量引擎规则检测报告**
📋 **基本信息**:
- 视频文件: {os.path.basename(video_file)}
- 分析设备: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU模式'}
- 分析指令: {instruction}
🔍 **视频内容描述**:
{prediction}
{format_violations_report(violations_result)}
📊 **技术信息**:
- 内容理解: MiniGPT4-Video + Whisper
- 规则引擎: 巨量引擎299条禁投规则
- 检测等级: P1(低危) + P2(中危) + P3(高危)
- 分析模式: 多模态理解 (视觉+语音+文本)
💡 **说明**:
基于MiniGPT4-Video的深度内容理解,结合巨量引擎完整禁投规则库进行专业违规检测。
"""
# 获取综合风险等级
safety_score = get_overall_risk_level(violations_result)
return enhanced_result, safety_score
except Exception as e:
error_msg = f"""
❌ **分析过程中出错**
错误信息: {str(e)}
🔄 **可能的解决方案**:
1. 检查视频文件格式 (建议MP4)
2. 确认模型文件是否正确加载
3. 检查GPU内存是否充足
4. 验证配置文件路径
💡 **提示**: 如果问题持续,请检查模型和依赖项安装
"""
return error_msg, "⚠️ 错误"
def create_app():
"""创建Gradio应用"""
interface = gr.Interface(
fn=analyze_video_with_minigpt4,
inputs=[
gr.Video(label="上传视频文件"),
gr.Textbox(
label="分析指令",
value="请详细分析这个视频的内容,包括场景、人物、动作、对话等,并描述所有可见和可听的元素。",
placeholder="输入您希望AI如何分析这个视频...",
lines=3
)
],
outputs=[
gr.Textbox(label="MiniGPT4-Video 内容分析 + 巨量引擎规则检测", lines=20),
gr.Textbox(label="巨量引擎风险评级")
],
title="🎥 智能视频内容安全分析 - MiniGPT4-Video + 巨量引擎",
description="""
## 🎬 基于MiniGPT4-Video + 巨量引擎299条禁投规则的专业视频安全检测系统
⚡ **ZeroGPU加速** | 🎬 **MiniGPT4-Video** | 🎙️ **Whisper语音** | 🛡️ **巨量引擎299条规则**
**🔥 核心功能:**
- 🎞️ **深度视频理解**: MiniGPT4-Video多模态分析
- 🎙️ **语音转文字**: Whisper自动生成字幕
- 🛡️ **专业违规检测**: 巨量引擎完整禁投规则库
- 📊 **智能风险评级**: P0-P3四级风险等级
**🎯 检测维度:**
- **高危(P3)**: 违法出版物、烟草、医疗等严重违规
- **中危(P2)**: 赌博周边、房地产、金融等中等风险
- **低危(P1)**: 化妆品、汽车、游戏等轻微风险
**📋 规则覆盖:**
涵盖化妆品类、汽车类、游戏类、赌博类、房地产类、工具软件类、教育培训类、
金融类、医疗类、烟草类等全部299条巨量引擎禁投规则
""",
examples=[
[None, "分析这个视频是否包含禁投内容"],
[None, "检测视频中是否有巨量引擎禁止的产品或服务"],
[None, "评估视频内容的投放风险等级"],
[None, "详细描述视频内容并进行合规检查"]
],
cache_examples=False
)
return interface
def main():
"""主函数"""
print("🚀 启动MiniGPT4-Video + 巨量引擎视频安全分析应用")
print("🎬 MiniGPT4-Video: 深度视频内容理解")
print("🛡️ 巨量引擎: 299条禁投规则检测")
if torch.cuda.is_available():
print(f"✅ GPU可用: {torch.cuda.get_device_name(0)}")
else:
print("⚠️ 使用CPU模式")
# 创建必要的目录
os.makedirs("workspace/tmp", exist_ok=True)
os.makedirs("workspace/inference_subtitles", exist_ok=True)
os.makedirs("workspace/inference_subtitles/mp3", exist_ok=True)
print("📁 工作目录准备完成")
print("🚀 正在启动Gradio应用...")
app = create_app()
# 启动应用
app.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
if __name__ == "__main__":
main()