tomascufaro commited on
Commit
1fecdf1
·
1 Parent(s): 632d9b5

new app.py

Browse files
Files changed (2) hide show
  1. app.py +135 -20
  2. requirements.txt +223 -11
app.py CHANGED
@@ -8,14 +8,38 @@ from langchain import hub
8
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
9
  from langchain_core.runnables import RunnablePassthrough
10
  from langchain_core.output_parsers import StrOutputParser
 
11
  import os
12
  import gradio as gr
 
 
 
 
 
 
 
 
 
13
 
14
 
15
- def doc_to_embeddings(doc:Document, split_mode:str='tiktoken',
16
- chunk_size:int=1000, chunk_overlap:int=5, faiss_save_path:str=None, save_faiss:bool=None):
17
- # Load the PDF file (if the file is a URL, load the PDF file from the URL)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Split by separator and merge by character count
20
  if split_mode == "character":
21
  # Create a CharacterTextSplitter object
@@ -42,40 +66,131 @@ def doc_to_embeddings(doc:Document, split_mode:str='tiktoken',
42
  chunk_size=chunk_size,
43
  chunk_overlap=chunk_overlap,)
44
  else:
45
- raise ValueError("Please specify the split mode.")
46
  documents = text_splitter.split_documents(doc)
 
 
 
 
 
 
 
 
 
 
47
  embeddings = OpenAIEmbeddings(openai_api_key=os.environ['OpenAI_APIKEY'])
48
  faiss_db = FAISS.from_documents(documents, embeddings)
49
  if save_faiss:
50
  faiss_db.save_local(faiss_save_path)
51
  return faiss_db
52
 
53
- def format_docs(docs):
54
- return "\n\n".join(doc.page_content for doc in docs)
55
 
 
 
 
 
 
 
 
 
56
 
57
- def wrap_all(file, input_prompt:str):
58
- loader = Docx2txtLoader(file)
59
- data = loader.load()
60
- db = doc_to_embeddings(data)
61
- retriever = db.as_retriever()
62
- prompt = hub.pull("rlm/rag-prompt")
63
- llm = ChatOpenAI(model_name="gpt-4",openai_api_key=os.environ['OpenAI_APIKEY'], temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  rag_chain = (
65
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
66
- | prompt
67
- | llm
68
- | StrOutputParser()
69
- )
 
 
70
  return rag_chain.invoke(input_prompt)
71
 
72
 
73
  # Define the Gradio interface
74
  iface = gr.Interface(
75
  fn=wrap_all,
76
- inputs=[gr.File(type="filepath", label=".docx file of the interview"), gr.Textbox(label="Enter your inquiry")],
77
  outputs="text",
78
  title="Interviews: QA and summarization",
79
  description="Upload a .docx file with the interview and enter the question you have or ask for a summarization.")
80
-
81
  iface.launch()
 
8
  from langchain_openai import ChatOpenAI, OpenAIEmbeddings
9
  from langchain_core.runnables import RunnablePassthrough
10
  from langchain_core.output_parsers import StrOutputParser
11
+ from langchain_community.vectorstores import Chroma
12
  import os
13
  import gradio as gr
14
+ import os
15
+ from typing import List
16
+ from pydantic import BaseModel
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from unstructured.partition.pdf import partition_pdf
19
+ import uuid
20
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
21
+ from langchain.storage import InMemoryStore
22
+ from langchain_community.document_loaders import UnstructuredPDFLoader
23
 
24
 
 
 
 
25
 
26
+ # The vectorstore to use to index the child chunks
27
+ vectorstore = Chroma(
28
+ collection_name="rag_app",embedding_function=OpenAIEmbeddings(api_key="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"))
29
+
30
+ # The storage layer for the parent documents
31
+ store = InMemoryStore()
32
+ id_key = "doc_id"
33
+
34
+ # The retriever (empty to start)
35
+ retriever = MultiVectorRetriever(
36
+ vectorstore=vectorstore,
37
+ docstore=store,
38
+ id_key=id_key,
39
+ )
40
+
41
+ def split_text(doc:str, split_mode:str='tiktoken',
42
+ chunk_size:int=1000, chunk_overlap:int=5, faiss_save_path:str=None, save_faiss:bool=None):
43
  # Split by separator and merge by character count
44
  if split_mode == "character":
45
  # Create a CharacterTextSplitter object
 
66
  chunk_size=chunk_size,
67
  chunk_overlap=chunk_overlap,)
68
  else:
69
+ raise ValueError("Please specify the split mode.")
70
  documents = text_splitter.split_documents(doc)
