Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import pickle | |
| import chromadb | |
| import logfire | |
| from custom_retriever import CustomRetriever | |
| from dotenv import load_dotenv | |
| from llama_index.core import Document, SimpleKeywordTableIndex, VectorStoreIndex | |
| from llama_index.core.ingestion import IngestionPipeline | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.core.retrievers import ( | |
| KeywordTableSimpleRetriever, | |
| VectorIndexRetriever, | |
| ) | |
| from llama_index.core.schema import NodeWithScore, QueryBundle | |
| from llama_index.embeddings.cohere import CohereEmbedding | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.vector_stores.chroma import ChromaVectorStore | |
| from utils import init_mongo_db | |
| load_dotenv() | |
| logfire.configure() | |
| if not os.path.exists("data/chroma-db-all_sources"): | |
| # Download the vector database from the Hugging Face Hub if it doesn't exist locally | |
| # https://huggingface.co/datasets/towardsai-buster/ai-tutor-vector-db/tree/main | |
| logfire.warn( | |
| f"Vector database does not exist at 'data/chroma-db-all_sources', downloading from Hugging Face Hub" | |
| ) | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id="towardsai-tutors/ai-tutor-vector-db", | |
| local_dir="data", | |
| repo_type="dataset", | |
| ) | |
| logfire.info(f"Downloaded vector database to 'data/chroma-db-all_sources'") | |
| def create_docs(input_file: str) -> list[Document]: | |
| with open(input_file, "r") as f: | |
| documents = [] | |
| for line in f: | |
| data = json.loads(line) | |
| documents.append( | |
| Document( | |
| doc_id=data["doc_id"], | |
| text=data["content"], | |
| metadata={ # type: ignore | |
| "url": data["url"], | |
| "title": data["name"], | |
| "tokens": data["tokens"], | |
| "retrieve_doc": data["retrieve_doc"], | |
| "source": data["source"], | |
| }, | |
| excluded_llm_metadata_keys=[ | |
| "title", | |
| "tokens", | |
| "retrieve_doc", | |
| "source", | |
| ], | |
| excluded_embed_metadata_keys=[ | |
| "url", | |
| "tokens", | |
| "retrieve_doc", | |
| "source", | |
| ], | |
| ) | |
| ) | |
| return documents | |
| def setup_database(db_collection, dict_file_name) -> CustomRetriever: | |
| db = chromadb.PersistentClient(path=f"data/{db_collection}") | |
| chroma_collection = db.get_or_create_collection(db_collection) | |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
| embed_model = CohereEmbedding( | |
| api_key=os.environ["COHERE_API_KEY"], | |
| model_name="embed-english-v3.0", | |
| input_type="search_query", | |
| ) | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, | |
| transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)], | |
| show_progress=True, | |
| # use_async=True, | |
| ) | |
| vector_retriever = VectorIndexRetriever( | |
| index=index, | |
| similarity_top_k=15, | |
| embed_model=embed_model, | |
| # use_async=True, | |
| ) | |
| with open(f"data/{db_collection}/{dict_file_name}", "rb") as f: | |
| document_dict = pickle.load(f) | |
| # with open("data/keyword_retriever_sync.pkl", "rb") as f: | |
| # keyword_retriever: KeywordTableSimpleRetriever = pickle.load(f) | |
| # keyword_retriever.num_chunks_per_query = 15 | |
| # # Creating the keyword index and retriever | |
| # logfire.info("Creating nodes from documents") | |
| # documents = create_docs("data/all_sources_data.jsonl") | |
| # pipeline = IngestionPipeline( | |
| # transformations=[SentenceSplitter(chunk_size=800, chunk_overlap=0)] | |
| # ) | |
| # all_nodes = pipeline.run(documents=documents, show_progress=True) | |
| # # with open("data/all_nodes.pkl", "wb") as f: | |
| # # pickle.dump(all_nodes, f) | |
| # all_nodes = pickle.load(open("data/nodes_with_added_context.pkl", "rb")) | |
| # logfire.info(f"Number of nodes: {len(all_nodes)}") | |
| # keyword_index = SimpleKeywordTableIndex( | |
| # nodes=all_nodes, max_keywords_per_chunk=10, show_progress=True, use_async=False | |
| # ) | |
| # # with open("data/keyword_index.pkl", "wb") as f: | |
| # # pickle.dump(keyword_index, f) | |
| # # keyword_index = pickle.load(open("data/keyword_index.pkl", "rb")) | |
| # logfire.info("Creating keyword retriever") | |
| # keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index) | |
| # with open("data/keyword_retriever_sync.pkl", "wb") as f: | |
| # pickle.dump(keyword_retriever, f) | |
| # 'OR' Means both the vector nodes and the keyword nodes | |
| # return CustomRetriever(vector_retriever, document_dict, keyword_retriever, "OR") | |
| return CustomRetriever(vector_retriever, document_dict) | |
| # Setup retrievers | |
| # custom_retriever_transformers: CustomRetriever = setup_database( | |
| # "chroma-db-transformers", | |
| # "document_dict_transformers.pkl", | |
| # ) | |
| # custom_retriever_peft: CustomRetriever = setup_database( | |
| # "chroma-db-peft", "document_dict_peft.pkl" | |
| # ) | |
| # custom_retriever_trl: CustomRetriever = setup_database( | |
| # "chroma-db-trl", "document_dict_trl.pkl" | |
| # ) | |
| # custom_retriever_llama_index: CustomRetriever = setup_database( | |
| # "chroma-db-llama_index", | |
| # "document_dict_llama_index.pkl", | |
| # ) | |
| # custom_retriever_openai_cookbooks: CustomRetriever = setup_database( | |
| # "chroma-db-openai_cookbooks", | |
| # "document_dict_openai_cookbooks.pkl", | |
| # ) | |
| # custom_retriever_langchain: CustomRetriever = setup_database( | |
| # "chroma-db-langchain", | |
| # "document_dict_langchain.pkl", | |
| # ) | |
| custom_retriever_all_sources: CustomRetriever = setup_database( | |
| "chroma-db-all_sources", | |
| "document_dict_all_sources.pkl", | |
| ) | |
| # Constants | |
| CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) | |
| MONGODB_URI = os.getenv("MONGODB_URI") | |
| AVAILABLE_SOURCES_UI = [ | |
| "Transformers Docs", | |
| "PEFT Docs", | |
| "TRL Docs", | |
| "LlamaIndex Docs", | |
| "LangChain Docs", | |
| "OpenAI Cookbooks", | |
| "Towards AI Blog", | |
| "8 Hour Primer", | |
| "Advanced LLM Developer", | |
| # "All Sources", | |
| ] | |
| AVAILABLE_SOURCES = [ | |
| "transformers", | |
| "peft", | |
| "trl", | |
| "llama_index", | |
| "langchain", | |
| "openai_cookbooks", | |
| "tai_blog", | |
| "8-hour_primer", | |
| "llm_developer", | |
| # "all_sources", | |
| ] | |
| mongo_db = ( | |
| init_mongo_db(uri=MONGODB_URI, db_name="towardsai-buster") | |
| if MONGODB_URI | |
| else logfire.warn("No mongodb uri found, you will not be able to save data.") | |
| ) | |
| __all__ = [ | |
| # "custom_retriever_transformers", | |
| # "custom_retriever_peft", | |
| # "custom_retriever_trl", | |
| # "custom_retriever_llama_index", | |
| # "custom_retriever_openai_cookbooks", | |
| # "custom_retriever_langchain", | |
| "custom_retriever_all_sources", | |
| "mongo_db", | |
| "CONCURRENCY_COUNT", | |
| "AVAILABLE_SOURCES_UI", | |
| "AVAILABLE_SOURCES", | |
| ] | |