awacke1 commited on
Commit
bdefc08
·
verified ·
1 Parent(s): 28280be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -245
app.py CHANGED
@@ -5,284 +5,149 @@ from sentence_transformers import SentenceTransformer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  import os
7
  from datetime import datetime
8
- import requests
9
  from datasets import load_dataset
10
- from urllib.parse import quote
11
 
12
  # Initialize session state
13
  if 'search_history' not in st.session_state:
14
  st.session_state['search_history'] = []
15
  if 'search_columns' not in st.session_state:
16
  st.session_state['search_columns'] = []
17
- if 'initial_search_done' not in st.session_state:
18
- st.session_state['initial_search_done'] = False
19
- if 'dataset' not in st.session_state:
20
- st.session_state['dataset'] = None
 
 
21
 
22
- class DatasetSearcher:
 
 
 
 
 
 
23
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
24
  self.dataset_id = dataset_id
25
- self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
26
  self.token = os.environ.get('DATASET_KEY')
27
  if not self.token:
28
  st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
29
  st.stop()
30
- self.load_dataset()
31
 
32
- def load_dataset(self):
33
- """Load dataset using the datasets library"""
 
34
  try:
35
- if st.session_state['dataset'] is None:
36
- with st.spinner("Loading dataset..."):
37
- st.session_state['dataset'] = load_dataset(
38
- self.dataset_id,
39
- token=self.token,
40
- streaming=False
41
- )
42
-
43
- self.dataset = st.session_state['dataset']
44
- # Convert first split to DataFrame for easier processing
45
- first_split = next(iter(self.dataset.values()))
46
- self.df = pd.DataFrame(first_split)
47
-
48
- # Store column information
49
- self.columns = list(self.df.columns)
50
- # Identify searchable columns
51
- self.text_columns = []
52
- for col in self.columns:
53
- if col.lower() not in ['embed', 'vector', 'encoding']:
54
- sample_val = self.df[col].iloc[0] if not self.df.empty else None
55
- if isinstance(sample_val, (str, int, float, list, dict)) or sample_val is None:
56
- self.text_columns.append(col)
57
-
58
- # Update session state columns
59
- st.session_state['search_columns'] = self.text_columns
60
-
61
- # Prepare text embeddings
62
- self.prepare_features()
63
-
64
  except Exception as e:
65
  st.error(f"Error loading dataset: {str(e)}")
66
- st.error("Please check your authentication token and internet connection.")
67
- st.stop()
 
 
 
 
68
 
69
- def prepare_features(self):
70
- """Prepare text embeddings for semantic search"""
71
  try:
72
- # Process text columns and handle different data types
73
- processed_texts = []
74
- for _, row in self.df.iterrows():
75
- row_texts = []
76
- for col in self.text_columns:
77
- value = row[col]
78
- if isinstance(value, (list, dict)):
79
- # Convert lists or dicts to string representation
80
- row_texts.append(str(value))
81
- elif isinstance(value, (int, float)):
82
- # Convert numbers to strings
83
- row_texts.append(str(value))
84
- elif value is None:
85
- row_texts.append('')
86
- else:
87
- # Handle string values
88
- row_texts.append(str(value))
89
- processed_texts.append(' '.join(row_texts))
90
-
91
- # Create embeddings in batches to manage memory
92
- batch_size = 32
93
- all_embeddings = []
94
 
95
- with st.spinner("Preparing search features..."):
96
- for i in range(0, len(processed_texts), batch_size):
97
- batch = processed_texts[i:i+batch_size]
98
- embeddings = self.text_model.encode(batch)
99
- all_embeddings.append(embeddings)
100
 
101
- self.text_embeddings = np.vstack(all_embeddings)
 
 
102
 
