Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from qdrant_client import models, QdrantClient | |
| from sentence_transformers import SentenceTransformer | |
| from PyPDF2 import PdfReader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.callbacks.manager import CallbackManager | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from ctransformers import AutoModelForCausalLM | |
| # Load the embedding model | |
| encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1') | |
| print("Embedding model loaded...") | |
| # Load the LLM | |
| callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
| llm = AutoModelForCausalLM.from_pretrained( | |
| "TheBloke/Llama-2-7B-Chat-GGUF", | |
| model_file="llama-2-7b-chat.Q3_K_S.gguf", | |
| model_type="llama", | |
| temperature=0.2, | |
| repetition_penalty=1.5, | |
| max_new_tokens=300, | |
| ) | |
| print("LLM loaded...") | |
| # Initialize QdrantClient | |
| client = QdrantClient(path="./db") | |
| print("DB created...") | |
| def get_chunks(text): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=250, | |
| chunk_overlap=50, | |
| length_function=len, | |
| ) | |
| return text_splitter.split_text(text) | |
| def setup_database(files): | |
| all_chunks = [] | |
| for file in files: | |
| reader = PdfReader(file) | |
| text = "".join(page.extract_text() for page in reader.pages) | |
| chunks = get_chunks(text) | |
| all_chunks.extend(chunks) | |
| print(f"Total chunks: {len(all_chunks)}") | |
| print("Chunks are ready...") | |
| client.recreate_collection( | |
| collection_name="my_facts", | |
| vectors_config=models.VectorParams( | |
| size=encoder.get_sentence_embedding_dimension(), | |
| distance=models.Distance.COSINE, | |
| ), | |
| ) | |
| print("Collection created...") | |
| records = [ | |
| models.Record( | |
| id=idx, | |
| vector=encoder.encode(chunk).tolist(), | |
| payload={"text": chunk} | |
| ) for idx, chunk in enumerate(all_chunks) | |
| ] | |
| client.upload_records( | |
| collection_name="my_facts", | |
| records=records, | |
| ) | |
| print("Records uploaded...") | |
| def answer(question): | |
| hits = client.search( | |
| collection_name="my_facts", | |
| query_vector=encoder.encode(question).tolist(), | |
| limit=3 | |
| ) | |
| context = " ".join(hit.payload["text"] for hit in hits) | |
| system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions. | |
| Read the given context before answering questions and think step by step. If you cannot answer a user question based on | |
| the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question.""" | |
| B_INST, E_INST = "[INST]", "[/INST]" | |
| B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
| instruction = f"Context: {context}\nUser: {question}" | |
| prompt_template = f"{B_INST}{B_SYS}{system_prompt}{E_SYS}{instruction}{E_INST}" | |
| print(prompt_template) | |
| result = llm(prompt_template) | |
| return result | |
| def chat(messages, files): | |
| if files: | |
| setup_database(files) | |
| if not messages: | |
| return "Please upload PDF documents to initialize the database." | |
| last_message = messages[-1]["content"] | |
| response = answer(last_message) | |
| messages.append({"role": "assistant", "content": response}) | |
| return messages | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot() | |
| file_input = gr.File(label="Upload PDFs", file_count="multiple") | |
| with gr.Row(): | |
| with gr.Column(scale=0.85): | |
| txt = gr.Textbox(show_label=False, placeholder="Enter your question here...").style(container=False) | |
| with gr.Column(scale=0.15, min_width=0): | |
| send_btn = gr.Button("Send") | |
| def respond(messages, files, txt): | |
| messages = chat(messages, files) | |
| return messages, None, "" | |
| send_btn.click(respond, [chatbot, file_input, txt], [chatbot, file_input, txt]) | |
| txt.submit(respond, [chatbot, file_input, txt], [chatbot, file_input, txt]) | |
| demo.launch() | |