Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from langchain_qdrant import QdrantVectorStore | |
from langchain_qdrant import FastEmbedSparse, RetrievalMode | |
from torch import cuda | |
import streamlit as st | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from appStore.prep_utils import get_client | |
# get the device to be used eithe gpu or cpu | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
def hybrid_embed_chunks(docs, collection_name, del_if_exists = False): | |
""" | |
takes the chunks and does the hybrid embedding for the list of chunks | |
""" | |
# Dense Embeddings function | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3' | |
) | |
# Sparse Embedding Function | |
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") | |
# get exisitng client | |
client = get_client() | |
# create collection | |
if del_if_exists: | |
client.delete_collection(collection_name=f"{collection_name}") | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config={ | |
"text-dense": models.VectorParams(size=1024, distance=models.Distance.COSINE, on_disk = True) | |
}, | |
sparse_vectors_config={ | |
"text-sparse": models.SparseVectorParams(index=models.SparseIndexParams( | |
on_disk=True, | |
) | |
) | |
},) | |
# create Vector store | |
vector_store = QdrantVectorStore( | |
client=client, | |
collection_name=collection_name, | |
embedding=embeddings, | |
vector_name="text-dense", | |
sparse_embedding = sparse_embeddings, | |
sparse_vector_name="text-sparse", | |
retrieval_mode=RetrievalMode.HYBRID, | |
) | |
print("starting embedding") | |
vector_store.add_documents(docs) | |
print("vector embeddings done") | |
def get_local_qdrant(collection_name): | |
"""once the local qdrant server is created this is used to make the connection to exisitng server""" | |
qdrant_collections = {} | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3') | |
client = QdrantClient(path="/data/local_qdrant") | |
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") | |
print("Collections in local Qdrant:",client.get_collections()) | |
qdrant_collections[collection_name] = Qdrant(client=client, collection_name=collection_name, | |
embeddings=embeddings, | |
) | |
return qdrant_collections |