Spaces:
Sleeping
Sleeping
File size: 1,186 Bytes
45c901d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from qdrant_client.http import models
import pickle as pickle
import torch
import io
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
class Device_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
return lambda b: torch.load(io.BytesIO(b), map_location=device_str)
else:
return super().find_class(module, name)
def pickle_to_document_store(path):
with open(path, "rb") as f:
document_store = Device_Unpickler(f).load()
document_store.embeddings.encode_kwargs["device"] = device_str
return document_store
def get_qdrant_filters(filter_dict: dict):
"""Build a Qdrant filter based on a filter dict.
Filter dict must use metadata fields and be formated like:
filter_dict = {'file_name':['file1', 'file2'],'sub_type':['text']}
"""
return models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{field}",
match=models.MatchAny(any=filter_dict[field]),
)
for field in filter_dict
]
)
|