71
+ return documents
72
+
73
+ def format_docs(docs):
74
+ return "\n\n".join(doc.page_content for doc in docs)
75
+
76
+ class Element(BaseModel):
77
+ type: str
78
+ text: str
79
+
80
+ def save_documents(Documents):
81
  embeddings = OpenAIEmbeddings(openai_api_key=os.environ['OpenAI_APIKEY'])
82
  faiss_db = FAISS.from_documents(documents, embeddings)
83
  if save_faiss:
84
  faiss_db.save_local(faiss_save_path)
85
  return faiss_db
86
 
87
+ def save_documents(texts, text_summaries, tables, table_summaries):
 
88
 
89
+ # Add texts
90
+ doc_ids = [str(uuid.uuid4()) for _ in texts]
91
+ summary_texts = [
92
+ Document(page_content=s, metadata={id_key: doc_ids[i]})
93
+ for i, s in enumerate(text_summaries)
94
+ ]
95
+ retriever.vectorstore.add_documents(summary_texts)
96
+ retriever.docstore.mset(list(zip(doc_ids, texts)))
97
 
98
+ # Add tables
99
+ table_ids = [str(uuid.uuid4()) for _ in tables]
100
+ summary_tables = [
101
+ Document(page_content=s, metadata={id_key: table_ids[i]})
102
+ for i, s in enumerate(table_summaries)
103
+ ]
104
+ retriever.vectorstore.add_documents(summary_tables)
105
+ retriever.docstore.mset(list(zip(table_ids, tables)))
106
+
107
+
108
+ def doc_processing(files: List[bytes]):
109
+ docs = []
110
+ tables = []
111
+ for file in files:
112
+ if file.name.endswith(".pdf"):
113
+ # Identify file type and process accordingly
114
+ raw_pdf_elements = partition_pdf(
115
+ filename=file,
116
+ extract_images_in_pdf=False,
117
+ infer_table_structure=True,
118
+ chunking_strategy="by_title",
119
+ max_characters=4000,
120
+ new_after_n_chars=3800,
121
+ combine_text_under_n_chars=2000,
122
+ image_output_dir_path='/tmp', # Change this to your desired path
123
+ )
124
+ categorized_elements = []
125
+ for element in raw_pdf_elements:
126
+ if "unstructured.documents.elements.Table" in str(type(element)):
127
+ categorized_elements.append(Element(type="table", text=str(element)))
128
+ elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
129
+ categorized_elements.append(Element(type="text", text=str(element)))
130
+
131
+ # Extract text and table elements
132
+ text_elements = [e for e in categorized_elements if e.type == "text"]
133
+ table_elements = [e for e in categorized_elements if e.type == "table"]
134
+ docs.extend(text_elements)
135
+ tables.extend(table_elements)
136
+ elif file.name.endswith(".docx"):
137
+ # Process DOCX file using LangChain Docx2txtLoader
138
+ loader = Docx2txtLoader(file)
139
+ data = loader.load()
140
+ docs.extend(data)
141
+
142
+ # Prompt
143
+ prompt_text = """You are an assistant tasked with summarizing tables and text.
144
+ Give a concise summary of the table or text. Table or text chunk: {element} """
145
+ prompt = ChatPromptTemplate.from_template(prompt_text)
146
+
147
+ # Summary chain
148
+ model = ChatOpenAI(temperature=0, model="gpt-4")
149
+ summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
150
+
151
+ table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
152
+ text_summaries = summarize_chain.batch(docs, {"max_concurrency": 5})
153
+
154
+ return docs, tables, text_summaries, table_summaries
155
+
156
+ # Convert the list of document texts to embeddings
157
+
158
+
159
+ def wrap_all(files: List[bytes], input_prompt: str):
160
+
161
+ save_documents(doc_processing(files))
162
+ # Prompt template
163
+ template = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
164
+ If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. Please cite the text that you are using to base your arguments when it is possible.
165
+
166
+ Question: {question}
167
+
168
+ Context: {context}
169
+
170
+ Answer:
171
+ """
172
+ prompt = ChatPromptTemplate.from_template(template)
173
+ # Load the prompt template and the language model
174
+ #prompt = hub.pull("rlm/rag-prompt")
175
+ llm = ChatOpenAI(model_name="gpt-4o", openai_api_key=os.environ['OpenAI_APIKEY'], temperature=0)
176
+
177
+ # Create the RAG chain
178
  rag_chain = (
179
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
180
+ | prompt
181
+ | llm
182
+ | StrOutputParser()
183
+ )
184
+
185
+ # Invoke the chain with the input prompt
186
  return rag_chain.invoke(input_prompt)
187
 
188
 
189
  # Define the Gradio interface
190
  iface = gr.Interface(
191
  fn=wrap_all,
192
+ inputs=[gr.File(type="filepath", label=".docx file of the interview", file_count='multiple'), gr.Textbox(label="Enter your inquiry")],
193
  outputs="text",
194
  title="Interviews: QA and summarization",
195
  description="Upload a .docx file with the interview and enter the question you have or ask for a summarization.")
 
196
  iface.launch()
requirements.txt CHANGED
@@ -1,11 +1,223 @@
1
- langchain
2
- numpy
3
- pandas
4
- openai
5
- openpyxl
6
- langchain_community
7
- langchain_openai
8
- langchain_core
9
- docx2txt
10
- faiss-cpu
11
- langchainhub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ anyio==4.4.0
7
+ asgiref==3.8.1
8
+ asttokens==2.4.1
9
+ attrs==23.2.0
10
+ backoff==2.2.1
11
+ bcrypt==4.1.3
12
+ beautifulsoup4==4.12.3
13
+ build==1.2.1
14
+ cachetools==5.3.3
15
+ certifi==2024.6.2
16
+ cffi==1.16.0
17
+ chardet==5.2.0
18
+ charset-normalizer==3.3.2
19
+ chroma-hnswlib==0.7.3
20
+ chromadb==0.5.3
21
+ click==8.1.7
22
+ cmake==3.29.6
23
+ colorama==0.4.6
24
+ coloredlogs==15.0.1
25
+ comm==0.2.2
26
+ contourpy==1.2.1
27
+ cryptography==42.0.8
28
+ cycler==0.12.1
29
+ dataclasses-json==0.6.7
30
+ debugpy==1.8.2
31
+ decorator==5.1.1
32
+ deepdiff==7.0.1
33
+ Deprecated==1.2.14
34
+ distro==1.9.0
35
+ dnspython==2.6.1
36
+ docx2txt==0.8
37
+ email_validator==2.2.0
38
+ emoji==2.12.1
39
+ et-xmlfile==1.1.0
40
+ executing==2.0.1
41
+ faiss-cpu==1.8.0.post1
42
+ fastapi==0.111.0
43
+ fastapi-cli==0.0.4
44
+ ffmpy==0.3.2
45
+ filelock==3.15.4
46
+ filetype==1.2.0
47
+ flatbuffers==24.3.25
48
+ fonttools==4.53.0
49
+ frozenlist==1.4.1
50
+ fsspec==2024.6.1
51
+ google-auth==2.31.0
52
+ googleapis-common-protos==1.63.2
53
+ gradio==4.37.2
54
+ gradio_client==1.0.2
55
+ greenlet==3.0.3
56
+ grpcio==1.64.1
57
+ h11==0.14.0
58
+ httpcore==1.0.5
59
+ httptools==0.6.1
60
+ httpx==0.27.0
61
+ huggingface-hub==0.23.4
62
+ humanfriendly==10.0
63
+ idna==3.7
64
+ importlib_metadata==7.1.0
65
+ importlib_resources==6.4.0
66
+ intel-openmp==2021.4.0
67
+ iopath==0.1.10
68
+ ipykernel==6.29.5
69
+ ipython==8.26.0
70
+ jedi==0.19.1
71
+ Jinja2==3.1.4
72
+ joblib==1.4.2
73
+ jsonpatch==1.33
74
+ jsonpath-python==1.0.6
75
+ jsonpointer==3.0.0
76
+ jsonschema==4.22.0
77
+ jsonschema-specifications==2023.12.1
78
+ jupyter_client==8.6.2
79
+ jupyter_core==5.7.2
80
+ kiwisolver==1.4.5
81
+ kubernetes==30.1.0
82
+ langchain==0.2.6
83
+ langchain-community==0.2.6
84
+ langchain-core==0.2.10
85
+ langchain-openai==0.1.13
86
+ langchain-text-splitters==0.2.2
87
+ langchainhub==0.1.20
88
+ langdetect==1.0.9
89
+ langsmith==0.1.83
90
+ layoutparser==0.3.4
91
+ lxml==5.2.2
92
+ markdown-it-py==3.0.0
93
+ MarkupSafe==2.1.5
94
+ marshmallow==3.21.3
95
+ matplotlib==3.9.0
96
+ matplotlib-inline==0.1.7
97
+ mdurl==0.1.2
98
+ mkl==2021.4.0
99
+ mmh3==4.1.0
100
+ monotonic==1.6
101
+ mpmath==1.3.0
102
+ multidict==6.0.5
103
+ mypy-extensions==1.0.0
104
+ nest-asyncio==1.6.0
105
+ networkx==3.3
106
+ nltk==3.8.1
107
+ numpy==1.26.4
108
+ oauthlib==3.2.2
109
+ onnx==1.16.1
110
+ onnxruntime==1.18.1
111
+ openai==1.35.7
112
+ opencv-python==4.10.0.84
113
+ openpyxl==3.1.5
114
+ opentelemetry-api==1.25.0
115
+ opentelemetry-exporter-otlp-proto-common==1.25.0
116
+ opentelemetry-exporter-otlp-proto-grpc==1.25.0
117
+ opentelemetry-instrumentation==0.46b0
118
+ opentelemetry-instrumentation-asgi==0.46b0
119
+ opentelemetry-instrumentation-fastapi==0.46b0
120
+ opentelemetry-proto==1.25.0
121
+ opentelemetry-sdk==1.25.0
122
+ opentelemetry-semantic-conventions==0.46b0
123
+ opentelemetry-util-http==0.46b0
124
+ ordered-set==4.1.0
125
+ orjson==3.10.5
126
+ overrides==7.7.0
127
+ packaging==24.1
128
+ pandas==2.2.2
129
+ parso==0.8.4
130
+ pdf2image==1.17.0
131
+ pdfminer.six==20231228
132
+ pdfplumber==0.11.1
133
+ pillow==10.4.0
134
+ pillow_heif==0.17.0
135
+ pkgconfig==1.5.5
136
+ platformdirs==4.2.2
137
+ poppler-utils==0.1.0
138
+ portalocker==2.10.0
139
+ posthog==3.5.0
140
+ prompt_toolkit==3.0.47
141
+ protobuf==4.25.3
142
+ psutil==6.0.0
143
+ pure-eval==0.2.2
144
+ pyasn1==0.6.0
145
+ pyasn1_modules==0.4.0
146
+ pycparser==2.22
147
+ pycryptodome==3.20.0
148
+ pydantic==2.8.0
149
+ pydantic_core==2.20.0
150
+ pydub==0.25.1
151
+ Pygments==2.18.0
152
+ pykg-config==1.3.0
153
+ pyparsing==3.1.2
154
+ pypdf==4.2.0
155
+ pypdfium2==4.30.0
156
+ PyPika==0.48.9
157
+ pyproject_hooks==1.1.0
158
+ pyreadline3==3.4.1
159
+ python-dateutil==2.9.0.post0
160
+ python-dotenv==1.0.1
161
+ python-iso639==2024.4.27
162
+ python-magic==0.4.27
163
+ python-multipart==0.0.9
164
+ pytz==2024.1
165
+ pywin32==306
166
+ PyYAML==6.0.1
167
+ pyzmq==26.0.3
168
+ rapidfuzz==3.9.3
169
+ referencing==0.35.1
170
+ regex==2024.5.15
171
+ requests==2.32.3
172
+ requests-oauthlib==2.0.0
173
+ requests-toolbelt==1.0.0
174
+ rich==13.7.1
175
+ rpds-py==0.18.1
176
+ rsa==4.9
177
+ ruff==0.5.0
178
+ safetensors==0.4.3
179
+ scipy==1.14.0
180
+ semantic-version==2.10.0
181
+ setuptools==70.1.1
182
+ shellingham==1.5.4
183
+ six==1.16.0
184
+ sniffio==1.3.1
185
+ soupsieve==2.5
186
+ SQLAlchemy==2.0.31
187
+ stack-data==0.6.3
188
+ starlette==0.37.2
189
+ sympy==1.12.1
190
+ tabulate==0.9.0
191
+ tbb==2021.13.0
192
+ tenacity==8.4.2
193
+ tiktoken==0.7.0
194
+ timm==1.0.7
195
+ tokenizers==0.19.1
196
+ tomlkit==0.12.0
197
+ toolz==0.12.1
198
+ torch==2.3.1
199
+ torchvision==0.18.1
200
+ tornado==6.4.1
201
+ tqdm==4.66.4
202
+ traitlets==5.14.3
203
+ transformers==4.42.3
204
+ typer==0.12.3
205
+ types-requests==2.32.0.20240622
206
+ typing-inspect==0.9.0
207
+ typing_extensions==4.12.2
208
+ tzdata==2024.1
209
+ ujson==5.10.0
210
+ unstructured==0.14.9
211
+ unstructured-client==0.23.8
212
+ unstructured-inference==0.7.36
213
+ unstructured.pytesseract==0.3.12
214
+ urllib3==2.2.2
215
+ uvicorn==0.30.1
216
+ watchfiles==0.22.0
217
+ wcwidth==0.2.13
218
+ websocket-client==1.8.0
219
+ websockets==11.0.3
220
+ wheel==0.43.0
221
+ wrapt==1.16.0
222
+ yarl==1.9.4
223
+ zipp==3.19.2