Tonic commited on
Commit
67d3ae5
·
1 Parent(s): 9b90d69

added langchain app for simplicity

Browse files
Files changed (3) hide show
  1. README.md +1 -12
  2. langchainapp.py +227 -0
  3. requirements.txt +1 -0
README.md CHANGED
@@ -5,19 +5,8 @@ colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
- app_file: app.py
9
  pinned: true
10
  license: mit
11
  ---
12
 
13
- **run chroma first:**
14
-
15
- ```sh
16
- chroma run --host localhost --port 8000
17
- ```
18
-
19
- **then**
20
-
21
- ```sh
22
- python3 app.py
23
- ```
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
+ app_file: langchainapp.py
9
  pinned: true
10
  license: mit
11
  ---
12
 
 
 
 
 
 
 
 
 
 
 
 
langchainapp.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import spaces
3
+ from torch.nn import DataParallel
4
+ from torch import Tensor
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from huggingface_hub import InferenceClient
7
+ from openai import OpenAI
8
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain_community.document_loaders import UnstructuredFileLoader
10
+ from langchain_chroma import Chroma
11
+ from chromadb import Documents, EmbeddingFunction, Embeddings
12
+ from chromadb.config import Settings
13
+ import chromadb #import HttpClient
14
+ from typing import List, Tuple, Dict, Any
15
+ import os
16
+ import re
17
+ import uuid
18
+ import gradio as gr
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from dotenv import load_dotenv
22
+ from utils import load_env_variables, parse_and_route , escape_special_characters
23
+ from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
24
+ # import time
25
+ # import httpx
26
+ from langchain_community.document_loaders import PyPDFLoader
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain.embeddings import HuggingFaceEmbeddings
29
+ from langchain.chat_models import ChatOpenAI
30
+
31
+ from langchain.retrievers.document_compressors import LLMChainExtractor
32
+ from langchain.retrievers.multi_query import MultiQueryRetriever
33
+ from langchain.retrievers import ContextualCompressionRetriever
34
+ from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
35
+ # from langchain.vectorstores import Chroma
36
+
37
+
38
+
39
+ load_dotenv()
40
+
41
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
42
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
43
+ os.environ['CUDA_CACHE_DISABLE'] = '1'
44
+
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ ### Utils
48
+ hf_token, yi_token = load_env_variables()
49
+
50
+ def clear_cuda_cache():
51
+ torch.cuda.empty_cache()
52
+
53
+ client = OpenAI(api_key=yi_token, base_url=API_BASE)
54
+
55
+ chroma_client = chromadb.Client(Settings())
56
+
57
+ # Create a collection
58
+ chroma_collection = chroma_client.create_collection("all-my-documents")
59
+
60
+ class MyEmbeddingFunction(EmbeddingFunction):
61
+ def __init__(self, model_name: str, token: str, intention_client):
62
+ self.model_name = model_name
63
+ self.token = token
64
+ self.intention_client = intention_client
65
+ self.hf_embeddings = HuggingFaceInstructEmbeddings(
66
+ model_name=model_name,
67
+ model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
68
+ encode_kwargs={'normalize_embeddings': True}
69
+ )
70
+
71
+ def create_embedding_generator(self):
72
+ return self.hf_embeddings
73
+
74
+ def __call__(self, input: Documents) -> (List[List[float]], List[Dict[str, Any]]):
75
+ embeddings_with_metadata = [self.compute_embeddings(doc.page_content) for doc in input]
76
+ embeddings = [item[0] for item in embeddings_with_metadata]
77
+ metadata = [item[1] for item in embeddings_with_metadata]
78
+ embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
79
+ metadata_flattened = [meta for sublist in metadata for meta in sublist]
80
+ return embeddings_flattened, metadata_flattened
81
+
82
+ def compute_embeddings(self, input_text: str):
83
+ escaped_input_text = escape_special_characters(input_text)
84
+
85
+ # Get the intention
86
+ intention_completion = self.intention_client.chat.completions.create(
87
+ model="yi-large",
88
+ messages=[
89
+ {"role": "system", "content": escape_special_characters(intention_prompt)},
90
+ {"role": "user", "content": escaped_input_text}
91
+ ]
92
+ )
93
+ intention_output = intention_completion.choices[0].message.content
94
+ parsed_task = parse_and_route(intention_output)
95
+ selected_task = parsed_task if parsed_task in tasks else "DEFAULT"
96
+ task_description = tasks[selected_task]
97
+
98
+ # Construct the embed_instruction and query_instruction dynamically
99
+ embed_instruction = f"Represent the document for retrieval: {task_description}"
100
+ query_instruction = f"Represent the query for retrieval: {task_description}"
101
+
102
+ # Update the hf_embeddings object with the new instructions
103
+ self.hf_embeddings.embed_instruction = embed_instruction
104
+ self.hf_embeddings.query_instruction = query_instruction
105
+
106
+ # Get the metadata
107
+ metadata_completion = self.intention_client.chat.completions.create(
108
+ model="yi-large",
109
+ messages=[
110
+ {"role": "system", "content": escape_special_characters(metadata_prompt)},
111
+ {"role": "user", "content": escaped_input_text}
112
+ ]
113
+ )
114
+ metadata_output = metadata_completion.choices[0].message.content
115
+ metadata = self.extract_metadata(metadata_output)
116
+
117
+ # Get the embeddings
118
+ embeddings = self.hf_embeddings.embed_documents([escaped_input_text])
119
+ return embeddings[0], metadata
120
+
121
+ def extract_metadata(self, metadata_output: str) -> Dict[str, str]:
122
+ pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
123
+ matches = pattern.findall(metadata_output)
124
+ metadata = {key: value for key, value in matches}
125
+ return metadata
126
+
127
+ def load_documents(file_path: str, mode: str = "elements"):
128
+ loader = UnstructuredFileLoader(file_path, mode=mode)
129
+ docs = loader.load()
130
+ return [doc.page_content for doc in docs]
131
+
132
+ def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
133
+ db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
134
+ return db
135
+
136
+ def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction):
137
+ for doc in documents:
138
+ embeddings, metadata = embedding_function.compute_embeddings(doc)
139
+ for embedding, meta in zip(embeddings, metadata):
140
+ chroma_collection.add(
141
+ ids=[str(uuid.uuid1())],
142
+ documents=[doc],
143
+ embeddings=[embedding],
144
+ metadatas=[meta]
145
+ )
146
+
147
+ def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction):
148
+ query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text)
149
+ result_docs = chroma_collection.query(
150
+ query_texts=[query_text],
151
+ n_results=3
152
+ )
153
+ return result_docs
154
+
155
+
156
+ def answer_query(message: str, chat_history: List[Tuple[str, str]]):
157
+ base_compressor = LLMChainExtractor.from_llm(intention_client)
158
+ db = Chroma(persist_directory="output/general_knowledge", embedding_function=embedding_function)
159
+ base_retriever = db.as_retriever()
160
+ mq_retriever = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=intention_client)
161
+ compression_retriever = ContextualCompressionRetriever(base_compressor=base_compressor, base_retriever=mq_retriever)
162
+
163
+ matched_docs = compression_retriever.get_relevant_documents(query=message)
164
+ context = ""
165
+ for doc in matched_docs:
166
+ page_content = doc.page_content
167
+ context += page_content
168
+ context += "\n\n"
169
+
170
+ template = """
171
+ Answer the following question only by using the context given below in the triple backticks, do not use any other information to answer the question.
172
+ If you can't answer the given question with the given context, you can return an empty string ('')
173
+ Context: ```{context}```
174
+ ----------------------------
175
+ Question: {query}
176
+ ----------------------------
177
+ Answer: """
178
+
179
+ human_message_prompt = HumanMessagePromptTemplate.from_template(template=template)
180
+ chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
181
+ prompt = chat_prompt.format_prompt(query=message, context=context)
182
+ response = intention_client.chat(messages=prompt.to_messages()).content
183
+ chat_history.append((message, response))
184
+ return "", chat_history
185
+
186
+
187
+ # Initialize clients
188
+ intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
189
+ embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client)
190
+ chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
191
+
192
+ def upload_documents(files):
193
+ for file in files:
194
+ loader = UnstructuredFileLoader(file.name)
195
+ documents = loader.load()
196
+ add_documents_to_chroma(documents, embedding_function)
197
+ return "Documents uploaded and processed successfully!"
198
+
199
+ def query_documents(query):
200
+ results = query_chroma(query)
201
+ return "\n\n".join([result.content for result in results])
202
+
203
+ with gr.Blocks() as demo:
204
+ with gr.Tab("Upload Documents"):
205
+ document_upload = gr.File(file_count="multiple", file_types=["document"])
206
+ upload_button = gr.Button("Upload and Process")
207
+ upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
208
+
209
+ with gr.Tab("Ask Questions"):
210
+ with gr.Row():
211
+ chat_interface = gr.ChatInterface(
212
+ answer_query,
213
+ additional_inputs=[
214
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
215
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
216
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
217
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
218
+ ],
219
+ )
220
+ query_input = gr.Textbox(label="Query")
221
+ query_button = gr.Button("Query")
222
+ query_output = gr.Textbox()
223
+ query_button.click(query_documents, inputs=query_input, outputs=query_output)
224
+
225
+ if __name__ == "__main__":
226
+ # os.system("chroma run --host localhost --port 8000 &")
227
+ demo.launch()
requirements.txt CHANGED
@@ -15,3 +15,4 @@ gradio
15
  # tesseract
16
  # libxml2
17
  # libxslt
 
 
15
  # tesseract
16
  # libxml2
17
  # libxslt
18
+ InstructorEmbedding