thechaiexperiment commited on
Commit
69e8e11
·
verified ·
1 Parent(s): d64968f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -350,47 +350,41 @@ from sentence_transformers import SentenceTransformer
350
  from sklearn.metrics.pairwise import cosine_similarity
351
  import nltk
352
 
353
- # Load a pre-trained embedding model
354
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Use a lightweight model for speed
355
-
356
- from sentence_transformers import SentenceTransformer
357
- from sklearn.metrics.pairwise import cosine_similarity
358
-
359
- # Load the embedding model globally for efficiency
360
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
361
-
362
- def extract_relevant_portions(document_texts, query, max_portions=3, chunk_size=500):
363
  try:
364
- # Embed the query once
365
- query_embedding = embedding_model.encode([query])
366
-
367
  relevant_portions = {}
368
- for doc_id, doc_text in enumerate(document_texts):
369
- # Split document into chunks (e.g., 500 characters per chunk)
370
- chunks = [doc_text[i:i + chunk_size] for i in range(0, len(doc_text), chunk_size)]
371
-
372
- # Embed all chunks in a single batch
373
- chunk_embeddings = embedding_model.encode(chunks)
374
 
375
- # Compute cosine similarity between query and all chunks
376
- similarities = cosine_similarity(query_embedding, chunk_embeddings)[0]
377
 
378
- # Rank chunks by similarity
379
- ranked_chunks = sorted(
380
- enumerate(chunks),
381
- key=lambda x: similarities[x[0]],
382
- reverse=True
383
- )
 
 
 
 
 
384
 
385
- # Select top chunks based on similarity
386
- doc_relevant_portions = [chunk for _, chunk in ranked_chunks[:max_portions]]
387
- relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
388
 
389
  return relevant_portions
 
390
  except Exception as e:
391
- print(f"Error in extracting relevant portions: {e}")
392
  return {}
393
 
 
394
  def remove_duplicates(selected_parts):
395
  unique_sentences = set()
396
  unique_selected_parts = []
@@ -532,7 +526,7 @@ async def chat_endpoint(chat_query: ChatQuery):
532
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
533
  scored_documents = list(zip(scores, document_ids, document_texts))
534
  scored_documents.sort(key=lambda x: x[0], reverse=True)
535
- relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, chunk_size=500)
536
  #flattened_relevant_portions = []
537
  #for doc_id, portions in relevant_portions.items():
538
  #flattened_relevant_portions.extend(portions)
 
350
  from sklearn.metrics.pairwise import cosine_similarity
351
  import nltk
352
 
353
+ def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
 
 
 
 
 
 
 
 
 
354
  try:
 
 
 
355
  relevant_portions = {}
356
+
357
+ for _, doc_id, doc_text in top_documents:
358
+ if doc_id not in embeddings_data:
359
+ print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
360
+ continue
 
361
 
362
+ # Retrieve the precomputed embedding for this document
363
+ doc_embedding = np.array(embeddings_data[doc_id])
364
 
365
+ # Compute similarity between the query embedding and the document embedding
366
+ similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
367
+
368
+ # Split the document into sentences
369
+ sentences = nltk.sent_tokenize(doc_text)
370
+
371
+ # Rank sentences based on their length (proxy for importance) or other heuristic
372
+ # Since we're using document-level embeddings, we assume all sentences are equally relevant.
373
+ sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
374
+
375
+ relevant_portions[doc_id] = sorted_sentences
376
 
377
+ print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
378
+ for i, sentence in enumerate(sorted_sentences, start=1):
379
+ print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
380
 
381
  return relevant_portions
382
+
383
  except Exception as e:
384
+ print(f"Error in extract_relevant_portions: {e}")
385
  return {}
386
 
387
+
388
  def remove_duplicates(selected_parts):
389
  unique_sentences = set()
390
  unique_selected_parts = []
 
526
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
527
  scored_documents = list(zip(scores, document_ids, document_texts))
528
  scored_documents.sort(key=lambda x: x[0], reverse=True)
529
+ relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
530
  #flattened_relevant_portions = []
531
  #for doc_id, portions in relevant_portions.items():
532
  #flattened_relevant_portions.extend(portions)