103
- except Exception as e:
104
- st.warning(f"Error preparing features: {str(e)}")
105
- self.text_embeddings = np.random.randn(len(self.df), 384)
106
-
107
- def search(self, query, column=None, top_k=20):
108
- """Search the dataset using semantic and keyword matching"""
109
- if self.df.empty:
110
- return []
111
-
112
- # Get semantic similarity scores
113
- query_embedding = self.text_model.encode([query])[0]
114
- similarities = cosine_similarity([query_embedding], self.text_embeddings)[0]
115
-
116
- # Get keyword match scores
117
- search_columns = [column] if column and column != "All Fields" else self.text_columns
118
- keyword_scores = np.zeros(len(self.df))
119
-
120
- query_lower = query.lower()
121
- for col in search_columns:
122
- if col in self.df.columns:
123
- for idx, value in enumerate(self.df[col]):
124
- if isinstance(value, (list, dict)):
125
- # Search in string representation of lists or dicts
126
- text = str(value).lower()
127
- elif isinstance(value, (int, float)):
128
- # Convert numbers to strings for searching
129
- text = str(value).lower()
130
- elif value is None:
131
- text = ''
132
- else:
133
- # Handle string values
134
- text = str(value).lower()
135
-
136
- keyword_scores[idx] += text.count(query_lower)
137
-
138
- # Combine scores
139
- combined_scores = 0.5 * similarities + 0.5 * (keyword_scores / max(1, keyword_scores.max()))
140
 
141
  # Get top results
142
- top_k = min(top_k, len(combined_scores))
143
- top_indices = np.argsort(combined_scores)[-top_k:][::-1]
144
-
145
- # Format results
146
- results = []
147
- for idx in top_indices:
148
- result = {
149
- 'relevance_score': float(combined_scores[idx]),
150
- 'semantic_score': float(similarities[idx]),
151
- 'keyword_score': float(keyword_scores[idx]),
152
- **self.df.iloc[idx].to_dict()
153
- }
154
- results.append(result)
155
-
156
- return results
157
-
158
- def get_dataset_info(self):
159
- """Get information about the dataset"""
160
- if not self.dataset:
161
- return {}
162
-
163
- info = {
164
- 'splits': list(self.dataset.keys()),
165
- 'total_rows': sum(split.num_rows for split in self.dataset.values()),
166
- 'columns': self.columns,
167
- 'text_columns': self.text_columns,
168
- 'sample_rows': len(self.df),
169
- 'embeddings_shape': self.text_embeddings.shape
170
- }
171
-
172
- return info
173
-
174
- def render_video_result(result):
175
- """Render a video result with enhanced display"""
176
- col1, col2 = st.columns([2, 1])
177
-
178
- with col1:
179
- if 'title' in result:
180
- st.markdown(f"**Title:** {result['title']}")
181
- if 'description' in result:
182
- st.markdown("**Description:**")
183
- st.write(result['description'])
184
-
185
- # Show timing information if available
186
- if 'start_time' in result and 'end_time' in result:
187
- st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
188
-
189
- # Show additional metadata
190
- for key, value in result.items():
191
- if key not in ['title', 'description', 'start_time', 'end_time', 'duration',
192
- 'relevance_score', 'semantic_score', 'keyword_score',
193
- 'video_id', 'youtube_id']:
194
- st.markdown(f"**{key.replace('_', ' ').title()}:** {value}")
195
-
196
- with col2:
197
- # Show search scores
198
- st.markdown("**Search Scores:**")
199
- cols = st.columns(3)
200
- cols[0].metric("Overall", f"{result['relevance_score']:.2%}")
201
- cols[1].metric("Semantic", f"{result['semantic_score']:.2%}")
202
- cols[2].metric("Keyword", f"{result['keyword_score']:.0f} matches")
203
-
204
- # Display video if available
205
- if 'youtube_id' in result:
206
- st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
207
 
208
  def main():
209
- st.title("🎥 Video Dataset Search")
210
 
211
  # Initialize search class
212
- searcher = DatasetSearcher()
213
 
214
- # Create tabs
215
- tab1, tab2 = st.tabs(["🔍 Search", "📊 Dataset Info"])
216
 
