gaia-chat / app.py
baptiste.bernard
RAG in progress
7dbc57c
raw
history blame
5.83 kB
import os
import sys
from dotenv import load_dotenv
import gradio as gr
from huggingface_hub import InferenceClient
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema import Document
load_dotenv()
hftoken = os.environ.get("HF_TOKEN")
from huggingface_hub import login
login(token=hftoken)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hftoken)
vector_store = None
embeddings = HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def extract_text_from_file(file_path):
"""Extrait le texte d'un fichier PDF ou TXT."""
try:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == ".pdf":
loader = PyPDFLoader(file_path)
pages = loader.load()
docs = [Document(page_content=page.page_content) for page in pages]
elif file_extension == ".txt":
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
docs = [Document(page_content=text)]
else:
return None, "Format non pris en charge. Téléchargez un PDF ou TXT."
return docs, None
except Exception as e:
return None, f"Erreur lors de la lecture du fichier : {e}"
def embed_documents(file):
"""Convertit un document en vecteurs FAISS et génère un résumé."""
global vector_store
docs, error = extract_text_from_file(file.name)
if error:
return error
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
documents = text_splitter.split_documents(docs)
if documents:
vector_store = FAISS.from_documents(documents, embeddings)
full_text = "\n".join([doc.page_content for doc in documents])
summary = summarize_text(full_text)
return f"✅ Document indexé avec succès !\n\n📌 **Résumé du fichier** :\n{summary}"
else:
return "❌ Aucun texte trouvable dans le fichier."
def summarize_text(text):
"""Utilise le modèle HF pour générer un résumé du document."""
messages = [{"role": "system", "content": "Résume ce texte en quelques phrases :"}, {"role": "user", "content": text}]
response = client.chat_completion(messages, max_tokens=200, temperature=0.5)
return response.choices[0].message["content"]
def query_faiss(query):
"""Recherche les documents pertinents dans FAISS et retourne une réponse reformulée."""
if vector_store is None:
return "❌ Aucun document indexé. Téléchargez un fichier."
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
results = retriever.get_relevant_documents(query)
if not results:
return "Je n'ai pas trouvé d'informations pertinentes dans les documents."
context = "\n".join([doc.page_content for doc in results])
messages = [
{"role": "system", "content": "Réponds à la question en utilisant les informations suivantes sans les copier mot pour mot."},
{"role": "user", "content": f"Contexte : {context}\nQuestion : {query}"}
]
response = client.chat_completion(messages, max_tokens=200, temperature=0.5)
return response.choices[0].message["content"]
def respond(message, history, system_message, max_tokens, temperature, top_p, file=None):
"""Gère la réponse du chatbot avec FAISS et Hugging Face."""
global vector_store
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
if file:
response = embed_documents(file)
yield response
return
context = query_faiss(message)
if "❌" not in context:
messages.append({"role": "user", "content": f"Contexte : {context}\nQuestion : {message}"})
else:
messages.append({"role": "user", "content": message})
response = ""
for msg in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = msg.choices[0].delta.content
response += token
yield response
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📚 Chatbot avec intégration de documents")
gr.Image(value="logo-gaia.png", label="Logo")
with gr.Row():
with gr.Column():
gr.Markdown("## ⚙️ Paramètres")
with gr.Accordion("Réglages avancés", open=False):
system_message = gr.Textbox(value="You are a friendly Chatbot.", label="System message")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
file_upload = gr.File(label="📂 Télécharger un fichier PDF ou TXT", file_types=[".pdf", ".txt"], type="filepath")
file_output = gr.Textbox()
file_upload.change(embed_documents, inputs=file_upload, outputs=file_output)
with gr.Column():
gr.Markdown("## 💬 Chat")
chatbot = gr.ChatInterface(
respond,
additional_inputs=[system_message, max_tokens, temperature, top_p, file_upload],
)
if __name__ == "__main__":
demo.launch(share=True)