ashish-001 commited on
Commit
ade9ea5
·
verified ·
1 Parent(s): 5c38ac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -178
app.py CHANGED
@@ -1,178 +1,182 @@
1
- import streamlit as st
2
- import chromadb
3
- from chromadb.config import Settings
4
- from transformers import CLIPProcessor, CLIPModel
5
- import cv2
6
- from PIL import Image
7
- import torch
8
- import logging
9
- import uuid
10
- import tempfile
11
- import os
12
- import requests
13
- import json
14
- from dotenv import load_dotenv
15
- import shutil
16
-
17
- load_dotenv()
18
- HF_TOKEN = os.getenv('hf_token')
19
- logging.basicConfig(level=logging.INFO)
20
- logger = logging.getLogger(__name__)
21
- try:
22
- temp_dir = 'temp_folder'
23
- if 'cleaned_temp' not in st.session_state:
24
- if os.path.exists(temp_dir):
25
- shutil.rmtree(temp_dir)
26
- os.makedirs(temp_dir, exist_ok=True)
27
- st.session_state.cleaned_temp = True
28
-
29
- @st.cache_resource
30
- def load_model():
31
- device = 'cpu'
32
- processor = CLIPProcessor.from_pretrained(
33
- "openai/clip-vit-large-patch14", token=HF_TOKEN)
34
- model = CLIPModel.from_pretrained(
35
- "openai/clip-vit-large-patch14", token=HF_TOKEN)
36
- model.eval().to(device)
37
- return processor, model
38
-
39
- @st.cache_resource
40
- def load_chromadb():
41
- chroma_client = chromadb.Client(
42
- path='Data', settings=Settings(anonymized_telemetry=False))
43
- collection = chroma_client.get_or_create_collection(name='images')
44
- return chroma_client, collection
45
-
46
- def resize_image(image_path, size=(224, 224)):
47
- if isinstance(image_path, str):
48
- img = Image.open(image_path).convert("RGB")
49
- else:
50
- img = Image.open(image_path).convert("RGB")
51
- img_resized = img.resize(size, Image.LANCZOS)
52
- return img_resized
53
-
54
- def get_image_embedding(image, model, preprocess, device='cpu'):
55
- image = Image.open(f'{image}').convert('RGB')
56
- input_tensor = preprocess(images=[image], return_tensors='pt')[
57
- 'pixel_values'].to(device)
58
- with torch.no_grad():
59
- embedding = model.get_image_features(
60
- pixel_values=input_tensor)
61
-
62
- return torch.nn.functional.normalize(embedding, p=2, dim=1)
63
-
64
- def extract_frames(v_path, frame_interval=30):
65
- cap = cv2.VideoCapture(v_path)
66
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
67
- frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
68
- total_seconds = frame_count//frame_rate
69
- frame_idx = 0
70
- saved_frames = 0
71
- while cap.isOpened():
72
- ret, frame = cap.read()
73
- if not ret:
74
- break
75
- if frame_idx % frame_interval == 0:
76
- unique_image_id = str(uuid.uuid4())
77
- frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg"
78
- cv2.imwrite(frame_name, frame)
79
-
80
- saved_frames += 1
81
- frame_idx += 1
82
- cap.release()
83
-
84
- def insert_into_db(collection, dir):
85
- embedding_list = []
86
- file_names = []
87
- ids = []
88
- with st.status("Generating embedding... ⏳", expanded=True) as status:
89
- for i in os.listdir(dir):
90
- embedding = get_image_embedding(
91
- f"{dir}/{i}", model, processor)
92
- embedding_list.append(
93
- embedding.squeeze(0).numpy().tolist())
94
- file_names.append(
95
- {'path': f"{dir}/{i}", 'type': 'photo'})
96
- unique_id = str(uuid.uuid4())
97
- ids.append(unique_id)
98
- status.update(label="Embedding generation complete",
99
- state="complete")
100
-
101
- collection.add(
102
- embeddings=embedding_list,
103
- ids=ids,
104
- metadatas=file_names
105
- )
106
- logger.info("Data inserted into DB")
107
-
108
- processor, model = load_model()
109
- logger.info("Model and processor loaded")
110
- client, collection = load_chromadb()
111
- logger.info("ChromaDB loaded")
112
- logger.info(
113
- f"Connected to ChromaDB collection images with {collection.count()} items")
114
-
115
- st.title("Extract frames from video using text")
116
- # Upload section
117
- st.sidebar.subheader("Upload video")
118
- video_file = st.sidebar.file_uploader(
119
- "Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False
120
- )
121
- num_images = st.sidebar.slider(
122
- "Number of images to be shown", min_value=1, max_value=10, value=3)
123
- if video_file:
124
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile:
125
- tmpfile.write(video_file.read())
126
- video_path = tmpfile.name
127
- st.video(video_path)
128
- st.sidebar.subheader("Add uploaded videos to collection")
129
- if st.sidebar.button("Add uploaded video"):
130
- extract_frames(video_path)
131
- insert_into_db(collection, temp_dir)
132
- else:
133
- video_path = 'Videos/Video.mp4'
134
- st.video(video_path)
135
- st.write(
136
- f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates")
137
-
138
- st.write("Enter the description of image to be extracted from the video")
139
- text_input = st.text_input("Description", "Flying Skater")
140
- if st.button("Search"):
141
- if text_input.strip():
142
- params = {'text': text_input.strip()}
143
- response = requests.get(
144
- 'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
145
- if response.status_code == 200:
146
- logger.info("Embedding returned by API successfully")
147
- data = json.loads(response.content)
148
- embedding = data['embedding']
149
- results = collection.query(
150
- query_embeddings=[embedding],
151
- n_results=num_images
152
- )
153
- images = [results['metadatas'][0][i]['path']
154
- for i in range(len(results['metadatas'][0]))]
155
- distances = [results['distances'][0][i]
156
- for i in range(len(results['metadatas'][0]))]
157
- if images:
158
- cols_per_row = 3
159
- rows = (len(images)+cols_per_row-1)//cols_per_row
160
- for row in range(rows):
161
- cols = st.columns(cols_per_row)
162
- for col_idx, col in enumerate(cols):
163
- img_idx = row*cols_per_row+col_idx
164
- if img_idx < len(images):
165
- resized_img = resize_image(
166
- images[img_idx])
167
- col.image(resized_img,
168
- caption=f"Image {img_idx+1}", use_container_width=True)
169
- else:
170
- st.write("No image found")
171
- else:
172
- st.write("Please try again later")
173
- logger.info(f"status code {response.status_code} returned")
174
- else:
175
- st.write("Please enter text in the text area")
176
-
177
- except Exception as e:
178
- logger.exception(f"Exception occured, {e}")
 
 
 
 
 
1
+ import streamlit as st
2
+ import chromadb
3
+ from chromadb.config import Settings
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import cv2
6
+ from PIL import Image
7
+ import torch
8
+ import logging
9
+ import uuid
10
+ import tempfile
11
+ import os
12
+ import requests
13
+ import json
14
+ from dotenv import load_dotenv
15
+ import shutil
16
+
17
+ load_dotenv()
18
+ HF_TOKEN = os.getenv('hf_token')
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+ try:
22
+
23
+
24
+
25
+ @st.cache_resource
26
+ def load_model():
27
+ device = 'cpu'
28
+ processor = CLIPProcessor.from_pretrained(
29
+ "openai/clip-vit-large-patch14", token=HF_TOKEN)
30
+ model = CLIPModel.from_pretrained(
31
+ "openai/clip-vit-large-patch14", token=HF_TOKEN)
32
+ model.eval().to(device)
33
+ return processor, model
34
+
35
+ @st.cache_resource
36
+ def load_chromadb():
37
+ chroma_client = chromadb.Client(
38
+ path='Data', settings=Settings(anonymized_telemetry=False))
39
+ collection = chroma_client.get_or_create_collection(name='images')
40
+ return chroma_client, collection
41
+
42
+ def resize_image(image_path, size=(224, 224)):
43
+ if isinstance(image_path, str):
44
+ img = Image.open(image_path).convert("RGB")
45
+ else:
46
+ img = Image.open(image_path).convert("RGB")
47
+ img_resized = img.resize(size, Image.LANCZOS)
48
+ return img_resized
49
+
50
+ def get_image_embedding(image, model, preprocess, device='cpu'):
51
+ image = Image.open(f'{image}').convert('RGB')
52
+ input_tensor = preprocess(images=[image], return_tensors='pt')[
53
+ 'pixel_values'].to(device)
54
+ with torch.no_grad():
55
+ embedding = model.get_image_features(
56
+ pixel_values=input_tensor)
57
+
58
+ return torch.nn.functional.normalize(embedding, p=2, dim=1)
59
+
60
+ def extract_frames(v_path, frame_interval=30):
61
+ cap = cv2.VideoCapture(v_path)
62
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
63
+ frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
64
+ total_seconds = frame_count//frame_rate
65
+ frame_idx = 0
66
+ saved_frames = 0
67
+ while cap.isOpened():
68
+ ret, frame = cap.read()
69
+ if not ret:
70
+ break
71
+ if frame_idx % frame_interval == 0:
72
+ unique_image_id = str(uuid.uuid4())
73
+ frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg"
74
+ cv2.imwrite(frame_name, frame)
75
+
76
+ saved_frames += 1
77
+ frame_idx += 1
78
+ cap.release()
79
+
80
+ def insert_into_db(collection, dir):
81
+ embedding_list = []
82
+ file_names = []
83
+ ids = []
84
+ with st.status("Generating embedding... ⏳", expanded=True) as status:
85
+ for i in os.listdir(dir):
86
+ embedding = get_image_embedding(
87
+ f"{dir}/{i}", model, processor)
88
+ embedding_list.append(
89
+ embedding.squeeze(0).numpy().tolist())
90
+ file_names.append(
91
+ {'path': f"{dir}/{i}", 'type': 'photo'})
92
+ unique_id = str(uuid.uuid4())
93
+ ids.append(unique_id)
94
+ status.update(label="Embedding generation complete",
95
+ state="complete")
96
+
97
+ collection.add(
98
+ embeddings=embedding_list,
99
+ ids=ids,
100
+ metadatas=file_names
101
+ )
102
+ logger.info("Data inserted into DB")
103
+
104
+ processor, model = load_model()
105
+ logger.info("Model and processor loaded")
106
+ client, collection = load_chromadb()
107
+ logger.info("ChromaDB loaded")
108
+ logger.info(
109
+ f"Connected to ChromaDB collection images with {collection.count()} items")
110
+
111
+ temp_dir = 'temp_folder'
112
+ if 'cleaned_temp' not in st.session_state:
113
+ if os.path.exists(temp_dir):
114
+ shutil.rmtree(temp_dir)
115
+ os.makedirs(temp_dir, exist_ok=True)
116
+ st.session_state.cleaned_temp = True
117
+ collection.delete(where={"path": {"$contains": "temp_folder"}})
118
+
119
+ st.title("Extract frames from video using text")
120
+ # Upload section
121
+ st.sidebar.subheader("Upload video")
122
+ video_file = st.sidebar.file_uploader(
123
+ "Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False
124
+ )
125
+ num_images = st.sidebar.slider(
126
+ "Number of images to be shown", min_value=1, max_value=10, value=3)
127
+ if video_file:
128
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile:
129
+ tmpfile.write(video_file.read())
130
+ video_path = tmpfile.name
131
+ st.video(video_path)
132
+ st.sidebar.subheader("Add uploaded videos to collection")
133
+ if st.sidebar.button("Add uploaded video"):
134
+ extract_frames(video_path)
135
+ insert_into_db(collection, temp_dir)
136
+ else:
137
+ video_path = 'Videos/Video.mp4'
138
+ st.video(video_path)
139
+ st.write(
140
+ f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates")
141
+
142
+ st.write("Enter the description of image to be extracted from the video")
143
+ text_input = st.text_input("Description", "Flying Skater")
144
+ if st.button("Search"):
145
+ if text_input.strip():
146
+ params = {'text': text_input.strip()}
147
+ response = requests.get(
148
+ 'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
149
+ if response.status_code == 200:
150
+ logger.info("Embedding returned by API successfully")
151
+ data = json.loads(response.content)
152
+ embedding = data['embedding']
153
+ results = collection.query(
154
+ query_embeddings=[embedding],
155
+ n_results=num_images
156
+ )
157
+ images = [results['metadatas'][0][i]['path']
158
+ for i in range(len(results['metadatas'][0]))]
159
+ distances = [results['distances'][0][i]
160
+ for i in range(len(results['metadatas'][0]))]
161
+ if images:
162
+ cols_per_row = 3
163
+ rows = (len(images)+cols_per_row-1)//cols_per_row
164
+ for row in range(rows):
165
+ cols = st.columns(cols_per_row)
166
+ for col_idx, col in enumerate(cols):
167
+ img_idx = row*cols_per_row+col_idx
168
+ if img_idx < len(images):
169
+ resized_img = resize_image(
170
+ images[img_idx])
171
+ col.image(resized_img,
172
+ caption=f"Image {img_idx+1}", use_container_width=True)
173
+ else:
174
+ st.write("No image found")
175
+ else:
176
+ st.write("Please try again later")
177
+ logger.info(f"status code {response.status_code} returned")
178
+ else:
179
+ st.write("Please enter text in the text area")
180
+
181
+ except Exception as e:
182
+ logger.exception(f"Exception occured, {e}")