Spaces:
Runtime error
Runtime error
File size: 2,220 Bytes
e539b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from sentence_transformers import SentenceTransformer
from ._utils import FewDocumentsError
from ._utils import document_extraction, paragraph_extraction, semantic_search
from corpora import gen_corpus
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import string
from ._utils import device
def extract(query: str, n: int=3, extracted_documents: list=None) -> str:
"""Extract n paragraphs from the corpus using the given query.
Parameters:
query (str): Sentence used to search the corpus for relevant documents
n (int): Number of paragraphs to return
Returns:
str: String containing the n most relevant paragraphs joined by line breaks
"""
# Open corpus
corpus = gen_corpus(query)
# Setup query
stop_words = set(stopwords.words('english'))
query_tokens = word_tokenize(query.lower())
tokens_without_sw = [word for word in query_tokens if not word in stop_words]
keywords = [keyword for keyword in tokens_without_sw if keyword not in string.punctuation]
# Gross search
if not extracted_documents:
extracted_documents, documents_empty, documents_sizes = document_extraction(
dataset=corpus,
query=query,
keywords=keywords,
min_document_size=0,
min_just_one_paragraph_size=0
)
# First semantc search (over documents)
# Model for semantic searches
search_model = SentenceTransformer('msmarco-distilbert-base-v4', device=device)
selected_documents, documents_distances = semantic_search(
model=search_model,
query=query,
files=extracted_documents,
number_of_similar_files=10
)
# Second semantic search (over paragraphs)
paragraphs = paragraph_extraction(
documents=selected_documents,
min_paragraph_size=20,
)
# Model for the second semantic search
selected_paragraphs, paragraphs_distances = semantic_search(
model=search_model,
query=query,
files=paragraphs,
number_of_similar_files=10
)
from pprint import pprint
pprint(selected_paragraphs[:n])
text = '\n'.join(selected_paragraphs[:n])
return text
|