Phoenix21's picture
revised app.py
f44a0c0
raw
history blame
8.17 kB
import os
import logging
from langchain.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import chardet
import gradio as gr
import pandas as pd
import json
import re
# Enable logging for debugging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Function to clean the API key
def clean_api_key(key):
return ''.join(c for c in key if ord(c) < 128)
# Load the GROQ API key from environment variables (set as a secret in the Space)
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
api_key = clean_api_key(api_key).strip() # Clean and strip whitespace
# Function to clean text by removing non-ASCII characters
def clean_text(text):
return text.encode("ascii", errors="ignore").decode()
# Function to load and clean documents from multiple file formats
def load_documents(file_paths):
docs = []
for file_path in file_paths:
ext = os.path.splitext(file_path)[-1].lower()
try:
if ext == ".csv":
# Handle CSV files
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
data = pd.read_csv(file_path, encoding=encoding)
for index, row in data.iterrows():
content = clean_text(row.to_string())
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".json":
# Handle JSON files
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
for entry in data:
content = clean_text(json.dumps(entry))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif isinstance(data, dict):
content = clean_text(json.dumps(data))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".txt":
# Handle TXT files
with open(file_path, 'r', encoding='utf-8') as f:
content = clean_text(f.read())
docs.append(Document(page_content=content, metadata={"source": file_path}))
else:
logger.warning(f"Unsupported file format: {file_path}")
except Exception as e:
logger.error(f"Error processing file {file_path}: {e}")
return docs
# Function to ensure the response ends with complete sentences
def ensure_complete_sentences(text):
# Use regex to find all complete sentences
sentences = re.findall(r'[^.!?]*[.!?]', text)
if sentences:
# Join all complete sentences to form the complete answer
return ' '.join(sentences).strip()
return text # Return as is if no complete sentence is found
# Initialize the LLM using ChatGroq with GROQ's API
def initialize_llm(model, temperature, max_tokens):
try:
# Allocate a portion of tokens for the prompt, e.g., 20%
prompt_allocation = int(max_tokens * 0.2)
response_max_tokens = max_tokens - prompt_allocation
if response_max_tokens <= 50:
raise ValueError("max_tokens is too small to allocate for the response.")
llm = ChatGroq(
model=model,
temperature=temperature,
max_tokens=response_max_tokens, # Adjusted max_tokens
api_key=api_key # Ensure the API key is passed correctly
)
logger.debug("LLM initialized successfully.")
return llm
except Exception as e:
logger.error(f"Error initializing LLM: {e}")
raise
# Create the RAG pipeline
def create_rag_pipeline(file_paths, model, temperature, max_tokens):
try:
llm = initialize_llm(model, temperature, max_tokens)
docs = load_documents(file_paths)
if not docs:
logger.warning("No documents were loaded. Please check your file paths and formats.")
return None, "No documents were loaded. Please check your file paths and formats."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Initialize the embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Use a persistent database for Chroma
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory="./chroma_db" # Specify persistent storage directory
)
vectorstore.persist() # Save the database to disk
logger.debug("Vectorstore initialized and persisted successfully.")
retriever = vectorstore.as_retriever()
custom_prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity.
Context:
{context}
Question:
{question}
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly.
"""
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": custom_prompt_template}
)
logger.debug("RAG pipeline created successfully.")
return rag_chain, "Pipeline created successfully."
except Exception as e:
logger.error(f"Error creating RAG pipeline: {e}")
return None, f"Error creating RAG pipeline: {e}"
# Function to answer questions with post-processing
def answer_question(file_paths, model, temperature, max_tokens, question):
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
if rag_chain is None:
return message
try:
answer = rag_chain.run(question)
logger.debug("Question answered successfully.")
# Post-process to ensure the answer ends with complete sentences
complete_answer = ensure_complete_sentences(answer)
return complete_answer
except Exception as e:
logger.error(f"Error during RAG pipeline execution: {e}")
return f"Error during RAG pipeline execution: {e}"
# Gradio Interface
def gradio_interface(model, temperature, max_tokens, question):
file_paths = ['AIChatbot.csv'] # Ensure this file is present in your Space root directory
return answer_question(file_paths, model, temperature, max_tokens, question)
# Define Gradio UI
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Model Name", value="llama3-8b-8192"),
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=0.7),
gr.Slider(label="Max Tokens", minimum=200, maximum=1024, step=1, value=500),
gr.Textbox(label="Question")
],
outputs="text",
title="Daily Wellness AI",
description="Ask questions about daily wellness and get detailed solutions."
)
# Launch Gradio app without share=True (not supported on Hugging Face Spaces)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)