Phoenix21's picture
ew version
189f4f3
raw
history blame
9.91 kB
import os
import logging
import re
from langchain.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import chardet
import gradio as gr
import pandas as pd
import json
# Enable logging for debugging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Function to clean the API key
def clean_api_key(key):
return ''.join(c for c in key if ord(c) < 128)
# Load the GROQ API key from environment variables (set as a secret in the Space)
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
api_key = clean_api_key(api_key).strip() # Clean and strip whitespace
# Function to clean text by removing non-ASCII characters
def clean_text(text):
return text.encode("ascii", errors="ignore").decode()
# Function to load and clean documents from multiple file formats
def load_documents(file_paths):
docs = []
for file_path in file_paths:
ext = os.path.splitext(file_path)[-1].lower()
try:
if ext == ".csv":
# Handle CSV files
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
data = pd.read_csv(file_path, encoding=encoding)
for index, row in data.iterrows():
content = clean_text(row.to_string())
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".json":
# Handle JSON files
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
for entry in data:
content = clean_text(json.dumps(entry))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif isinstance(data, dict):
content = clean_text(json.dumps(data))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".txt":
# Handle TXT files
with open(file_path, 'r', encoding='utf-8') as f:
content = clean_text(f.read())
docs.append(Document(page_content=content, metadata={"source": file_path}))
else:
logger.warning(f"Unsupported file format: {file_path}")
except Exception as e:
logger.error(f"Error processing file {file_path}: {e}")
logger.debug("Exception details:", exc_info=True)
return docs
# Function to ensure the response ends with complete sentences
def ensure_complete_sentences(text):
# Use regex to find all complete sentences
sentences = re.findall(r'[^.!?]*[.!?]', text)
if sentences:
# Join all complete sentences to form the complete answer
return ' '.join(sentences).strip()
return text # Return as is if no complete sentence is found
# Function to check if input is valid
def is_valid_input(text):
"""
Checks if the input text is meaningful.
Returns True if the text contains alphabetic characters and is of sufficient length.
"""
if not text or text.strip() == "":
return False
# Regex to check for at least one alphabetic character
if not re.search('[A-Za-z]', text):
return False
# Additional check: minimum length
if len(text.strip()) < 5:
return False
return True
# Initialize the LLM using ChatGroq with GROQ's API
def initialize_llm(model, temperature, max_tokens):
try:
# Allocate a portion of tokens for the prompt
prompt_allocation = int(max_tokens * 0.2)
response_max_tokens = max_tokens - prompt_allocation
if response_max_tokens <= 50:
raise ValueError("max_tokens is too small to allocate for the response.")
llm = ChatGroq(
model=model,
temperature=temperature,
max_tokens=response_max_tokens,
api_key=api_key
)
logger.info("LLM initialized successfully.")
return llm
except Exception as e:
logger.error(f"Error initializing LLM: {e}")
raise
# Create the RAG pipeline
def create_rag_pipeline(file_paths, model, temperature, max_tokens):
try:
llm = initialize_llm(model, temperature, max_tokens)
docs = load_documents(file_paths)
if not docs:
logger.warning("No documents were loaded. Please check your file paths and formats.")
return None, "No documents were loaded. Please check your file paths and formats."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Initialize the embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Use a temporary directory for Chroma vectorstore
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory="/tmp/chroma_db"
)
vectorstore.persist() # Save the database to disk
logger.info("Vectorstore initialized and persisted successfully.")
retriever = vectorstore.as_retriever()
custom_prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity.
Context:
{context}
Question:
{question}
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly.
"""
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": custom_prompt_template}
)
logger.info("RAG pipeline created successfully.")
return rag_chain, "Pipeline created successfully."
except Exception as e:
logger.error(f"Error creating RAG pipeline: {e}")
logger.debug("Exception details:", exc_info=True)
return None, f"Error creating RAG pipeline: {e}"
# Initialize the RAG pipeline once at startup
file_paths = ['AIChatbot.csv']
model = "llama3-8b-8192"
temperature = 0.7
max_tokens = 500
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
if rag_chain is None:
logger.error("Failed to initialize RAG pipeline at startup.")
# Function to answer questions with input validation and post-processing
def answer_question(model, temperature, max_tokens, question):
# Validate input
if not is_valid_input(question):
logger.info("Received invalid input from user.")
return "Please provide a valid question or input containing meaningful text."
if rag_chain is None:
logger.error("RAG pipeline is not initialized.")
return "The system is currently unavailable. Please try again later."
try:
answer = rag_chain.run(question)
logger.info("Question answered successfully.")
# Post-process to ensure the answer ends with complete sentences
complete_answer = ensure_complete_sentences(answer)
return complete_answer
except Exception as e_inner:
logger.error(f"Error during RAG pipeline execution: {e_inner}")
logger.debug("Exception details:", exc_info=True)
return f"Error during RAG pipeline execution: {e_inner}"
# Gradio Interface (no feedback)
def gradio_interface(model, temperature, max_tokens, question):
return answer_question(model, temperature, max_tokens, question)
# Define Gradio UI
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(
label="Model Name",
value=model,
placeholder="e.g., llama3-8b-8192"
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=1,
step=0.01,
value=temperature,
info="Controls the randomness of the response. Higher values make output more random."
),
gr.Slider(
label="Max Tokens",
minimum=200,
maximum=2048,
step=1,
value=max_tokens,
info="Determines the maximum number of tokens in the response."
),
gr.Textbox(
label="Question",
placeholder="e.g., What is box breathing and how does it help reduce anxiety?"
)
],
outputs="text",
title="Daily Wellness AI",
description="Ask questions about daily wellness and get detailed solutions.",
examples=[
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"],
["llama3-8b-8192", 0.6, 600, "Provide a daily wellness schedule incorporating box breathing techniques."]
],
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)