217
- # ---- Tab 1: Search ----
218
- with tab1:
219
- st.subheader("Search Videos")
220
- col1, col2 = st.columns([3, 1])
221
-
222
- with col1:
223
- query = st.text_input("Search query:",
224
- value="" if st.session_state['initial_search_done'] else "")
225
- with col2:
226
- search_column = st.selectbox("Search in field:",
227
- ["All Fields"] + st.session_state['search_columns'])
228
-
229
- col3, col4 = st.columns(2)
230
- with col3:
231
- num_results = st.slider("Number of results:", 1, 100, 20)
232
- with col4:
233
- search_button = st.button("🔍 Search")
234
-
235
- if search_button and query:
236
- st.session_state['initial_search_done'] = True
237
- selected_column = None if search_column == "All Fields" else search_column
238
-
239
- with st.spinner("Searching..."):
240
- results = searcher.search(query, selected_column, num_results)
241
-
242
- st.session_state['search_history'].append({
243
- 'query': query,
244
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
245
- 'results': results[:5]
246
- })
247
-
248
- for i, result in enumerate(results, 1):
249
- with st.expander(
250
- f"Result {i}: {result.get('title', result.get('description', 'No title'))[:100]}...",
251
- expanded=(i==1)
252
- ):
253
- render_video_result(result)
254
 
255
- # ---- Tab 2: Dataset Info ----
256
- with tab2:
257
- st.subheader("Dataset Information")
258
-
259
- info = searcher.get_dataset_info()
260
- if info:
261
- st.write(f"### Dataset: {searcher.dataset_id}")
262
- st.write(f"- Total rows: {info['total_rows']:,}")
263
- st.write(f"- Available splits: {', '.join(info['splits'])}")
264
- st.write(f"- Number of columns: {len(info['columns'])}")
265
- st.write(f"- Searchable text columns: {', '.join(info['text_columns'])}")
266
-
267
- st.write("### Sample Data")
268
- st.dataframe(searcher.df.head())
269
 
