document_redaction / tools /load_spacy_model_custom_recognisers.py
seanpedrickcase's picture
Added regex functionality to deny lists. Corrected tesseract to word level parsing. Improved review search regex capabilities. Updated documentation
4852fb5
from typing import List
import spacy
from presidio_analyzer import (
AnalyzerEngine,
EntityRecognizer,
Pattern,
PatternRecognizer,
RecognizerResult,
)
from presidio_analyzer.nlp_engine import (
NerModelConfiguration,
NlpArtifacts,
SpacyNlpEngine,
)
from spacy.matcher import Matcher
from spaczz.matcher import FuzzyMatcher
spacy.prefer_gpu()
import os
import re
import gradio as gr
import Levenshtein
import requests
from spacy.cli.download import download
from tools.config import (
CUSTOM_ENTITIES,
DEFAULT_LANGUAGE,
SPACY_MODEL_PATH,
TESSERACT_DATA_FOLDER,
)
score_threshold = 0.001
custom_entities = CUSTOM_ENTITIES
# Create a class inheriting from SpacyNlpEngine
class LoadedSpacyNlpEngine(SpacyNlpEngine):
def __init__(self, loaded_spacy_model, language_code: str):
super().__init__(
ner_model_configuration=NerModelConfiguration(
labels_to_ignore=["CARDINAL", "ORDINAL"]
)
) # Ignore non-relevant labels
self.nlp = {language_code: loaded_spacy_model}
def _base_language_code(language: str) -> str:
lang = _normalize_language_input(language)
if "_" in lang:
return lang.split("_")[0]
return lang
def load_spacy_model(language: str = DEFAULT_LANGUAGE):
"""
Load a spaCy model for the requested language and return it as `nlp`.
Accepts common inputs like: "en", "en_lg", "en_sm", "de", "fr", "es", "it", "nl", "pt", "zh", "ja", "xx".
Falls back through sensible candidates and will download if missing.
"""
# Set spaCy data path for custom model storage (only if specified)
import os
if SPACY_MODEL_PATH and SPACY_MODEL_PATH.strip():
os.environ["SPACY_DATA"] = SPACY_MODEL_PATH
print(f"Setting spaCy model path to: {SPACY_MODEL_PATH}")
else:
print("Using default spaCy model storage location")
synonyms = {
"english": "en",
"catalan": "ca",
"danish": "da",
"german": "de",
"french": "fr",
"greek": "el",
"finnish": "fi",
"croatian": "hr",
"lithuanian": "lt",
"macedonian": "mk",
"norwegian_bokmaal": "nb",
"polish": "pl",
"russian": "ru",
"slovenian": "sl",
"swedish": "sv",
"dutch": "nl",
"portuguese": "pt",
"chinese": "zh",
"japanese": "ja",
"multilingual": "xx",
}
lang_norm = _normalize_language_input(language)
lang_norm = synonyms.get(lang_norm, lang_norm)
base_lang = _base_language_code(lang_norm)
candidates_by_lang = {
# English - prioritize lg, then trf, then md, then sm
"en": [
"en_core_web_lg",
"en_core_web_trf",
"en_core_web_md",
"en_core_web_sm",
],
"en_lg": ["en_core_web_lg"],
"en_trf": ["en_core_web_trf"],
"en_md": ["en_core_web_md"],
"en_sm": ["en_core_web_sm"],
# Major languages (news pipelines) - prioritize lg, then md, then sm
"ca": ["ca_core_news_lg", "ca_core_news_md", "ca_core_news_sm"], # Catalan
"da": ["da_core_news_lg", "da_core_news_md", "da_core_news_sm"], # Danish
"de": ["de_core_news_lg", "de_core_news_md", "de_core_news_sm"], # German
"el": ["el_core_news_lg", "el_core_news_md", "el_core_news_sm"], # Greek
"es": ["es_core_news_lg", "es_core_news_md", "es_core_news_sm"], # Spanish
"fi": ["fi_core_news_lg", "fi_core_news_md", "fi_core_news_sm"], # Finnish
"fr": ["fr_core_news_lg", "fr_core_news_md", "fr_core_news_sm"], # French
"hr": ["hr_core_news_lg", "hr_core_news_md", "hr_core_news_sm"], # Croatian
"it": ["it_core_news_lg", "it_core_news_md", "it_core_news_sm"], # Italian
"ja": ["ja_core_news_lg", "ja_core_news_md", "ja_core_news_sm"], # Japanese
"ko": ["ko_core_news_lg", "ko_core_news_md", "ko_core_news_sm"], # Korean
"lt": ["lt_core_news_lg", "lt_core_news_md", "lt_core_news_sm"], # Lithuanian
"mk": ["mk_core_news_lg", "mk_core_news_md", "mk_core_news_sm"], # Macedonian
"nb": [
"nb_core_news_lg",
"nb_core_news_md",
"nb_core_news_sm",
], # Norwegian Bokmål
"nl": ["nl_core_news_lg", "nl_core_news_md", "nl_core_news_sm"], # Dutch
"pl": ["pl_core_news_lg", "pl_core_news_md", "pl_core_news_sm"], # Polish
"pt": ["pt_core_news_lg", "pt_core_news_md", "pt_core_news_sm"], # Portuguese
"ro": ["ro_core_news_lg", "ro_core_news_md", "ro_core_news_sm"], # Romanian
"ru": ["ru_core_news_lg", "ru_core_news_md", "ru_core_news_sm"], # Russian
"sl": ["sl_core_news_lg", "sl_core_news_md", "sl_core_news_sm"], # Slovenian
"sv": ["sv_core_news_lg", "sv_core_news_md", "sv_core_news_sm"], # Swedish
"uk": ["uk_core_news_lg", "uk_core_news_md", "uk_core_news_sm"], # Ukrainian
"zh": [
"zh_core_web_lg",
"zh_core_web_mod",
"zh_core_web_sm",
"zh_core_web_trf",
], # Chinese
# Multilingual NER
"xx": ["xx_ent_wiki_sm"],
}
if lang_norm in candidates_by_lang:
candidates = candidates_by_lang[lang_norm]
elif base_lang in candidates_by_lang:
candidates = candidates_by_lang[base_lang]
else:
# Fallback to multilingual if unknown
candidates = candidates_by_lang["xx"]
last_error = None
if language != "en":
print(
f"Attempting to load spaCy model for language '{language}' with candidates: {candidates}"
)
print(
"Note: Models are prioritized by size (lg > md > sm) - will stop after first successful load"
)
for i, candidate in enumerate(candidates):
if language != "en":
print(f"Trying candidate {i+1}/{len(candidates)}: {candidate}")
# Try importable package first (fast-path when installed as a package)
try:
module = __import__(candidate)
print(f"✓ Successfully imported spaCy model: {candidate}")
return module.load()
except Exception as e:
last_error = e
# Try spacy.load if package is linked/installed
try:
nlp = spacy.load(candidate)
print(f"✓ Successfully loaded spaCy model via spacy.load: {candidate}")
return nlp
except OSError:
# Model not found, proceed with download
print(f"Model {candidate} not found, attempting to download...")
try:
download(candidate)
print(f"✓ Successfully downloaded spaCy model: {candidate}")
# Refresh spaCy's model registry after download
import importlib
import sys
importlib.reload(spacy)
# Clear any cached imports that might interfere
if candidate in sys.modules:
del sys.modules[candidate]
# Small delay to ensure model is fully registered
import time
time.sleep(0.5)
# Try to load the downloaded model
nlp = spacy.load(candidate)
print(f"✓ Successfully loaded downloaded spaCy model: {candidate}")
return nlp
except Exception as download_error:
print(f"✗ Failed to download or load {candidate}: {download_error}")
# Try alternative loading methods
try:
# Try importing the module directly after download
module = __import__(candidate)
print(
f"✓ Successfully loaded {candidate} via direct import after download"
)
return module.load()
except Exception as import_error:
print(f"✗ Direct import also failed: {import_error}")
# Try one more approach - force spaCy to refresh its model registry
try:
from spacy.util import get_model_path
model_path = get_model_path(candidate)
if model_path and os.path.exists(model_path):
print(f"Found model at path: {model_path}")
nlp = spacy.load(model_path)
print(
f"✓ Successfully loaded {candidate} from path: {model_path}"
)
return nlp
except Exception as path_error:
print(f"✗ Path-based loading also failed: {path_error}")
last_error = download_error
continue
except Exception as e:
print(f"✗ Failed to load {candidate}: {e}")
last_error = e
continue
# Provide more helpful error message
error_msg = f"Failed to load spaCy model for language '{language}'"
if last_error:
error_msg += f". Last error: {last_error}"
error_msg += f". Tried candidates: {candidates}"
raise RuntimeError(error_msg)
# Language-aware spaCy model loader
def _normalize_language_input(language: str) -> str:
return language.strip().lower().replace("-", "_")
# Update the global variables to use the new function
ACTIVE_LANGUAGE_CODE = _base_language_code(DEFAULT_LANGUAGE)
nlp = None # Placeholder, will be loaded in the create_nlp_analyser function below #load_spacy_model(DEFAULT_LANGUAGE)
def get_tesseract_lang_code(short_code: str):
"""
Maps a two-letter language code to the corresponding Tesseract OCR code.
Args:
short_code (str): The two-letter language code (e.g., "en", "de").
Returns:
str or None: The Tesseract language code (e.g., "eng", "deu"),
or None if no mapping is found.
"""
# Mapping from 2-letter codes to Tesseract 3-letter codes
# Based on ISO 639-2/T codes.
lang_map = {
"en": "eng",
"de": "deu",
"fr": "fra",
"es": "spa",
"it": "ita",
"nl": "nld",
"pt": "por",
"zh": "chi_sim", # Mapping to Simplified Chinese by default
"ja": "jpn",
"ko": "kor",
"lt": "lit",
"mk": "mkd",
"nb": "nor",
"pl": "pol",
"ro": "ron",
"ru": "rus",
"sl": "slv",
"sv": "swe",
"uk": "ukr",
}
return lang_map.get(short_code)
def download_tesseract_lang_pack(
short_lang_code: str, tessdata_dir=TESSERACT_DATA_FOLDER
):
"""
Downloads a Tesseract language pack to a local directory.
Args:
lang_code (str): The short code for the language (e.g., "eng", "fra").
tessdata_dir (str, optional): The directory to save the language pack.
Defaults to "tessdata".
"""
# Create the directory if it doesn't exist
if not os.path.exists(tessdata_dir):
os.makedirs(tessdata_dir)
# Get the Tesseract language code
lang_code = get_tesseract_lang_code(short_lang_code)
if lang_code is None:
raise ValueError(
f"Language code {short_lang_code} not found in Tesseract language map"
)
# Set the local file path
file_path = os.path.join(tessdata_dir, f"{lang_code}.traineddata")
# Check if the file already exists
if os.path.exists(file_path):
print(f"Language pack {lang_code}.traineddata already exists at {file_path}")
return file_path
# Construct the URL for the language pack
url = f"https://raw.githubusercontent.com/tesseract-ocr/tessdata/main/{lang_code}.traineddata"
# Download the file
try:
response = requests.get(url, stream=True, timeout=60)
response.raise_for_status() # Raise an exception for bad status codes
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Successfully downloaded {lang_code}.traineddata to {file_path}")
return file_path
except requests.exceptions.RequestException as e:
print(f"Error downloading {lang_code}.traineddata: {e}")
return None
#### Custom recognisers
def _is_regex_pattern(term: str) -> bool:
"""
Detect if a term is intended to be a regex pattern or a literal string.
Args:
term: The term to check
Returns:
True if the term appears to be a regex pattern, False if it's a literal string
"""
term = term.strip()
if not term:
return False
# First, try to compile as regex to validate it
# This catches patterns like \d\d\d-\d\d\d that use regex escape sequences
try:
re.compile(term)
is_valid_regex = True
except re.error:
# If it doesn't compile as regex, treat as literal
return False
# If it compiles, check if it contains regex-like features
# Regex metacharacters that suggest a pattern (excluding escaped literals)
regex_metacharacters = [
"+",
"*",
"?",
"{",
"}",
"[",
"]",
"(",
")",
"|",
"^",
"$",
".",
]
# Common regex escape sequences that indicate regex intent
regex_escape_sequences = [
"\\d",
"\\w",
"\\s",
"\\D",
"\\W",
"\\S",
"\\b",
"\\B",
"\\n",
"\\t",
"\\r",
]
# Check if term contains regex metacharacters or escape sequences
has_metacharacters = False
has_escape_sequences = False
i = 0
while i < len(term):
if term[i] == "\\" and i + 1 < len(term):
# Check if it's a regex escape sequence
escape_seq = term[i : i + 2]
if escape_seq in regex_escape_sequences:
has_escape_sequences = True
# Skip the escape sequence (backslash + next char)
i += 2
continue
if term[i] in regex_metacharacters:
has_metacharacters = True
i += 1
# If it's a valid regex and contains regex features, treat as regex pattern
if is_valid_regex and (has_metacharacters or has_escape_sequences):
return True
# If it compiles but has no regex features, it might be a literal that happens to compile
# (e.g., "test" compiles as regex but is just literal text)
# In this case, if it has escape sequences, it's definitely regex
if has_escape_sequences:
return True
# Otherwise, treat as literal
return False
def custom_word_list_recogniser(custom_list: List[str] = list()):
# Create regex pattern, handling quotes carefully
# Supports both literal strings and regex patterns
quote_str = '"'
replace_str = '(?:"|"|")'
regex_patterns = []
literal_patterns = []
# Separate regex patterns from literal strings
for term in custom_list:
term = term.strip()
if not term:
continue
if _is_regex_pattern(term):
# Use regex pattern as-is (but wrap with word boundaries if appropriate)
# Note: Word boundaries might not be appropriate for all regex patterns
# (e.g., email patterns), so we'll add them conditionally
regex_patterns.append(term)
else:
# Escape literal strings and add word boundaries
escaped_term = re.escape(term).replace(quote_str, replace_str)
literal_patterns.append(rf"(?<!\w){escaped_term}(?!\w)")
# Combine patterns: regex patterns first, then literal patterns
all_patterns = []
# Add regex patterns (without word boundaries, as they may have their own)
for pattern in regex_patterns:
all_patterns.append(f"({pattern})")
# Add literal patterns (with word boundaries)
all_patterns.extend(literal_patterns)
if not all_patterns:
# Return empty recognizer if no patterns
custom_pattern = Pattern(
name="custom_pattern", regex="(?!)", score=1
) # Never matches
else:
custom_regex = "|".join(all_patterns)
# print(custom_regex)
custom_pattern = Pattern(name="custom_pattern", regex=custom_regex, score=1)
custom_recogniser = PatternRecognizer(
supported_entity="CUSTOM",
name="CUSTOM",
patterns=[custom_pattern],
global_regex_flags=re.DOTALL | re.MULTILINE | re.IGNORECASE,
)
return custom_recogniser
# Initialise custom recogniser that will be overwritten later
custom_recogniser = custom_word_list_recogniser()
# Custom title recogniser
titles_list = [
"Sir",
"Ma'am",
"Madam",
"Mr",
"Mr.",
"Mrs",
"Mrs.",
"Ms",
"Ms.",
"Miss",
"Dr",
"Dr.",
"Professor",
]
titles_regex = (
"\\b" + "\\b|\\b".join(rf"{re.escape(title)}" for title in titles_list) + "\\b"
)
titles_pattern = Pattern(name="titles_pattern", regex=titles_regex, score=1)
titles_recogniser = PatternRecognizer(
supported_entity="TITLES",
name="TITLES",
patterns=[titles_pattern],
global_regex_flags=re.DOTALL | re.MULTILINE,
)
# %%
# Custom postcode recogniser
# Define the regex pattern in a Presidio `Pattern` object:
ukpostcode_pattern = Pattern(
name="ukpostcode_pattern",
regex=r"\b([A-Z]{1,2}\d[A-Z\d]? ?\d[A-Z]{2}|GIR ?0AA)\b",
score=1,
)
# Define the recognizer with one or more patterns
ukpostcode_recogniser = PatternRecognizer(
supported_entity="UKPOSTCODE", name="UKPOSTCODE", patterns=[ukpostcode_pattern]
)
### Street name
def extract_street_name(text: str) -> str:
"""
Extracts the street name and preceding word (that should contain at least one number) from the given text.
"""
street_types = [
"Street",
"St",
"Boulevard",
"Blvd",
"Highway",
"Hwy",
"Broadway",
"Freeway",
"Causeway",
"Cswy",
"Expressway",
"Way",
"Walk",
"Lane",
"Ln",
"Road",
"Rd",
"Avenue",
"Ave",
"Circle",
"Cir",
"Cove",
"Cv",
"Drive",
"Dr",
"Parkway",
"Pkwy",
"Park",
"Court",
"Ct",
"Square",
"Sq",
"Loop",
"Place",
"Pl",
"Parade",
"Estate",
"Alley",
"Arcade",
"Avenue",
"Ave",
"Bay",
"Bend",
"Brae",
"Byway",
"Close",
"Corner",
"Cove",
"Crescent",
"Cres",
"Cul-de-sac",
"Dell",
"Drive",
"Dr",
"Esplanade",
"Glen",
"Green",
"Grove",
"Heights",
"Hts",
"Mews",
"Parade",
"Path",
"Piazza",
"Promenade",
"Quay",
"Ridge",
"Row",
"Terrace",
"Ter",
"Track",
"Trail",
"View",
"Villas",
"Marsh",
"Embankment",
"Cut",
"Hill",
"Passage",
"Rise",
"Vale",
"Side",
]
# Dynamically construct the regex pattern with all possible street types
street_types_pattern = "|".join(
rf"{re.escape(street_type)}" for street_type in street_types
)
# The overall regex pattern to capture the street name and preceding word(s)
pattern = r"(?P<preceding_word>\w*\d\w*)\s*"
pattern += rf"(?P<street_name>\w+\s*\b(?:{street_types_pattern})\b)"
# Find all matches in text
matches = re.finditer(pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE)
start_positions = list()
end_positions = list()
for match in matches:
match.group("preceding_word").strip()
match.group("street_name").strip()
start_pos = match.start()
end_pos = match.end()
# print(f"Start: {start_pos}, End: {end_pos}")
# print(f"Preceding words: {preceding_word}")
# print(f"Street name: {street_name}")
start_positions.append(start_pos)
end_positions.append(end_pos)
return start_positions, end_positions
class StreetNameRecognizer(EntityRecognizer):
def load(self) -> None:
"""No loading is required."""
pass
def analyze(
self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts
) -> List[RecognizerResult]:
"""
Logic for detecting a specific PII
"""
start_pos, end_pos = extract_street_name(text)
results = list()
for i in range(0, len(start_pos)):
result = RecognizerResult(
entity_type="STREETNAME", start=start_pos[i], end=end_pos[i], score=1
)
results.append(result)
return results
street_recogniser = StreetNameRecognizer(supported_entities=["STREETNAME"])
## Custom fuzzy match recogniser for list of strings
def custom_fuzzy_word_list_regex(text: str, custom_list: List[str] = list()):
# Create regex pattern, handling quotes carefully
quote_str = '"'
replace_str = '(?:"|"|")'
custom_regex_pattern = "|".join(
rf"(?<!\w){re.escape(term.strip()).replace(quote_str, replace_str)}(?!\w)"
for term in custom_list
)
# Find all matches in text
matches = re.finditer(
custom_regex_pattern, text, re.DOTALL | re.MULTILINE | re.IGNORECASE
)
start_positions = list()
end_positions = list()
for match in matches:
start_pos = match.start()
end_pos = match.end()
start_positions.append(start_pos)
end_positions.append(end_pos)
return start_positions, end_positions
class CustomWordFuzzyRecognizer(EntityRecognizer):
def __init__(
self,
supported_entities: List[str],
custom_list: List[str] = list(),
spelling_mistakes_max: int = 1,
search_whole_phrase: bool = True,
):
super().__init__(supported_entities=supported_entities)
self.custom_list = custom_list # Store the custom_list as an instance attribute
self.spelling_mistakes_max = (
spelling_mistakes_max # Store the max spelling mistakes
)
self.search_whole_phrase = (
search_whole_phrase # Store the search whole phrase flag
)
def load(self) -> None:
"""No loading is required."""
pass
def analyze(
self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts
) -> List[RecognizerResult]:
"""
Logic for detecting a specific PII
"""
start_pos, end_pos = spacy_fuzzy_search(
text, self.custom_list, self.spelling_mistakes_max, self.search_whole_phrase
) # Pass new parameters
results = list()
for i in range(0, len(start_pos)):
result = RecognizerResult(
entity_type="CUSTOM_FUZZY", start=start_pos[i], end=end_pos[i], score=1
)
results.append(result)
return results
custom_list_default = list()
custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(
supported_entities=["CUSTOM_FUZZY"], custom_list=custom_list_default
)
# Pass the loaded model to the new LoadedSpacyNlpEngine
loaded_nlp_engine = LoadedSpacyNlpEngine(
loaded_spacy_model=nlp, language_code=ACTIVE_LANGUAGE_CODE
)
def create_nlp_analyser(
language: str = DEFAULT_LANGUAGE,
custom_list: List[str] = None,
spelling_mistakes_max: int = 1,
search_whole_phrase: bool = True,
existing_nlp_analyser: AnalyzerEngine = None,
return_also_model: bool = False,
):
"""
Create an nlp_analyser object based on the specified language input.
Args:
language (str): Language code (e.g., "en", "de", "fr", "es", etc.)
custom_list (List[str], optional): List of custom words to recognize. Defaults to None.
spelling_mistakes_max (int, optional): Maximum number of spelling mistakes for fuzzy matching. Defaults to 1.
search_whole_phrase (bool, optional): Whether to search for whole phrases or individual words. Defaults to True.
existing_nlp_analyser (AnalyzerEngine, optional): Existing nlp_analyser object to use. Defaults to None.
return_also_model (bool, optional): Whether to return the nlp_model object as well. Defaults to False.
Returns:
AnalyzerEngine: Configured nlp_analyser object with custom recognizers
"""
if existing_nlp_analyser is None:
pass
else:
if existing_nlp_analyser.supported_languages[0] == language:
nlp_analyser = existing_nlp_analyser
print(f"Using existing nlp_analyser for {language}")
return nlp_analyser
# Load spaCy model for the specified language
nlp_model = load_spacy_model(language)
# Get base language code
base_lang_code = _base_language_code(language)
# Create custom recognizers
if custom_list is None:
custom_list = list()
custom_recogniser = custom_word_list_recogniser(custom_list)
custom_word_fuzzy_recognizer = CustomWordFuzzyRecognizer(
supported_entities=["CUSTOM_FUZZY"],
custom_list=custom_list,
spelling_mistakes_max=spelling_mistakes_max,
search_whole_phrase=search_whole_phrase,
)
# Create NLP engine with loaded model
loaded_nlp_engine = LoadedSpacyNlpEngine(
loaded_spacy_model=nlp_model, language_code=base_lang_code
)
# Create analyzer engine
nlp_analyser = AnalyzerEngine(
nlp_engine=loaded_nlp_engine,
default_score_threshold=score_threshold,
supported_languages=[base_lang_code],
log_decision_process=False,
)
# Add custom recognizers to nlp_analyser
nlp_analyser.registry.add_recognizer(custom_recogniser)
nlp_analyser.registry.add_recognizer(custom_word_fuzzy_recognizer)
# Add language-specific recognizers for English
if base_lang_code == "en":
nlp_analyser.registry.add_recognizer(street_recogniser)
nlp_analyser.registry.add_recognizer(ukpostcode_recogniser)
nlp_analyser.registry.add_recognizer(titles_recogniser)
if return_also_model:
return nlp_analyser, nlp_model
return nlp_analyser
# Create the default nlp_analyser using the new function
nlp_analyser, nlp = create_nlp_analyser(DEFAULT_LANGUAGE, return_also_model=True)
def spacy_fuzzy_search(
text: str,
custom_query_list: List[str] = list(),
spelling_mistakes_max: int = 1,
search_whole_phrase: bool = True,
nlp=nlp,
progress=gr.Progress(track_tqdm=True),
):
"""Conduct fuzzy match on a list of text data."""
all_matches = list()
all_start_positions = list()
all_end_positions = list()
all_ratios = list()
# print("custom_query_list:", custom_query_list)
if not text:
out_message = "No text data found. Skipping page."
print(out_message)
return all_start_positions, all_end_positions
for string_query in custom_query_list:
query = nlp(string_query)
if search_whole_phrase is False:
# Keep only words that are not stop words
token_query = [
token.text
for token in query
if not token.is_space and not token.is_stop and not token.is_punct
]
spelling_mistakes_fuzzy_pattern = "FUZZY" + str(spelling_mistakes_max)
if len(token_query) > 1:
# pattern_lemma = [{"LEMMA": {"IN": query}}]
pattern_fuzz = [
{"TEXT": {spelling_mistakes_fuzzy_pattern: {"IN": token_query}}}
]
else:
# pattern_lemma = [{"LEMMA": query[0]}]
pattern_fuzz = [
{"TEXT": {spelling_mistakes_fuzzy_pattern: token_query[0]}}
]
matcher = Matcher(nlp.vocab)
matcher.add(string_query, [pattern_fuzz])
# matcher.add(string_query, [pattern_lemma])
else:
# If matching a whole phrase, use Spacy PhraseMatcher, then consider similarity after using Levenshtein distance.
# If you want to match the whole phrase, use phrase matcher
matcher = FuzzyMatcher(nlp.vocab)
patterns = [nlp.make_doc(string_query)] # Convert query into a Doc object
matcher.add("PHRASE", patterns, [{"ignore_case": True}])
batch_size = 256
docs = nlp.pipe([text], batch_size=batch_size)
# Get number of matches per doc
for doc in docs: # progress.tqdm(docs, desc = "Searching text", unit = "rows"):
matches = matcher(doc)
match_count = len(matches)
# If considering each sub term individually, append match. If considering together, consider weight of the relevance to that of the whole phrase.
if search_whole_phrase is False:
all_matches.append(match_count)
for match_id, start, end in matches:
span = str(doc[start:end]).strip()
query_search = str(query).strip()
# Convert word positions to character positions
start_char = doc[start].idx # Start character position
end_char = doc[end - 1].idx + len(
doc[end - 1]
) # End character position
# The positions here are word position, not character position
all_matches.append(match_count)
all_start_positions.append(start_char)
all_end_positions.append(end_char)
else:
for match_id, start, end, ratio, pattern in matches:
span = str(doc[start:end]).strip()
query_search = str(query).strip()
# Calculate Levenshtein distance. Only keep matches with less than specified number of spelling mistakes
distance = Levenshtein.distance(query_search.lower(), span.lower())
# print("Levenshtein distance:", distance)
if distance > spelling_mistakes_max:
match_count = match_count - 1
else:
# Convert word positions to character positions
start_char = doc[start].idx # Start character position
end_char = doc[end - 1].idx + len(
doc[end - 1]
) # End character position
all_matches.append(match_count)
all_start_positions.append(start_char)
all_end_positions.append(end_char)
all_ratios.append(ratio)
return all_start_positions, all_end_positions