arjunanand13's picture
Update app.py
22307f0 verified
raw
history blame
4.94 kB
import torch
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration, AutoProcessor
import gc
MODEL_ID = "arjunanand13/gas_pipe_llava_finetunedv3"
def clear_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def extract_frames_from_video(video_path, num_frames=4):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames < num_frames:
num_frames = min(total_frames, num_frames)
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for frame_idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
frame_resized = frame_pil.resize((112, 112), Image.Resampling.LANCZOS)
frames.append(frame_resized)
cap.release()
while len(frames) < 4:
if frames:
frames.append(frames[-1].copy())
else:
frames.append(Image.new('RGB', (112, 112), color='black'))
return frames[:4]
def create_frame_grid(frames, grid_size=(2, 2)):
cols, rows = grid_size
frame_size = 112
grid_width = frame_size * cols
grid_height = frame_size * rows
grid_image = Image.new('RGB', (grid_width, grid_height))
for i, frame in enumerate(frames):
row = i // cols
col = i % cols
x = col * frame_size
y = row * frame_size
grid_image.paste(frame, (x, y))
return grid_image
@torch.no_grad()
def load_model():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.uint8
)
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"
processor.tokenizer.pad_token = processor.tokenizer.eos_token
model = LlavaNextForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
quantization_config=bnb_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True
)
model.config.use_cache = False
model.eval()
return model, processor
model, processor = load_model()
def predict_gas_pipe_quality(video_path):
try:
frames = extract_frames_from_video(video_path, num_frames=4)
grid_image = create_frame_grid(frames, grid_size=(2, 2))
prompt = "[INST] <image>\nGas pipe test result? [/INST]"
inputs = processor(text=prompt, images=grid_image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs["pixel_values"],
image_sizes=inputs["image_sizes"],
max_new_tokens=16,
do_sample=False,
pad_token_id=processor.tokenizer.eos_token_id
)
prediction = processor.batch_decode(
generated_ids[:, inputs["input_ids"].size(1):],
skip_special_tokens=True
)[0].strip()
clear_memory()
return grid_image, prediction
except Exception as e:
clear_memory()
return None, f"Error: {str(e)}"
def create_interface():
with gr.Blocks(title="Gas Pipe Quality Control") as iface:
gr.Markdown("# Gas Pipe Quality Control")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload Video")
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
frame_grid = gr.Image(label="Extracted Frames")
result_output = gr.Textbox(label="Model Output", lines=3)
analyze_btn.click(
fn=predict_gas_pipe_quality,
inputs=video_input,
outputs=[frame_grid, result_output]
)
video_input.change(
fn=predict_gas_pipe_quality,
inputs=video_input,
outputs=[frame_grid, result_output]
)
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch(share=True)