|
import gradio as gr |
|
from langchain.document_loaders import PDFMinerLoader, PyMuPDFLoader |
|
from langchain.text_splitter import CharacterTextSplitter |
|
import chromadb |
|
import chromadb.config |
|
from chromadb.config import Settings |
|
from transformers import T5ForConditionalGeneration, AutoTokenizer |
|
import torch |
|
import gradio as gr |
|
import uuid |
|
from sentence_transformers import SentenceTransformer |
|
import os |
|
|
|
global file_name |
|
|
|
model_name = 'google/flan-t5-base' |
|
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
print('flan read') |
|
|
|
|
|
ST_name = 'sentence-transformers/sentence-t5-base' |
|
st_model = SentenceTransformer(ST_name) |
|
print('sentence read') |
|
|
|
|
|
def get_context(query_text, collection): |
|
query_emb = st_model.encode(query_text) |
|
query_response = collection.query(query_embeddings=query_emb.tolist(), n_results=4) |
|
context = query_response['documents'][0][0] |
|
context = context.replace('\n', ' ').replace(' ', ' ') |
|
return context |
|
|
|
def local_query(query, context): |
|
t5query = """Using the available context, please answer the question. |
|
If you aren't sure please say i don't know. |
|
Context: {} |
|
Question: {} |
|
""".format(context, query) |
|
|
|
print('t5 query is') |
|
primt(t5query) |
|
|
|
inputs = tokenizer(t5query, return_tensors="pt") |
|
|
|
print('done with tokenizer') |
|
outputs = model.generate(**inputs, max_new_tokens=20) |
|
|
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
def run_query(file, history, query): |
|
|
|
file_name = file.name |
|
|
|
loader = PDFMinerLoader(file_name) |
|
doc = loader.load() |
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
|
texts = text_splitter.split_documents(doc) |
|
|
|
texts = [i.page_content for i in texts] |
|
|
|
doc_emb = st_model.encode(texts) |
|
doc_emb = doc_emb.tolist() |
|
|
|
ids = [str(uuid.uuid1()) for _ in doc_emb] |
|
|
|
client = chromadb.Client() |
|
collection = client.create_collection("test_db") |
|
|
|
collection.add( |
|
embeddings=doc_emb, |
|
documents=texts, |
|
ids=ids |
|
) |
|
|
|
print('calling get contct function') |
|
print(collection) |
|
|
|
context = get_context(query, collection) |
|
|
|
print(context) |
|
print('calling local query') |
|
result = local_query(query, context) |
|
print(result) |
|
|
|
history = history.append(query) |
|
|
|
print(history) |
|
return history, result |
|
|
|
|
|
|
|
|
|
def upload_pdf(file): |
|
try: |
|
if file is not None: |
|
file_name = file.name |
|
return 'Successfully uploaded!' |
|
else: |
|
return "No file uploaded." |
|
|
|
except Exception as e: |
|
return f"An error occurred: {e}" |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"]) |
|
output = gr.Textbox(label="Output Box") |
|
chatbot = gr.Chatbot(value=[], elem_id="chatbot") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.70): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter a question", |
|
) |
|
|
|
|
|
|
|
btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output]) |
|
txt.submit(run_query, [btn, chatbot, txt], [chatbot,]) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|