Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from langchain.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_groq import ChatGroq | |
from langchain.schema import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
import chardet | |
import gradio as gr | |
import pandas as pd | |
import json | |
import re | |
# Enable logging for debugging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Function to clean the API key | |
def clean_api_key(key): | |
return ''.join(c for c in key if ord(c) < 128) | |
# Load the GROQ API key from environment variables (set as a secret in the Space) | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
logger.error("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
raise ValueError("GROQ_API_KEY environment variable is not set. Please add it as a secret.") | |
api_key = clean_api_key(api_key).strip() # Clean and strip whitespace | |
# Function to clean text by removing non-ASCII characters | |
def clean_text(text): | |
return text.encode("ascii", errors="ignore").decode() | |
# Function to load and clean documents from multiple file formats | |
def load_documents(file_paths): | |
docs = [] | |
for file_path in file_paths: | |
ext = os.path.splitext(file_path)[-1].lower() | |
try: | |
if ext == ".csv": | |
# Handle CSV files | |
with open(file_path, 'rb') as f: | |
result = chardet.detect(f.read()) | |
encoding = result['encoding'] | |
data = pd.read_csv(file_path, encoding=encoding) | |
for index, row in data.iterrows(): | |
content = clean_text(row.to_string()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".json": | |
# Handle JSON files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
for entry in data: | |
content = clean_text(json.dumps(entry)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif isinstance(data, dict): | |
content = clean_text(json.dumps(data)) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
elif ext == ".txt": | |
# Handle TXT files | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = clean_text(f.read()) | |
docs.append(Document(page_content=content, metadata={"source": file_path})) | |
else: | |
logger.warning(f"Unsupported file format: {file_path}") | |
except Exception as e: | |
logger.error(f"Error processing file {file_path}: {e}") | |
return docs | |
# Function to ensure the response ends with complete sentences | |
def ensure_complete_sentences(text): | |
# Use regex to find all complete sentences | |
sentences = re.findall(r'[^.!?]*[.!?]', text) | |
if sentences: | |
# Join all complete sentences to form the complete answer | |
return ' '.join(sentences).strip() | |
return text # Return as is if no complete sentence is found | |
# Initialize the LLM using ChatGroq with GROQ's API | |
def initialize_llm(model, temperature, max_tokens): | |
try: | |
# Allocate a portion of tokens for the prompt, e.g., 20% | |
prompt_allocation = int(max_tokens * 0.2) | |
response_max_tokens = max_tokens - prompt_allocation | |
if response_max_tokens <= 50: | |
raise ValueError("max_tokens is too small to allocate for the response.") | |
llm = ChatGroq( | |
model=model, | |
temperature=temperature, | |
max_tokens=response_max_tokens, # Adjusted max_tokens | |
api_key=api_key # Ensure the API key is passed correctly | |
) | |
logger.debug("LLM initialized successfully.") | |
return llm | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {e}") | |
raise | |
# Create the RAG pipeline | |
def create_rag_pipeline(file_paths, model, temperature, max_tokens): | |
try: | |
llm = initialize_llm(model, temperature, max_tokens) | |
docs = load_documents(file_paths) | |
if not docs: | |
logger.warning("No documents were loaded. Please check your file paths and formats.") | |
return None, "No documents were loaded. Please check your file paths and formats." | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
# Initialize the embedding model | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Use a persistent database for Chroma | |
vectorstore = Chroma.from_documents( | |
documents=splits, | |
embedding=embedding_model, | |
persist_directory="./chroma_db" # Specify persistent storage directory | |
) | |
vectorstore.persist() # Save the database to disk | |
logger.debug("Vectorstore initialized and persisted successfully.") | |
retriever = vectorstore.as_retriever() | |
custom_prompt_template = PromptTemplate( | |
input_variables=["context", "question"], | |
template=""" | |
You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity. | |
Context: | |
{context} | |
Question: | |
{question} | |
Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly. | |
""" | |
) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": custom_prompt_template} | |
) | |
logger.debug("RAG pipeline created successfully.") | |
return rag_chain, "Pipeline created successfully." | |
except Exception as e: | |
logger.error(f"Error creating RAG pipeline: {e}") | |
return None, f"Error creating RAG pipeline: {e}" | |
# Function to answer questions with post-processing | |
def answer_question(file_paths, model, temperature, max_tokens, question): | |
rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens) | |
if rag_chain is None: | |
return message | |
try: | |
answer = rag_chain.run(question) | |
logger.debug("Question answered successfully.") | |
# Post-process to ensure the answer ends with complete sentences | |
complete_answer = ensure_complete_sentences(answer) | |
return complete_answer | |
except Exception as e: | |
logger.error(f"Error during RAG pipeline execution: {e}") | |
return f"Error during RAG pipeline execution: {e}" | |
# Gradio Interface | |
def gradio_interface(model, temperature, max_tokens, question): | |
file_paths = ['AIChatbot.csv'] # Ensure this file is present in your Space root directory | |
return answer_question(file_paths, model, temperature, max_tokens, question) | |
# Define Gradio UI | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Model Name", value="llama3-8b-8192"), | |
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.01, value=0.7), | |
gr.Slider(label="Max Tokens", minimum=200, maximum=1024, step=1, value=500), | |
gr.Textbox(label="Question") | |
], | |
outputs="text", | |
title="Daily Wellness AI", | |
description="Ask questions about daily wellness and get detailed solutions." | |
) | |
# Launch Gradio app without share=True (not supported on Hugging Face Spaces) | |
if __name__ == "__main__": | |
interface.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |