Spaces:
Running
Running
File size: 7,268 Bytes
ade9ea5 28b7868 ade9ea5 667e965 ade9ea5 66ece16 de5e541 ade9ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import streamlit as st
import chromadb
from chromadb.config import Settings
from transformers import CLIPProcessor, CLIPModel
import cv2
from PIL import Image
import torch
import logging
import uuid
import tempfile
import os
import requests
import json
from dotenv import load_dotenv
import shutil
load_dotenv()
HF_TOKEN = os.getenv('hf_token')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
@st.cache_resource
def load_model():
device = 'cpu'
processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-large-patch14", token=HF_TOKEN)
model = CLIPModel.from_pretrained(
"openai/clip-vit-large-patch14", token=HF_TOKEN)
model.eval().to(device)
return processor, model
@st.cache_resource
def load_chromadb():
chroma_client = chromadb.PersistentClient(
path='Data', settings=Settings(anonymized_telemetry=False))
collection = chroma_client.get_or_create_collection(name='images')
return chroma_client, collection
def resize_image(image_path, size=(224, 224)):
if isinstance(image_path, str):
img = Image.open(image_path).convert("RGB")
else:
img = Image.open(image_path).convert("RGB")
img_resized = img.resize(size, Image.LANCZOS)
return img_resized
def get_image_embedding(image, model, preprocess, device='cpu'):
image = Image.open(f'{image}').convert('RGB')
input_tensor = preprocess(images=[image], return_tensors='pt')[
'pixel_values'].to(device)
with torch.no_grad():
embedding = model.get_image_features(
pixel_values=input_tensor)
return torch.nn.functional.normalize(embedding, p=2, dim=1)
def extract_frames(v_path, frame_interval=30):
cap = cv2.VideoCapture(v_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
total_seconds = frame_count//frame_rate
frame_idx = 0
saved_frames = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_interval == 0:
unique_image_id = str(uuid.uuid4())
frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg"
cv2.imwrite(frame_name, frame)
saved_frames += 1
frame_idx += 1
cap.release()
logger.info("Frames extracted")
def insert_into_db(collection, dir):
embedding_list = []
file_names = []
ids = []
with st.status("Generating embedding... ⏳", expanded=True) as status:
for i in os.listdir(dir):
embedding = get_image_embedding(
f"{dir}/{i}", model, processor)
embedding_list.append(
embedding.squeeze(0).numpy().tolist())
file_names.append(
{'path': f"{dir}/{i}", 'type': 'photo'})
unique_id = str(uuid.uuid4())
ids.append(unique_id)
status.update(label="Embedding generation complete",
state="complete")
collection.add(
embeddings=embedding_list,
ids=ids,
metadatas=file_names
)
logger.info("Data inserted into DB")
processor, model = load_model()
logger.info("Model and processor loaded")
client, collection = load_chromadb()
logger.info("ChromaDB loaded")
logger.info(
f"Connected to ChromaDB collection images with {collection.count()} items")
temp_dir = 'temp_folder'
if 'cleaned_temp' not in st.session_state:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
st.session_state.cleaned_temp = True
results=collection.get(include=["metadatas"])
ids_to_delete=[
_id for _id,metadata in zip(results["ids"],results['metadatas']) if metadata.get("path","").startswith("temp")
]
if ids_to_delete:
collection.delete(ids=ids_to_delete)
st.title("Extract frames from video using text")
# Upload section
st.sidebar.subheader("Upload video")
video_file = st.sidebar.file_uploader(
"Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False
)
num_images = st.sidebar.slider(
"Number of images to be shown", min_value=1, max_value=10, value=3)
if video_file:
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile:
tmpfile.write(video_file.read())
video_path = tmpfile.name
st.video(video_path)
st.sidebar.subheader("Add uploaded videos to collection")
if st.sidebar.button("Add uploaded video"):
extract_frames(video_path)
insert_into_db(collection, temp_dir)
else:
video_path = 'Videos/Video.mp4'
st.video(video_path)
st.write(
f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates")
st.write("Enter the description of image to be extracted from the video")
text_input = st.text_input("Description", "Flying Skater")
if st.button("Search"):
if text_input.strip():
params = {'text': text_input.strip()}
response = requests.get(
'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
if response.status_code == 200:
logger.info("Embedding returned by API successfully")
data = json.loads(response.content)
embedding = data['embedding']
results = collection.query(
query_embeddings=[embedding],
n_results=num_images
)
images = [results['metadatas'][0][i]['path']
for i in range(len(results['metadatas'][0]))]
distances = [results['distances'][0][i]
for i in range(len(results['metadatas'][0]))]
if images:
cols_per_row = 3
rows = (len(images)+cols_per_row-1)//cols_per_row
for row in range(rows):
cols = st.columns(cols_per_row)
for col_idx, col in enumerate(cols):
img_idx = row*cols_per_row+col_idx
if img_idx < len(images):
resized_img = resize_image(
images[img_idx])
col.image(resized_img,
caption=f"Image {img_idx+1}", use_container_width=True)
else:
st.write("No image found")
else:
st.write("Please try again later")
logger.info(f"status code {response.status_code} returned")
else:
st.write("Please enter text in the text area")
except Exception as e:
logger.exception(f"Exception occured, {e}")
|