gaia-chat / app.py
baptiste.bernard
add RAG and update README
d79eb5f
raw
history blame
5.27 kB
import os
from dotenv import load_dotenv
import gradio as gr
from huggingface_hub import InferenceClient, login
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema import Document
load_dotenv()
hftoken = os.getenv("HFTOKEN")
login(token=hftoken)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hftoken)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = None
def extract_text_from_file(file_path):
"""Extrait le texte d'un fichier PDF ou TXT."""
try:
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
loader = PyPDFLoader(file_path)
pages = loader.load()
return [Document(page_content=page.page_content) for page in pages]
elif ext == ".txt":
with open(file_path, "r", encoding="utf-8") as file:
return [Document(page_content=file.read())]
else:
return None
except Exception as e:
print(f"Erreur extraction : {e}")
return None
def embed_documents(file):
"""Charge un document dans FAISS pour la recherche vectorielle."""
global vector_store
docs = extract_text_from_file(file.name)
if not docs:
return "❌ Erreur lors du chargement du fichier."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
documents = text_splitter.split_documents(docs)
if documents:
vector_store = FAISS.from_documents(documents, embeddings)
return "📁 Document chargé et indexé ! Posez-moi une question."
else:
return "❌ Aucun texte valide extrait du fichier."
def query_faiss(query):
"""Effectue une recherche dans FAISS et génère une réponse."""
global vector_store
if vector_store is None:
return None
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
results = retriever.get_relevant_documents(query)
if results:
context = "\n".join([doc.page_content for doc in results])
return context
return None
def chatbot_response(message, history, system_message, max_tokens, temperature, top_p, file=None):
"""Gère les réponses du chatbot avec ou sans document."""
global vector_store
if file:
status = embed_documents(file)
if "❌" in status:
yield status
return
yield status
context = query_faiss(message) if vector_store else None
messages = [{"role": "system", "content": system_message + " Réponds uniquement en français."}]
for val in history:
if val[0]: messages.append({"role": "user", "content": val[0]})
if val[1]: messages.append({"role": "assistant", "content": val[1]})
if context:
messages.append({"role": "user", "content": f"Contexte : {context}\nQuestion : {message}"})
else:
messages.append({"role": "user", "content": message})
response = ""
try:
response_stream = client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
)
for msg in response_stream:
if msg.choices and msg.choices[0].delta and msg.choices[0].delta.content:
token = msg.choices[0].delta.content
response += token
yield response
if not response.strip():
yield "Je ne sais pas. Peux-tu reformuler ?"
except Exception as e:
print(f"Erreur réponse : {e}")
yield "❌ Une erreur est survenue lors de la génération de la réponse."
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
gr.Image("logo-gaia.png", width=300, height=300, show_label=False, show_download_button=False)
with gr.Row():
gr.Markdown("<h1 style='text-align: center;'>📚 Chatbot GAIA</h1>")
with gr.Row():
with gr.Column():
gr.Markdown("## ⚙️ Paramètres")
with gr.Accordion(" Paramètres avancés", open=False):
system_message = gr.Textbox(value="Réponds de façon simple et claire.", label="Message système")
max_tokens = gr.Slider(1, 2048, value=800, step=1, label="Max tokens")
temperature = gr.Slider(0.1, 4.0, value=0.3, step=0.1, label="Température")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
gr.Markdown("## 📂 Télécharger un fichier")
file_upload = gr.File(label="Téléchargez un PDF ou TXT", file_types=[".pdf", ".txt"], type="filepath")
file_upload.change(embed_documents, inputs=file_upload, outputs=[])
with gr.Column():
gr.Markdown("## 💬 Chat")
chatbot = gr.ChatInterface(
chatbot_response,
additional_inputs=[system_message, max_tokens, temperature, top_p, file_upload],
)
if __name__ == "__main__":
demo.launch(share=True)