Spaces:
Running
Running
from flask import Flask, render_template, request, redirect, url_for, session, flash | |
import os | |
from werkzeug.utils import secure_filename | |
#from retrival import generate_data_store | |
from retrival import generate_data_store #,add_document_to_existing_db, delete_chunks_by_source | |
from langchain_community.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
from huggingface_hub import InferenceClient | |
from langchain.schema import Document | |
from langchain_core.documents import Document | |
from dotenv import load_dotenv | |
import re | |
import numpy as np | |
import glob | |
import shutil | |
from werkzeug.utils import secure_filename | |
import asyncio | |
import nltk | |
nltk.download('punkt_tab') | |
import nltk | |
nltk.download('averaged_perceptron_tagger_eng') | |
app = Flask(__name__) | |
# Set the secret key for session management | |
app.secret_key = os.urandom(24) | |
# Configurations | |
UPLOAD_FOLDER = "uploads/" | |
VECTOR_DB_FOLDER = "VectorDB/" | |
TABLE_DB_FOLDER = "TableDB/" | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
os.makedirs(VECTOR_DB_FOLDER, exist_ok=True) | |
os.makedirs(TABLE_DB_FOLDER, exist_ok=True) | |
# Global variables | |
CHROMA_PATH = None | |
TABLE_PATH = None | |
PROMPT_TEMPLATE_DOC = """ | |
<s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions: | |
- Use only the provided context; do not add external information. | |
- The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones. | |
- If available, use the provided source information to support the response. | |
- Answer concisely and factually. | |
Context: | |
{context} | |
--- | |
Question: | |
{question} | |
Response: | |
[/INST] | |
""" | |
# prompt if the document having the tables | |
PROMPT_TEMPLATE_TAB = """ | |
<s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions: | |
- Use only the provided context; do not add external information. | |
- The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones. | |
- If available, use the provided source information to support the response. | |
- If a table is provided as html, incorporate its relevant details into the response while maintaining a structured format. | |
- Answer concisely and factually. | |
Context: | |
{context} | |
--- | |
Table: | |
{table} | |
--- | |
Question: | |
{question} | |
Response: | |
[/INST] | |
""" | |
#HFT = os.getenv('HF_TOKEN') | |
#client = InferenceClient(api_key=HFT) | |
def home(): | |
return render_template('home.html') | |
def chat(): | |
if 'history' not in session: | |
session['history'] = [] | |
print("sessionhist1",session['history']) | |
global CHROMA_PATH | |
global TABLE_PATH | |
old_db = session.get('old_db', None) | |
print(f"Selected DB: {CHROMA_PATH}") | |
# if old_db != None: | |
# if CHROMA_PATH != old_db: | |
# session['history'] = [] | |
#print("sessionhist1",session['history']) | |
if request.method == 'POST': | |
query_text = request.form['query_text'] | |
if CHROMA_PATH is None: | |
flash("Please select a database first!", "error") | |
return redirect(url_for('list_dbs')) | |
# Load the selected Document Database | |
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
#embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) | |
# Convert the query to its embedding vector | |
query_embedding = embedding_function.embed_query(query_text) | |
if isinstance(query_embedding, float): | |
query_embedding = [query_embedding] | |
# print(f"Query embedding: {query_embedding}") | |
# print(f"Type of query embedding: {type(query_embedding)}") | |
# print(f"Length of query embedding: {len(query_embedding) if isinstance(query_embedding, (list, np.ndarray)) else 'Not applicable'}") | |
results_document = db.similarity_search_by_vector_with_relevance_scores( | |
embedding=query_embedding, # Pass the query embedding | |
k=3, | |
#filter=filter_condition # Pass the filter condition | |
) | |
print("results------------------->",results_document) | |
print("============================================") | |
print("============================================") | |
context_text_document = " \n\n###\n\n ".join( | |
[f"Source: {doc.metadata.get('source', '')} Page_content:{doc.page_content}\n" for doc, _score in results_document] | |
) | |
# Loading Table Database only if available | |
if TABLE_PATH is not None: | |
#embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
tdb = Chroma(persist_directory=TABLE_PATH, embedding_function=embedding_function) | |
results_table = tdb.similarity_search_by_vector_with_relevance_scores( | |
embedding=query_embedding, # Pass the query embedding | |
k=2 | |
#filter=filter_condition # Pass the filter condition | |
) | |
print("results------------------->",results_table) | |
context_text_table = "\n\n---\n\n".join([doc.page_content for doc, _score in results_table]) | |
# Prepare the prompt and query the model | |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_TAB) | |
prompt = prompt_template.format(context=context_text_document,table=context_text_table,question=query_text) | |
#prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text) | |
print("results------------------->",prompt) | |
else: | |
# Prepare the prompt and query the model | |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_DOC) | |
prompt = prompt_template.format(context=context_text_document,question=query_text) | |
#prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text) | |
print("results------------------->",prompt) | |
#Model Defining and its use | |
repo_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
HFT = os.environ["HF_TOKEN"] | |
llm = HuggingFaceEndpoint( | |
repo_id=repo_id, | |
#max_tokens=3000, | |
max_new_tokens=2000, | |
temperature=0.8, | |
huggingfacehub_api_token=HFT, | |
) | |
data= llm(prompt) | |
#data = response.choices[0].message.content | |
# filtering the uneccessary context. | |
if re.search(r'\bmention\b|\bnot mention\b|\bnot mentioned\b|\bnot contain\b|\bnot include\b|\bnot provide\b|\bdoes not\b|\bnot explicitly\b|\bnot explicitly mentioned\b', data, re.IGNORECASE): | |
data = "We do not have information related to your query on our end." | |
# Save the query and answer to the session history | |
session['history'].append((query_text, data)) | |
# Mark the session as modified to ensure it gets saved | |
session.modified = True | |
print("sessionhist2",session['history']) | |
return render_template('chat.html', query_text=query_text, answer=data, history=session['history'],old_db=CHROMA_PATH) | |
return render_template('chat.html', history=session['history'], old_db=CHROMA_PATH) | |
def create_db(): | |
if request.method == 'POST': | |
db_name = request.form.get('db_name', '').strip() | |
if not db_name: | |
return "Database name is required", 400 | |
# Get uploaded files | |
files = request.files.getlist('folder') # Folder uploads (multiple files) | |
single_files = request.files.getlist('file') # Single file uploads | |
print("==================folder==>", files) | |
print("==================single_files==>", single_files) | |
# Ensure at least one valid file is uploaded | |
if not any(file.filename.strip() for file in files) and not any(file.filename.strip() for file in single_files): | |
return "No files uploaded", 400 | |
# Create upload directory | |
upload_base_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(db_name)) | |
print(f"Base Upload Path: {upload_base_path}") | |
os.makedirs(upload_base_path, exist_ok=True) | |
# Process single file uploads first (if any exist) | |
if any(file.filename.strip() for file in single_files): | |
for file in single_files: | |
if file.filename.strip(): # Ensure the file is valid | |
file_name = secure_filename(file.filename) | |
file_path = os.path.join(upload_base_path, file_name) | |
print(f"Saving single file to: {file_path}") | |
file.save(file_path) | |
# If single file is uploaded, skip folder processing | |
print("Single file uploaded, skipping folder processing.") | |
asyncio.run(generate_data_store(upload_base_path, db_name)) | |
return redirect(url_for('list_dbs')) | |
# Process folder files only if valid files exist | |
if any(file.filename.strip() for file in files): | |
for file in files: | |
if file.filename.strip(): # Ensure it's a valid file | |
file_name = secure_filename(file.filename) | |
file_path = os.path.join(upload_base_path, file_name) | |
print(f"Saving folder file to: {file_path}") | |
file.save(file_path) | |
# Generate datastore | |
asyncio.run(generate_data_store(upload_base_path, db_name)) | |
return redirect(url_for('list_dbs')) | |
return render_template('create_db.html') | |
def list_dbs(): | |
vector_dbs = [name for name in os.listdir(VECTOR_DB_FOLDER) if os.path.isdir(os.path.join(VECTOR_DB_FOLDER, name))] | |
return render_template('list_dbs.html', vector_dbs=vector_dbs) | |
def select_db(db_name): | |
flash(f"{db_name} Database has been selected", "table_selected") | |
#Selecting the Documnet Vector DB | |
global CHROMA_PATH | |
global TABLE_PATH | |
print(f"Selected DB: {CHROMA_PATH}") | |
print("-----------------------------------------------------1----") | |
CHROMA_PATH = os.path.join(VECTOR_DB_FOLDER, db_name) | |
CHROMA_PATH = CHROMA_PATH.replace("\\", "/") | |
print(f"Selected DB: {CHROMA_PATH}") | |
print("-----------------------------------------------------2----") | |
# Selecting the Table Vector DB | |
table_db_path = os.path.join(TABLE_DB_FOLDER, db_name) | |
table_db_path = table_db_path.replace("\\", "/") | |
TABLE_PATH = table_db_path if os.path.exists(table_db_path) else None | |
print(f"Selected Table DB: {TABLE_PATH}") | |
return redirect(url_for('chat')) | |
def update_db(db_name): | |
if request.method == 'POST': | |
db_name = request.form['db_name'] | |
# Get all files from the uploaded folder | |
files = request.files.getlist('folder') | |
if not files: | |
return "No files uploaded", 400 | |
print(f"Selected DB: {db_name}") | |
DB_PATH = os.path.join(VECTOR_DB_FOLDER, db_name) | |
DB_PATH = DB_PATH.replace("\\", "/") | |
print(f"Selected DB: {DB_PATH}") | |
generate_data_store(DB_PATH, db_name) | |
return redirect(url_for('list_dbs')) | |
return render_template('update_db.html') | |
if __name__ == "__main__": | |
app.run(debug=False, use_reloader=False) | |