Phoenix21's picture
Handled input like "sleep anxiety"
62b9066 verified
raw
history blame
6.81 kB
import os
import logging
import re
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import chardet
import gradio as gr
import pandas as pd
import json
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def clean_api_key(key):
return ''.join(c for c in key if ord(c) < 128)
# Load the GROQ API key
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.")
api_key = clean_api_key(api_key).strip()
def clean_text(text):
return text.encode("ascii", errors="ignore").decode()
def load_documents(file_paths):
docs = []
for file_path in file_paths:
ext = os.path.splitext(file_path)[-1].lower()
try:
if ext == ".csv":
with open(file_path, 'rb') as f:
result = chardet.detect(f.read())
encoding = result['encoding']
data = pd.read_csv(file_path, encoding=encoding)
for _, row in data.iterrows():
content = clean_text(row.to_string())
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".json":
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
for entry in data:
content = clean_text(json.dumps(entry))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif isinstance(data, dict):
content = clean_text(json.dumps(data))
docs.append(Document(page_content=content, metadata={"source": file_path}))
elif ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
content = clean_text(f.read())
docs.append(Document(page_content=content, metadata={"source": file_path}))
else:
logger.warning(f"Unsupported file format: {file_path}")
except Exception as e:
logger.error(f"Error processing file {file_path}: {e}")
return docs
# Simplify input validation
def is_valid_input(text):
"""Validate the user's input question."""
if not text or text.strip() == "":
return False, "Input cannot be empty. Please provide a meaningful question."
if len(text.strip()) < 2:
return False, "Input is too short. Please provide more context or details."
# Check if the input has at least one valid word
words = re.findall(r'\b\w+\b', text)
if len(words) < 1: # Require at least one recognizable word
return False, "Input appears incomplete. Please provide a meaningful question."
return True, "Valid input."
def initialize_llm(model, temperature, max_tokens):
prompt_allocation = int(max_tokens * 0.2)
response_max_tokens = max_tokens - prompt_allocation
if response_max_tokens <= 50:
raise ValueError("max_tokens too small.")
llm = ChatGroq(
model=model,
temperature=temperature,
max_tokens=response_max_tokens,
api_key=api_key
)
return llm
def create_rag_pipeline(file_paths, model, temperature, max_tokens):
llm = initialize_llm(model, temperature, max_tokens)
docs = load_documents(file_paths)
if not docs:
return None, "No documents were loaded."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embedding_model,
persist_directory="/tmp/chroma_db"
)
retriever = vectorstore.as_retriever()
custom_prompt_template = PromptTemplate(
input_variables=["context", "question"],
template="""
You are an AI assistant specialized in daily wellness. Provide a concise, thorough, and stand-alone answer to the user's question based on the given context. Include relevant examples or schedules where beneficial. **When listing steps or guidelines, format them as a numbered list with appropriate markdown formatting.** The final answer should be coherent, self-contained, and end with a complete sentence.
Context:
{context}
Question:
{question}
Final Answer:
"""
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": custom_prompt_template}
)
return rag_chain, "Pipeline created successfully."
file_paths = ['AIChatbot.csv']
model = "llama3-8b-8192"
temperature = 0.7
max_tokens = 500
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
def answer_question(model, temperature, max_tokens, question):
is_valid, message = is_valid_input(question)
if not is_valid:
return message
if rag_chain is None:
return "The system is currently unavailable. Please try again later."
try:
answer = rag_chain.run(question)
return answer.strip()
except Exception as e_inner:
logger.error(f"Error: {e_inner}")
return "An error occurred while processing your request."
def gradio_interface(model, temperature, max_tokens, question):
return answer_question(model, temperature, max_tokens, question)
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(label="Model Name", value=model),
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=temperature),
gr.Slider(label="Max Tokens", minimum=200, maximum=2048, step=1, value=max_tokens),
gr.Textbox(label="Question", placeholder="e.g., What is box breathing and how does it help reduce anxiety?")
],
outputs=gr.Markdown(label="Answer"),
title="Daily Wellness AI",
description="Ask questions about daily wellness and receive a concise, complete answer.",
examples=[
["llama3-8b-8192", 0.7, 500, "What is box breathing and how does it help reduce anxiety?"],
["llama3-8b-8192", 0.6, 600, "Give me a weekly fitness schedule incorporating mindfulness exercises."]
],
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True)