ashish-001 commited on
Commit
5c38ac9
·
verified ·
1 Parent(s): caa2024

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +178 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import chromadb
3
+ from chromadb.config import Settings
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import cv2
6
+ from PIL import Image
7
+ import torch
8
+ import logging
9
+ import uuid
10
+ import tempfile
11
+ import os
12
+ import requests
13
+ import json
14
+ from dotenv import load_dotenv
15
+ import shutil
16
+
17
+ load_dotenv()
18
+ HF_TOKEN = os.getenv('hf_token')
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+ try:
22
+ temp_dir = 'temp_folder'
23
+ if 'cleaned_temp' not in st.session_state:
24
+ if os.path.exists(temp_dir):
25
+ shutil.rmtree(temp_dir)
26
+ os.makedirs(temp_dir, exist_ok=True)
27
+ st.session_state.cleaned_temp = True
28
+
29
+ @st.cache_resource
30
+ def load_model():
31
+ device = 'cpu'
32
+ processor = CLIPProcessor.from_pretrained(
33
+ "openai/clip-vit-large-patch14", token=HF_TOKEN)
34
+ model = CLIPModel.from_pretrained(
35
+ "openai/clip-vit-large-patch14", token=HF_TOKEN)
36
+ model.eval().to(device)
37
+ return processor, model
38
+
39
+ @st.cache_resource
40
+ def load_chromadb():
41
+ chroma_client = chromadb.Client(
42
+ path='Data', settings=Settings(anonymized_telemetry=False))
43
+ collection = chroma_client.get_or_create_collection(name='images')
44
+ return chroma_client, collection
45
+
46
+ def resize_image(image_path, size=(224, 224)):
47
+ if isinstance(image_path, str):
48
+ img = Image.open(image_path).convert("RGB")
49
+ else:
50
+ img = Image.open(image_path).convert("RGB")
51
+ img_resized = img.resize(size, Image.LANCZOS)
52
+ return img_resized
53
+
54
+ def get_image_embedding(image, model, preprocess, device='cpu'):
55
+ image = Image.open(f'{image}').convert('RGB')
56
+ input_tensor = preprocess(images=[image], return_tensors='pt')[
57
+ 'pixel_values'].to(device)
58
+ with torch.no_grad():
59
+ embedding = model.get_image_features(
60
+ pixel_values=input_tensor)
61
+
62
+ return torch.nn.functional.normalize(embedding, p=2, dim=1)
63
+
64
+ def extract_frames(v_path, frame_interval=30):
65
+ cap = cv2.VideoCapture(v_path)
66
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
67
+ frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
68
+ total_seconds = frame_count//frame_rate
69
+ frame_idx = 0
70
+ saved_frames = 0
71
+ while cap.isOpened():
72
+ ret, frame = cap.read()
73
+ if not ret:
74
+ break
75
+ if frame_idx % frame_interval == 0:
76
+ unique_image_id = str(uuid.uuid4())
77
+ frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg"
78
+ cv2.imwrite(frame_name, frame)
79
+
80
+ saved_frames += 1
81
+ frame_idx += 1
82
+ cap.release()
83
+
84
+ def insert_into_db(collection, dir):
85
+ embedding_list = []
86
+ file_names = []
87
+ ids = []
88
+ with st.status("Generating embedding... ⏳", expanded=True) as status:
89
+ for i in os.listdir(dir):
90
+ embedding = get_image_embedding(
91
+ f"{dir}/{i}", model, processor)
92
+ embedding_list.append(
93
+ embedding.squeeze(0).numpy().tolist())
94
+ file_names.append(
95
+ {'path': f"{dir}/{i}", 'type': 'photo'})
96
+ unique_id = str(uuid.uuid4())
97
+ ids.append(unique_id)
98
+ status.update(label="Embedding generation complete",
99
+ state="complete")
100
+
101
+ collection.add(
102
+ embeddings=embedding_list,
103
+ ids=ids,
104
+ metadatas=file_names
105
+ )
106
+ logger.info("Data inserted into DB")
107
+
108
+ processor, model = load_model()
109
+ logger.info("Model and processor loaded")
110
+ client, collection = load_chromadb()
111
+ logger.info("ChromaDB loaded")
112
+ logger.info(
113
+ f"Connected to ChromaDB collection images with {collection.count()} items")
114
+
115
+ st.title("Extract frames from video using text")
116
+ # Upload section
117
+ st.sidebar.subheader("Upload video")
118
+ video_file = st.sidebar.file_uploader(
119
+ "Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False
120
+ )
121
+ num_images = st.sidebar.slider(
122
+ "Number of images to be shown", min_value=1, max_value=10, value=3)
123
+ if video_file:
124
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile:
125
+ tmpfile.write(video_file.read())
126
+ video_path = tmpfile.name
127
+ st.video(video_path)
128
+ st.sidebar.subheader("Add uploaded videos to collection")
129
+ if st.sidebar.button("Add uploaded video"):
130
+ extract_frames(video_path)
131
+ insert_into_db(collection, temp_dir)
132
+ else:
133
+ video_path = 'Videos/Video.mp4'
134
+ st.video(video_path)
135
+ st.write(
136
+ f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates")
137
+
138
+ st.write("Enter the description of image to be extracted from the video")
139
+ text_input = st.text_input("Description", "Flying Skater")
140
+ if st.button("Search"):
141
+ if text_input.strip():
142
+ params = {'text': text_input.strip()}
143
+ response = requests.get(
144
+ 'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
145
+ if response.status_code == 200:
146
+ logger.info("Embedding returned by API successfully")
147
+ data = json.loads(response.content)
148
+ embedding = data['embedding']
149
+ results = collection.query(
150
+ query_embeddings=[embedding],
151
+ n_results=num_images
152
+ )
153
+ images = [results['metadatas'][0][i]['path']
154
+ for i in range(len(results['metadatas'][0]))]
155
+ distances = [results['distances'][0][i]
156
+ for i in range(len(results['metadatas'][0]))]
157
+ if images:
158
+ cols_per_row = 3
159
+ rows = (len(images)+cols_per_row-1)//cols_per_row
160
+ for row in range(rows):
161
+ cols = st.columns(cols_per_row)
162
+ for col_idx, col in enumerate(cols):
163
+ img_idx = row*cols_per_row+col_idx
164
+ if img_idx < len(images):
165
+ resized_img = resize_image(
166
+ images[img_idx])
167
+ col.image(resized_img,
168
+ caption=f"Image {img_idx+1}", use_container_width=True)
169
+ else:
170
+ st.write("No image found")
171
+ else:
172
+ st.write("Please try again later")
173
+ logger.info(f"status code {response.status_code} returned")
174
+ else:
175
+ st.write("Please enter text in the text area")
176
+
177
+ except Exception as e:
178
+ logger.exception(f"Exception occured, {e}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.50.3
2
+ streamlit==1.44.1
3
+ chromadb==0.6.3
4
+ requests==2.32.3
5
+ torch==2.6.0
6
+ python-dotenv==1.1.0
7
+ opencv-python==4.11.0.86