270
- st.write("### Column Details")
271
- for col in info['columns']:
272
- st.write(f"- **{col}**: {searcher.df[col].dtype}")
273
-
274
- # Sidebar
275
- with st.sidebar:
276
- st.subheader("⚙️ Search History")
277
- if st.button("🗑️ Clear History"):
278
- st.session_state['search_history'] = []
279
- st.experimental_rerun()
280
-
281
- st.markdown("### Recent Searches")
282
- for entry in reversed(st.session_state['search_history'][-5:]):
283
- with st.expander(f"{entry['timestamp']}: {entry['query']}"):
284
- for i, result in enumerate(entry['results'], 1):
285
- st.write(f"{i}. {result.get('title', result.get('description', 'No title'))[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  if __name__ == "__main__":
288
  main()
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  import os
7
  from datetime import datetime
 
8
  from datasets import load_dataset
 
9
 
10
  # Initialize session state
11
  if 'search_history' not in st.session_state:
12
  st.session_state['search_history'] = []
13
  if 'search_columns' not in st.session_state:
14
  st.session_state['search_columns'] = []
15
+ if 'dataset_loaded' not in st.session_state:
16
+ st.session_state['dataset_loaded'] = False
17
+ if 'current_page' not in st.session_state:
18
+ st.session_state['current_page'] = 0
19
+ if 'data_cache' not in st.session_state:
20
+ st.session_state['data_cache'] = None
21
 
22
+ ROWS_PER_PAGE = 100 # Number of rows to load at a time
23
+
24
+ @st.cache_resource
25
+ def get_model():
26
+ return SentenceTransformer('all-MiniLM-L6-v2')
27
+
28
+ class FastDatasetSearcher:
29
  def __init__(self, dataset_id="tomg-group-umd/cinepile"):
30
  self.dataset_id = dataset_id
31
+ self.text_model = get_model()
32
  self.token = os.environ.get('DATASET_KEY')
33
  if not self.token:
34
  st.error("Please set the DATASET_KEY environment variable with your Hugging Face token.")
35
  st.stop()
36
+ self.load_dataset_info()
37
 
38
+ @st.cache_data
39
+ def load_dataset_info(self):
40
+ """Load dataset metadata only"""
41
  try:
42
+ dataset = load_dataset(
43
+ self.dataset_id,
44
+ token=self.token,
45
+ streaming=True
46
+ )
47
+ self.dataset_info = dataset['train'].info
48
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
  st.error(f"Error loading dataset: {str(e)}")
51
+ return False
52
+
53
+ def load_page(self, page=0):
54
+ """Load a specific page of data"""
55
+ if st.session_state['data_cache'] is not None and st.session_state['current_page'] == page:
56
+ return st.session_state['data_cache']
57
 
 
 
58
  try:
59
+ dataset = load_dataset(
60
+ self.dataset_id,
61
+ token=self.token,
62
+ streaming=False,
63
+ split=f'train[{page*ROWS_PER_PAGE}:{(page+1)*ROWS_PER_PAGE}]'
64
+ )
65
+ df = pd.DataFrame(dataset)
66
+ st.session_state['data_cache'] = df
67
+ st.session_state['current_page'] = page
68
+ return df
69
+ except Exception as e:
70
+ st.error(f"Error loading page {page}: {str(e)}")
71
+ return pd.DataFrame()
72
+
73
+ def quick_search(self, query, df):
74
+ """Fast search on current page"""
75
+ scores = []
76
+ query_embedding = self.text_model.encode([query], show_progress_bar=False)[0]
77
+
78
+ for _, row in df.iterrows():
79
+ # Combine all searchable text fields
80
+ text = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
81
 
82
+ # Quick keyword match
83
+ keyword_score = text.lower().count(query.lower()) / len(text.split())
 
 
 
84
 
85
+ # Semantic search on combined text
86
+ text_embedding = self.text_model.encode([text], show_progress_bar=False)[0]
87
+ semantic_score = cosine_similarity([query_embedding], [text_embedding])[0][0]
88
 
89
+ # Combine scores
90
+ combined_score = 0.5 * semantic_score + 0.5 * keyword_score
91
+ scores.append(combined_score)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # Get top results
94
+ df['score'] = scores
95
+ return df.sort_values('score', ascending=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def main():
98
+ st.title("🎥 Fast Video Dataset Search")
99
 
100
  # Initialize search class
101
+ searcher = FastDatasetSearcher()
102
 
103
+ # Page navigation
104
+ page = st.number_input("Page", min_value=0, value=st.session_state['current_page'])
105
 
106
+ # Load current page
107
+ with st.spinner(f"Loading page {page}..."):
108
+ df = searcher.load_page(page)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ if df.empty:
111
+ st.warning("No data available for this page.")
112
+ return
113
+
114
+ # Search interface
115
+ query = st.text_input("Search in current page:", help="Searches within currently loaded data")
116
+
117
+ if query:
118
+ with st.spinner("Searching..."):
119
+ results = searcher.quick_search(query, df)
 
 
 
 
120
 
121
+ # Display results
122
+ st.write(f"Found {len(results)} results on this page:")
123
+ for i, (_, result) in enumerate(results.iterrows(), 1):
124
+ score = result.pop('score')
125
+ with st.expander(f"Result {i} (Score: {score:.2%})", expanded=i==1):
126
+ # Display video if available
127
+ if 'youtube_id' in result:
128
+ st.video(
129
+ f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}"
130
+ )
131
+
132
+ # Display other fields
133
+ for key, value in result.items():
134
+ if isinstance(value, (str, int, float)):
135
+ st.write(f"**{key}:** {value}")
136
+
137
+ # Show raw data
138
+ st.subheader("Raw Data")
139
+ st.dataframe(df)
140
+
141
+ # Navigation buttons
142
+ cols = st.columns(2)
143
+ with cols[0]:
144
+ if st.button("Previous Page") and page > 0:
145
+ st.session_state['current_page'] -= 1
146
+ st.rerun()
147
+ with cols[1]:
148
+ if st.button("Next Page"):
149
+ st.session_state['current_page'] += 1
150
+ st.rerun()
151
 
152
  if __name__ == "__main__":
153
  main()