Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
from huggingface_hub import InferenceClient, login | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.schema import Document | |
load_dotenv() | |
hftoken = os.getenv("HFTOKEN") | |
prompt_systeme = os.getenv("PROMPT_SYSTEM") | |
questions = os.getenv("QUESTIONS").split(",") | |
login(token=hftoken) | |
client = InferenceClient(model="meta-llama/Llama-3.3-70B-Instruct", token=hftoken) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_store = None | |
def extract_text_from_file(file_path): | |
"""Extrait le texte d'un fichier PDF ou TXT.""" | |
try: | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext == ".pdf": | |
loader = PyPDFLoader(file_path) | |
pages = loader.load() | |
return [Document(page_content=page.page_content) for page in pages] | |
elif ext == ".txt": | |
with open(file_path, "r", encoding="utf-8") as file: | |
return [Document(page_content=file.read())] | |
else: | |
return None | |
except Exception as e: | |
print(f"Erreur extraction : {e}") | |
return None | |
def embed_documents(file): | |
"""Charge un document dans FAISS pour la recherche vectorielle.""" | |
global vector_store | |
docs = extract_text_from_file(file.name) | |
if not docs: | |
return "❌ Erreur lors du chargement du fichier." | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
documents = text_splitter.split_documents(docs) | |
if documents: | |
vector_store = FAISS.from_documents(documents, embeddings) | |
return "📁 Document chargé et indexé ! Posez-moi une question." | |
else: | |
return "❌ Aucun texte valide extrait du fichier." | |
def query_faiss(query): | |
"""Effectue une recherche dans FAISS et génère une réponse.""" | |
global vector_store | |
if vector_store is None: | |
return None | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
results = retriever.get_relevant_documents(query) | |
if results: | |
context = "\n".join([doc.page_content for doc in results]) | |
return context | |
return None | |
def chatbot_response(message, history, system_message, max_tokens, temperature, top_p, file=None): | |
"""Gère les réponses du chatbot avec ou sans document.""" | |
global vector_store | |
if file: | |
status = embed_documents(file) | |
if "❌" in status: | |
yield status | |
return | |
yield status | |
context = query_faiss(message) if vector_store else None | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: messages.append({"role": "user", "content": val[0]}) | |
if val[1]: messages.append({"role": "assistant", "content": val[1]}) | |
if context: | |
messages.append({"role": "user", "content": f"Contexte : {context}\nQuestion : {message}"}) | |
else: | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
try: | |
response_stream = client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
for msg in response_stream: | |
if msg.choices and msg.choices[0].delta and msg.choices[0].delta.content: | |
token = msg.choices[0].delta.content | |
response += token | |
yield response | |
if not response.strip(): | |
yield "Je ne sais pas. Peux-tu reformuler ?" | |
except Exception as e: | |
print(f"Erreur réponse : {e}") | |
yield "❌ Une erreur est survenue lors de la génération de la réponse." | |
def handle_question(selected_question): | |
"""Gère la sélection d'une question pré-remplie.""" | |
return selected_question.strip() | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
with gr.Row(): | |
gr.Image("logo-gaia.png", width=300, height=300, show_label=False, show_download_button=False) | |
with gr.Row(): | |
gr.Markdown("<h1 style='text-align: center;'>📚 Chatbot GAIA</h1>") | |
with gr.Row(): | |
gr.Markdown("<p style='text-align: center;'>Exemple : Allez sur <a href='https://www.societe.ninja/' target='_blank'>société.ninja</a> avec votre numéro de SIREN et utilisez un PDF.</p>") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## ⚙️ Paramètres") | |
with gr.Accordion(" Paramètres avancés", open=False): | |
system_message = gr.Textbox(value=prompt_systeme, label="Message système") | |
max_tokens = gr.Slider(1, 2048, value=800, step=1, label="Max tokens") | |
temperature = gr.Slider(0.1, 4.0, value=0.3, step=0.1, label="Température") | |
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p") | |
gr.Markdown("## 📂 Télécharger un fichier") | |
file_upload = gr.File(label="Téléchargez un PDF ou TXT", file_types=[".pdf", ".txt"], type="filepath") | |
file_upload.change(embed_documents, inputs=file_upload, outputs=[]) | |
gr.Markdown("## ❓ Questions pré-remplies") | |
question_buttons = [] | |
for question in questions: | |
button = gr.Button(value=question.strip()) | |
question_buttons.append(button) | |
selected_question_output = gr.Textbox(label="Question sélectionnée", interactive=False) | |
for button in question_buttons: | |
button.click(handle_question, inputs=button, outputs=selected_question_output) | |
with gr.Column(): | |
gr.Markdown("## 💬 Chat") | |
chatbot = gr.ChatInterface( | |
chatbot_response, | |
additional_inputs=[system_message, max_tokens, temperature, top_p, file_upload], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |