ppsingh's picture
Update appStore/embed.py
b095384 verified
raw
history blame
2.02 kB
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient
from langchain_qdrant import FastEmbedSparse, RetrievalMode
from torch import cuda
import streamlit as st
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
def hybrid_embed_chunks(chunks):
"""
takes the chunks and does the hybrid embedding for the list of chunks
"""
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3'
)
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
# placeholder for collection
print("starting embedding")
#qdrant_collections = {}
Qdrant.from_documents(
chunks,
embeddings,
sparse_embeddings = sparse_embeddings,
path="/data/local_qdrant",
collection_name='giz_worldwide',
retrieval_mode=RetrievalMode.HYBRID,
)
print(qdrant_collections)
print("vector embeddings done")
@st.cache_resource
def get_local_qdrant(collection_name):
"""once the local qdrant server is created this is used to make the connection to exisitng server"""
qdrant_collections = {}
embeddings = HuggingFaceEmbeddings(
model_kwargs = {'device': device},
encode_kwargs = {'normalize_embeddings': True},
model_name='BAAI/bge-m3')
client = QdrantClient(path="/data/local_qdrant")
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
print("Collections in local Qdrant:",client.get_collections())
qdrant_collections[collection_name] = Qdrant(client=client, collection_name=collection_name,
embeddings=embeddings,
)
return qdrant_collections