Spaces:
Sleeping
Sleeping
import streamlit as st | |
import chromadb | |
from chromadb.config import Settings | |
from transformers import CLIPProcessor, CLIPModel | |
import cv2 | |
from PIL import Image | |
import torch | |
import logging | |
import uuid | |
import tempfile | |
import os | |
import requests | |
import json | |
from dotenv import load_dotenv | |
import shutil | |
load_dotenv() | |
HF_TOKEN = os.getenv('hf_token') | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
try: | |
def load_model(): | |
device = 'cpu' | |
processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-large-patch14", token=HF_TOKEN) | |
model = CLIPModel.from_pretrained( | |
"openai/clip-vit-large-patch14", token=HF_TOKEN) | |
model.eval().to(device) | |
return processor, model | |
def load_chromadb(): | |
chroma_client = chromadb.PersistentClient( | |
path='Data', settings=Settings(anonymized_telemetry=False)) | |
collection = chroma_client.get_or_create_collection(name='images') | |
return chroma_client, collection | |
def resize_image(image_path, size=(224, 224)): | |
if isinstance(image_path, str): | |
img = Image.open(image_path).convert("RGB") | |
else: | |
img = Image.open(image_path).convert("RGB") | |
img_resized = img.resize(size, Image.LANCZOS) | |
return img_resized | |
def get_image_embedding(image, model, preprocess, device='cpu'): | |
image = Image.open(f'{image}').convert('RGB') | |
input_tensor = preprocess(images=[image], return_tensors='pt')[ | |
'pixel_values'].to(device) | |
with torch.no_grad(): | |
embedding = model.get_image_features( | |
pixel_values=input_tensor) | |
return torch.nn.functional.normalize(embedding, p=2, dim=1) | |
def extract_frames(v_path, frame_interval=30): | |
cap = cv2.VideoCapture(v_path) | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_rate = int(cap.get(cv2.CAP_PROP_FPS)) | |
total_seconds = frame_count//frame_rate | |
frame_idx = 0 | |
saved_frames = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_idx % frame_interval == 0: | |
unique_image_id = str(uuid.uuid4()) | |
frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg" | |
cv2.imwrite(frame_name, frame) | |
saved_frames += 1 | |
frame_idx += 1 | |
cap.release() | |
logger.info("Frames extracted") | |
def insert_into_db(collection, dir): | |
embedding_list = [] | |
file_names = [] | |
ids = [] | |
with st.status("Generating embedding... ⏳", expanded=True) as status: | |
for i in os.listdir(dir): | |
embedding = get_image_embedding( | |
f"{dir}/{i}", model, processor) | |
embedding_list.append( | |
embedding.squeeze(0).numpy().tolist()) | |
file_names.append( | |
{'path': f"{dir}/{i}", 'type': 'photo'}) | |
unique_id = str(uuid.uuid4()) | |
ids.append(unique_id) | |
status.update(label="Embedding generation complete", | |
state="complete") | |
collection.add( | |
embeddings=embedding_list, | |
ids=ids, | |
metadatas=file_names | |
) | |
logger.info("Data inserted into DB") | |
processor, model = load_model() | |
logger.info("Model and processor loaded") | |
client, collection = load_chromadb() | |
logger.info("ChromaDB loaded") | |
logger.info( | |
f"Connected to ChromaDB collection images with {collection.count()} items") | |
temp_dir = 'temp_folder' | |
if 'cleaned_temp' not in st.session_state: | |
if os.path.exists(temp_dir): | |
shutil.rmtree(temp_dir) | |
os.makedirs(temp_dir, exist_ok=True) | |
st.session_state.cleaned_temp = True | |
results=collection.get(include=["metadatas"]) | |
ids_to_delete=[ | |
_id for _id,metadata in zip(results["ids"],results['metadatas']) if metadata.get("path","").startswith("temp") | |
] | |
if ids_to_delete: | |
collection.delete(ids=ids_to_delete) | |
st.title("Extract frames from video using text") | |
# Upload section | |
st.sidebar.subheader("Upload video") | |
video_file = st.sidebar.file_uploader( | |
"Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False | |
) | |
num_images = st.sidebar.slider( | |
"Number of images to be shown", min_value=1, max_value=10, value=3) | |
if video_file: | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile: | |
tmpfile.write(video_file.read()) | |
video_path = tmpfile.name | |
st.video(video_path) | |
st.sidebar.subheader("Add uploaded videos to collection") | |
if st.sidebar.button("Add uploaded video"): | |
extract_frames(video_path) | |
insert_into_db(collection, temp_dir) | |
else: | |
video_path = 'Videos/Video.mp4' | |
st.video(video_path) | |
st.write( | |
f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates") | |
st.write("Enter the description of image to be extracted from the video") | |
text_input = st.text_input("Description", "Flying Skater") | |
if st.button("Search"): | |
if text_input.strip(): | |
params = {'text': text_input.strip()} | |
response = requests.get( | |
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params) | |
if response.status_code == 200: | |
logger.info("Embedding returned by API successfully") | |
data = json.loads(response.content) | |
embedding = data['embedding'] | |
results = collection.query( | |
query_embeddings=[embedding], | |
n_results=num_images | |
) | |
images = [results['metadatas'][0][i]['path'] | |
for i in range(len(results['metadatas'][0]))] | |
distances = [results['distances'][0][i] | |
for i in range(len(results['metadatas'][0]))] | |
if images: | |
cols_per_row = 3 | |
rows = (len(images)+cols_per_row-1)//cols_per_row | |
for row in range(rows): | |
cols = st.columns(cols_per_row) | |
for col_idx, col in enumerate(cols): | |
img_idx = row*cols_per_row+col_idx | |
if img_idx < len(images): | |
resized_img = resize_image( | |
images[img_idx]) | |
col.image(resized_img, | |
caption=f"Image {img_idx+1}", use_container_width=True) | |
else: | |
st.write("No image found") | |
else: | |
st.write("Please try again later") | |
logger.info(f"status code {response.status_code} returned") | |
else: | |
st.write("Please enter text in the text area") | |
except Exception as e: | |
logger.exception(f"Exception occured, {e}") | |