Spaces:
Running
Running
"""Functionality for splitting text.""" | |
from __future__ import annotations | |
from typing import Any, Optional, cast | |
from core.model_manager import ModelInstance | |
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer | |
from core.splitter.text_splitter import ( | |
TS, | |
Collection, | |
Literal, | |
RecursiveCharacterTextSplitter, | |
Set, | |
TokenTextSplitter, | |
Union, | |
) | |
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): | |
""" | |
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken | |
""" | |
def from_encoder( | |
cls: type[TS], | |
embedding_model_instance: Optional[ModelInstance], | |
allowed_special: Union[Literal[all], Set[str]] = set(), | |
disallowed_special: Union[Literal[all], Collection[str]] = "all", | |
**kwargs: Any, | |
): | |
def _token_encoder(text: str) -> int: | |
if not text: | |
return 0 | |
if embedding_model_instance: | |
embedding_model_type_instance = embedding_model_instance.model_type_instance | |
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |
return embedding_model_type_instance.get_num_tokens( | |
model=embedding_model_instance.model, | |
credentials=embedding_model_instance.credentials, | |
texts=[text] | |
) | |
else: | |
return GPT2Tokenizer.get_num_tokens(text) | |
if issubclass(cls, TokenTextSplitter): | |
extra_kwargs = { | |
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2', | |
"allowed_special": allowed_special, | |
"disallowed_special": disallowed_special, | |
} | |
kwargs = {**kwargs, **extra_kwargs} | |
return cls(length_function=_token_encoder, **kwargs) | |
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): | |
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): | |
"""Create a new TextSplitter.""" | |
super().__init__(**kwargs) | |
self._fixed_separator = fixed_separator | |
self._separators = separators or ["\n\n", "\n", " ", ""] | |
def split_text(self, text: str) -> list[str]: | |
"""Split incoming text and return chunks.""" | |
if self._fixed_separator: | |
chunks = text.split(self._fixed_separator) | |
else: | |
chunks = list(text) | |
final_chunks = [] | |
for chunk in chunks: | |
if self._length_function(chunk) > self._chunk_size: | |
final_chunks.extend(self.recursive_split_text(chunk)) | |
else: | |
final_chunks.append(chunk) | |
return final_chunks | |
def recursive_split_text(self, text: str) -> list[str]: | |
"""Split incoming text and return chunks.""" | |
final_chunks = [] | |
# Get appropriate separator to use | |
separator = self._separators[-1] | |
for _s in self._separators: | |
if _s == "": | |
separator = _s | |
break | |
if _s in text: | |
separator = _s | |
break | |
# Now that we have the separator, split the text | |
if separator: | |
splits = text.split(separator) | |
else: | |
splits = list(text) | |
# Now go merging things, recursively splitting longer texts. | |
_good_splits = [] | |
for s in splits: | |
if self._length_function(s) < self._chunk_size: | |
_good_splits.append(s) | |
else: | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, separator) | |
final_chunks.extend(merged_text) | |
_good_splits = [] | |
other_info = self.recursive_split_text(s) | |
final_chunks.extend(other_info) | |
if _good_splits: | |
merged_text = self._merge_splits(_good_splits, separator) | |
final_chunks.extend(merged_text) | |
return final_chunks | |