import gradio as gr import os from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFaceEndpoint from pathlib import Path import chromadb from unidecode import unidecode from transformers import AutoTokenizer import transformers import torch import tqdm import accelerate import re # LlamaParse import from llama_parse import LlamaParse import asyncio from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs from llama_index.core.base.response.schema import PydanticResponse from llama_index.core.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.core.callbacks.base import CallbackManager from llama_index.core.llms.llm import LLM from llama_index.core.node_parser.interface import NodeParser from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode from llama_index.core.utils import get_tqdm_iterable from io import StringIO from typing import Any, Callable, List, Optional import pandas as pd from llama_index.core.node_parser.relational.base_element import ( # BaseElementNodeParser, Element, ) from llama_index.core.schema import BaseNode, TextNode api_token = os.getenv("HF_TOKEN") # Implementations # default_persist_directory = './chroma_HF/' list_llm = ["mistralai/Mistral-7B-Instruct-v0.3", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \ "google/gemma-7b-it","google/gemma-2b-it", \ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \ "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Load PDF document and create doc splits def load_doc(list_file_path, chunk_size, chunk_overlap): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size = chunk_size, chunk_overlap = chunk_overlap) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, ) return vectordb # Load vector database def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma( embedding_function=embedding) return vectordb # Initialize langchain LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): progress(0.1, desc="Initializing HF tokenizer...") progress(0.5, desc="Initializing HF Hub...") if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.3": llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, load_in_8bit = True, ) elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]: raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint") llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, ) elif llm_model == "microsoft/phi-2": llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, trust_remote_code = True, torch_dtype = "auto", ) elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0": llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = 250, top_k = top_k, ) elif llm_model == "meta-llama/Llama-2-7b-chat-hf": raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...") llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, ) else: llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token = api_token, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, ) progress(0.75, desc="Defining buffer memory...") memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) retriever=vector_db.as_retriever() progress(0.8, desc="Defining retrieval chain...") qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) progress(0.9, desc="Done!") return qa_chain # Generate collection name for vector database def create_collection_name(filepath): collection_name = Path(filepath).stem collection_name = collection_name.replace(" ","-") collection_name = unidecode(collection_name) collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name) collection_name = collection_name[:50] if len(collection_name) < 3: collection_name = collection_name + 'xyz' if not collection_name[0].isalnum(): collection_name = 'A' + collection_name[1:] if not collection_name[-1].isalnum(): collection_name = collection_name[:-1] + 'Z' print('Filepath: ', filepath) print('Collection name: ', collection_name) return collection_name # Initialize database def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()): list_file_path = [x.name for x in list_file_obj if x is not None] progress(0.1, desc="Creating collection name...") collection_name = create_collection_name(list_file_path[0]) progress(0.25, desc="Loading document...") doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) progress(0.5, desc="Generating vector database...") vector_db = create_db(doc_splits, collection_name) progress(0.9, desc="Done!") return vector_db, collection_name, "Complete!" def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): llm_name = list_llm[llm_option] print("llm_name: ",llm_name) qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) return qa_chain, "Complete!" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history): formatted_chat_history = format_chat_history(message, history) response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] response_sources = response["source_documents"] response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source3 = response_sources[2].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 response_source3_page = response_sources[2].metadata["page"] + 1 new_history = history + [(message, response_answer)] return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page def upload_file(file_obj): list_file_path = [] for idx, file in enumerate(file_obj): file_path = file_obj.name list_file_path.append(file_path) return list_file_path # Initialize LlamaIndex parsing def initialize_llama_index(file_obj): documents = LlamaParse(result_type="markdown",api_key=secret_value_0).load_data(file_obj.name) node_parser = MarkdownElementNodeParser(llm = None, num_workers=8) nodes = node_parser.get_nodes_from_documents(documents) base_nodes, objects = node_parser.get_nodes_and_objects(nodes) index_with_obj = VectorStoreIndex(nodes=base_nodes+objects) index_ret = index_with_obj.as_retriever(top_k=15) recursive_query_engine = RetrieverQueryEngine.from_args(index_ret, node_postprocessors=[FlagEmbeddingReranker( top_n=5, model="BAAI/bge-reranker-large", )], verbose=False) return recursive_query_engine, "LlamaIndex parsing complete" def demo(): with gr.Blocks(theme="base") as demo: vector_db = gr.State() qa_chain = gr.State() collection_name = gr.State() llama_index_engine = gr.State() gr.Markdown( """