RAG_AI_V2 / app.py
WebashalarForML's picture
Update app.py
dfd51c8 verified
from flask import Flask, render_template, request, redirect, url_for, session, flash
import os
from werkzeug.utils import secure_filename
from retrival import generate_data_store,update_data_store,approximate_bpe_token_counter
from langchain_community.vectorstores import Chroma
import chromadb
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from huggingface_hub import InferenceClient
from langchain.schema import Document
from langchain_core.documents import Document
from dotenv import load_dotenv
import re
import numpy as np
import glob
import shutil
from werkzeug.utils import secure_filename
import asyncio
import nltk
nltk.download('punkt_tab')
import nltk
nltk.download('averaged_perceptron_tagger_eng')
app = Flask(__name__)
# Set the secret key for session management
app.secret_key = os.urandom(24)
# Configurations
UPLOAD_FOLDER = "uploads/"
VECTOR_DB_FOLDER = "VectorDB/"
TABLE_DB_FOLDER = "TableDB/"
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['DEBUG'] = True
app.config['ENV'] = 'development'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(VECTOR_DB_FOLDER, exist_ok=True)
os.makedirs(TABLE_DB_FOLDER, exist_ok=True)
# Global variables
CHROMA_PATH = None
TABLE_PATH = None
########################################################################################################################################################
####----------------------------------------------------------------- Prompt Templates ------------------------------------------------------------####
########################################################################################################################################################
# prompt if the simple document
PROMPT_TEMPLATE_DOC = """
<s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions:
- Use only the provided context; do not add external information.
- The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones.
- If available, use the provided source information to support the response.
- Answer concisely and factually.
Context:
{context}
---
Question:
{question}
Response:
[/INST]
"""
# prompt if the document having the tables
PROMPT_TEMPLATE_TAB = """
<s>[INST] You are a retrieval-augmented generation (RAG) assistant. Your task is to generate a response strictly based on the given context. Follow these instructions:
- Use only the provided context; do not add external information.
- The context contains multiple retrieved chunks separated by "###". Choose only the most relevant chunks to answer the question and ignore unrelated ones.
- If available, use the provided source information to support the response.
- If a table is provided as html, incorporate its relevant details into the response while maintaining a structured format.
- Answer concisely and factually.
Context:
{context}
---
Table:
{table}
---
Question:
{question}
Response:
[/INST]
"""
########################################################################################################################################################
####--------------------------------------------------------------- Flask APP ROUTES --------------------------------------------------------------####
########################################################################################################################################################
@app.route('/', methods=['GET'])
def home():
return render_template('home.html')
########################################################################################################################################################
####---------------------------------------------------------------- routes for chat --------------------------------------------------------------####
########################################################################################################################################################
@app.route('/chat', methods=['GET', 'POST'])
def chat():
try:
if 'history' not in session:
session['history'] = []
print("sessionhist1",session['history'])
global CHROMA_PATH
global TABLE_PATH
old_db = session.get('old_db', None)
print(f"Selected DB: {CHROMA_PATH}")
# if old_db != None:
# if CHROMA_PATH != old_db:
# session['history'] = []
#print("sessionhist1",session['history'])
if request.method == 'POST':
query_text = request.form['query_text']
if CHROMA_PATH is None:
flash("Please select a database first!", "error")
return redirect(url_for('list_dbs'))
# Load the selected Document Database
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
#embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
# Convert the query to its embedding vector
query_embedding = embedding_function.embed_query(query_text)
if isinstance(query_embedding, float):
query_embedding = [query_embedding]
results_document = db.similarity_search_by_vector_with_relevance_scores(
embedding=query_embedding, # Pass the query embedding
k=3,
#filter=filter_condition # Pass the filter condition
)
print("results------------------->",results_document)
print("============================================")
print("============================================")
context_text_document = " \n\n###\n\n ".join(
[f"Source: {doc.metadata.get('source', '')} Page_content:{doc.page_content}\n" for doc, _score in results_document]
)
# Loading Table Database only if available
if TABLE_PATH is not None:
#embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
embedding_function = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
tdb = Chroma(persist_directory=TABLE_PATH, embedding_function=embedding_function)
results_table = tdb.similarity_search_by_vector_with_relevance_scores(
embedding=query_embedding, # Pass the query embedding
k=2
#filter=filter_condition # Pass the filter condition
)
print("results------------------->",results_table)
context_text_table = "\n\n---\n\n".join([doc.page_content for doc, _score in results_table])
# Prepare the prompt and query the model
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_TAB)
prompt = prompt_template.format(context=context_text_document,table=context_text_table,question=query_text)
#prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text)
print("results------------------->",prompt)
else:
# Prepare the prompt and query the model
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE_DOC)
prompt = prompt_template.format(context=context_text_document,question=query_text)
#prompt = prompt_template.format(context=context_text_document,table=context_text_table, question=query_text)
print("results------------------->",prompt)
#Model Defining and its use
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
HFT = os.environ["HF_TOKEN"]
llm = HuggingFaceEndpoint(
repo_id=repo_id,
#max_tokens=3000,
max_new_tokens=2000,
task = "text-generation",
temperature=0.8,
huggingfacehub_api_token=HFT,
)
data= llm.invoke(prompt)
#data= llm(prompt)
#data = response.choices[0].message.content
# filtering the uneccessary context.
if re.search(r'\bmention\b|\bnot mention\b|\bnot mentioned\b|\bnot contain\b|\bnot include\b|\bnot provide\b|\bdoes not\b|\bnot explicitly\b|\bnot explicitly mentioned\b', data, re.IGNORECASE):
data = "We do not have information related to your query on our end."
# Save the query and answer to the session history
session['history'].append((query_text, data))
# Mark the session as modified to ensure it gets saved
session.modified = True
print("sessionhist2",session['history'])
return render_template('chat.html', query_text=query_text, answer=data,token_count=approximate_bpe_token_counter(data), history=session['history'],old_db=CHROMA_PATH)
except Exception as e:
flash(f"Error in Creating DB: {e}","error")
return redirect(url_for('list_dbs'))
return render_template('chat.html', history=session['history'], old_db=CHROMA_PATH)
########################################################################################################################################################
####---------------------------------------------------------------- routes for create-db ---------------------------------------------------------####
########################################################################################################################################################
@app.route('/create-db', methods=['GET', 'POST'])
def create_db():
try:
if request.method == 'POST':
db_name = request.form.get('db_name', '').strip()
if not db_name:
return "Database name is required", 400
# Get uploaded files
files = request.files.getlist('folder') # Folder uploads (multiple files)
single_files = request.files.getlist('file') # Single file uploads
print("==================folder==>", files)
print("==================single_files==>", single_files)
# Ensure at least one valid file is uploaded
if not any(file.filename.strip() for file in files) and not any(file.filename.strip() for file in single_files):
return "No files uploaded", 400
# Create upload directory
upload_base_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(db_name))
print(f"Base Upload Path: {upload_base_path}")
os.makedirs(upload_base_path, exist_ok=True)
# Process single file uploads first (if any exist)
if any(file.filename.strip() for file in single_files):
for file in single_files:
if file.filename.strip(): # Ensure the file is valid
file_name = secure_filename(file.filename)
file_path = os.path.join(upload_base_path, file_name)
print(f"Saving single file to: {file_path}")
file.save(file_path)
# If single file is uploaded, skip folder processing
print("Single file uploaded, skipping folder processing.")
asyncio.run(generate_data_store(upload_base_path, db_name))
return redirect(url_for('list_dbs'))
# Process folder files only if valid files exist
if any(file.filename.strip() for file in files):
for file in files:
if file.filename.strip(): # Ensure it's a valid file
file_name = secure_filename(file.filename)
file_path = os.path.join(upload_base_path, file_name)
print(f"Saving folder file to: {file_path}")
file.save(file_path)
# Generate datastore
#flash("Warning: storing data in DB may take time","warning")
asyncio.run(generate_data_store(upload_base_path, db_name))
flash(f"{db_name} created sucessfully!","success")
return redirect(url_for('list_dbs'))
except Exception as e:
flash(f"Error in Creating DB: {e}","error")
return redirect(url_for('list_dbs'))
return render_template('create_db.html')
########################################################################################################################################################
####------------------------------------------------------- routes for list-dbs and documents -----------------------------------------------------####
########################################################################################################################################################
@app.route('/list-dbs', methods=['GET'])
def list_dbs():
vector_dbs = [name for name in os.listdir(VECTOR_DB_FOLDER) if os.path.isdir(os.path.join(VECTOR_DB_FOLDER, name))]
if vector_dbs==[]:
flash("NO available DBs! Let create new db","error")
return redirect(url_for('create_db'))
return render_template('list_dbs.html', vector_dbs=vector_dbs)
@app.route('/select-db/<db_name>', methods=['POST'])
def select_db(db_name):
flash(f"{db_name} Database has been selected", "success")
#Selecting the Documnet Vector DB
global CHROMA_PATH
global TABLE_PATH
print(f"Selected DB: {CHROMA_PATH}")
print("---------------------------------------------------------")
CHROMA_PATH = os.path.join(VECTOR_DB_FOLDER, db_name)
CHROMA_PATH = CHROMA_PATH.replace("\\", "/")
print(f"Selected DB: {CHROMA_PATH}")
print("---------------------------------------------------------")
# Selecting the Table Vector DB
table_db_path = os.path.join(TABLE_DB_FOLDER, db_name)
table_db_path = table_db_path.replace("\\", "/")
TABLE_PATH = table_db_path if os.path.exists(table_db_path) else None
print(f"Selected Table DB: {TABLE_PATH}")
return redirect(url_for('chat'))
########################################################################################################################################################
####---------------------------------------------------------- routes for modification of dbs -----------------------------------------------------####
########################################################################################################################################################
@app.route('/modify-dbs/<db_name>', methods=['GET','POST'])
def modify_db(db_name):
flash(f"{db_name} Database is selected","success")
print(db_name)
return render_template('modify_dbs.html', db_name=db_name)
########################################################################################################################################################
####--------------------------------------------------------- routes for update exisiting of dbs --------------------------------------------------####
########################################################################################################################################################
@app.route('/update-dbs/<db_name>', methods=['GET','POST'])
def update_db(db_name):
try:
if db_name and request.method == 'POST':
print(db_name)
#vector DB name is db_name
# Get all files from the uploaded folder
files = request.files.getlist('folder') # Folder uploads (multiple files)
single_files = request.files.getlist('file') # Single file uploads
print("============from_update======folder==>", files)
print("============from_update======single_files==>", single_files)
# Ensure at least one valid file is uploaded
if not any(file.filename.strip() for file in files) and not any(file.filename.strip() for file in single_files):
return "No files uploaded", 400
# Create upload directory
upload_base_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(db_name))
print(f"Base Upload Path: {upload_base_path}")
os.makedirs(upload_base_path, exist_ok=True)
# Process single file uploads first (if any exist)
if any(file.filename.strip() for file in single_files):
for file in single_files:
if file.filename.strip(): # Ensure the file is valid
file_name = secure_filename(file.filename)
file_path = os.path.join(upload_base_path, file_name)
print(f"Saving single file to: {file_path}")
file.save(file_path)
# If single file is uploaded, skip folder processing
print("Single file uploaded, skipping folder processing.")
flash(f"{db_name} updated successfully!","success")
asyncio.run(update_data_store(upload_base_path, db_name))
return redirect(url_for('modify_db', db_name=db_name))
# Process folder files only if valid files exist
if any(file.filename.strip() for file in files):
for file in files:
if file.filename.strip(): # Ensure it's a valid file
file_name = secure_filename(file.filename)
file_path = os.path.join(upload_base_path, file_name)
print(f"Saving folder file to: {file_path}")
file.save(file_path)
# Generate datastore
asyncio.run(update_data_store(upload_base_path, db_name))
flash(f"{db_name} updated successfully!","success")
return redirect(url_for('modify_db', db_name=db_name))
except Exception as e:
print("No Database selected for updating")
print(f"got unexpected error {e}")
flash("got unexpected error while updating","error")
return render_template('update_db.html',db_name=db_name)
########################################################################################################################################################
####--------------------------------------------------------- routes for removing the of dbs ------------------------------------------------------####
########################################################################################################################################################
@app.route('/remove-dbs/<db_name>', methods=['GET','POST'])
def remove_db(db_name):
if db_name:
print(db_name)
CHROMA_PATH = f"./VectorDB/{db_name}"
TABLE_PATH = f"./TableDB/{db_name}"
try:
if os.path.exists(CHROMA_PATH):
shutil.rmtree(CHROMA_PATH)
if os.path.exists(TABLE_PATH):
shutil.rmtree(TABLE_PATH)
flash(f"{db_name} Database Removed successfully","success")
return redirect(url_for('list_dbs'))
except Exception as e:
print(f"Error in getting table: {e}")
flash(f"Error in getting table: {e}","error")
return redirect(url_for('list_dbs'))
########################################################################################################################################################
####--------------------------------------------------------- routes for removing specific dbs ----------------------------------------------------####
########################################################################################################################################################
@app.route('/delete-doc/<db_name>', methods=['GET', 'POST'])
def delete_doc(db_name):
try:
DB_PATH = f"./VectorDB/{db_name}"
TAB_PATH = f"./TableDB/{db_name}"
client = chromadb.PersistentClient(path=DB_PATH)
# Select your collection
collection = client.get_collection("langchain")
# Fetch all documents (including metadata)
results = collection.get(include=["metadatas"])
# Extract unique file names from metadata
file_list = set(item["filename"] for item in results["metadatas"] if "filename" in item)
print("file_list", file_list)
if request.method == 'POST':
list_doc = request.form.get('list_doc')
print("list_doc", list_doc)
# Delete from the VectorDB collection
collection.delete(where={"filename": f"{list_doc}"})
flash(f"The document '{list_doc}' has been removed from VectorDB.", "success")
# Check if TAB_PATH exists and delete the document from TableDB if present
if os.path.exists(TAB_PATH):
client_tab = chromadb.PersistentClient(path=TAB_PATH) # Create a new client for TableDB
collect_tab = client_tab.get_collection("langchain")
# Fetch documents in TableDB
result_tab = collect_tab.get(include=["metadatas"])
# Extract unique file names from TableDB metadata
file_list_tab = set(item["filename"] for item in result_tab["metadatas"] if "filename" in item)
print("TableDB file_list:", file_list_tab)
if list_doc in file_list_tab:
collect_tab.delete(where={"filename": f"{list_doc}"}) # Delete the document from TableDB
flash(f"The document '{list_doc}' has also been removed from TableDB.", "success")
else:
flash(f"The document '{list_doc}' was not found in TableDB.", "warning")
else:
print("Note: TableDB does not exist.")
flash(f"TableDB path '{TAB_PATH}' does not exist.", "warning")
return redirect(url_for('modify_db', db_name=db_name))
return render_template('delete_doc.html', db_name=db_name, file_list=file_list)
except Exception as e:
flash(f"Error while deleting documents: {e}", "error")
return redirect(url_for('modify_db', db_name=db_name))
########################################################################################################################################################
####---------------------------------------------------------------------- App MAIN ---------------------------------------------------------------####
########################################################################################################################################################
if __name__ == "__main__":
app.run(debug=False, use_reloader=False)