File size: 6,126 Bytes
9cc30f1
0aecada
9f33690
d79eb5f
b374d39
7dbc57c
 
 
 
9f33690
0aecada
d79eb5f
00f6950
b2b0700
00f6950
 
a602253
00f6950
d79eb5f
7dbc57c
d79eb5f
9f33690
370a2b2
 
 
d79eb5f
 
7dbc57c
 
d79eb5f
 
370a2b2
d79eb5f
370a2b2
d79eb5f
370a2b2
d79eb5f
 
7dbc57c
 
d79eb5f
7dbc57c
d79eb5f
 
 
 
7dbc57c
 
 
 
 
d79eb5f
7dbc57c
d79eb5f
7dbc57c
 
d79eb5f
 
7dbc57c
d79eb5f
 
7dbc57c
 
 
d79eb5f
 
 
 
7dbc57c
d79eb5f
 
7dbc57c
d79eb5f
 
 
 
 
 
 
 
00f6950
7dbc57c
9f33690
d79eb5f
 
9f33690
d79eb5f
7dbc57c
 
 
 
 
d79eb5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f33690
b2b0700
 
 
 
 
 
a602253
d79eb5f
 
a602253
d79eb5f
 
b2b0700
 
 
 
a602253
 
7dbc57c
d79eb5f
00f6950
d79eb5f
 
 
b2b0700
d79eb5f
 
 
b2b0700
 
 
 
 
 
 
 
 
 
a602253
 
7dbc57c
a602253
d79eb5f
a602253
 
9f33690
 
7dbc57c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
from dotenv import load_dotenv
import gradio as gr
from huggingface_hub import InferenceClient, login
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.schema import Document

load_dotenv()
hftoken = os.getenv("HFTOKEN")
prompt_systeme = os.getenv("PROMPT_SYSTEM")
questions = os.getenv("QUESTIONS").split(",")  


login(token=hftoken)
client = InferenceClient(model="meta-llama/Llama-3.3-70B-Instruct", token=hftoken)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

vector_store = None

def extract_text_from_file(file_path):
    """Extrait le texte d'un fichier PDF ou TXT."""
    try:
        ext = os.path.splitext(file_path)[1].lower()
        if ext == ".pdf":
            loader = PyPDFLoader(file_path)
            pages = loader.load()
            return [Document(page_content=page.page_content) for page in pages]
        elif ext == ".txt":
            with open(file_path, "r", encoding="utf-8") as file:
                return [Document(page_content=file.read())]
        else:
            return None
    except Exception as e:
        print(f"Erreur extraction : {e}")
        return None

def embed_documents(file):
    """Charge un document dans FAISS pour la recherche vectorielle."""
    global vector_store
    docs = extract_text_from_file(file.name)
    if not docs:
        return "❌ Erreur lors du chargement du fichier."
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    documents = text_splitter.split_documents(docs)

    if documents:
        vector_store = FAISS.from_documents(documents, embeddings)
        return "📁 Document chargé et indexé ! Posez-moi une question."
    else:
        return "❌ Aucun texte valide extrait du fichier."

def query_faiss(query):
    """Effectue une recherche dans FAISS et génère une réponse."""
    global vector_store
    if vector_store is None:
        return None
    
    retriever = vector_store.as_retriever(search_kwargs={"k": 3})
    results = retriever.get_relevant_documents(query)

    if results:
        context = "\n".join([doc.page_content for doc in results])
        return context
    return None

def chatbot_response(message, history, system_message, max_tokens, temperature, top_p, file=None):
    """Gère les réponses du chatbot avec ou sans document."""
    global vector_store
    if file:
        status = embed_documents(file)
        if "❌" in status:
            yield status
            return
        yield status

    context = query_faiss(message) if vector_store else None
    messages = [{"role": "system", "content": system_message}]
    
    for val in history:
        if val[0]: messages.append({"role": "user", "content": val[0]})
        if val[1]: messages.append({"role": "assistant", "content": val[1]})

    if context:
        messages.append({"role": "user", "content": f"Contexte : {context}\nQuestion : {message}"})
    else:
        messages.append({"role": "user", "content": message})

    response = ""
    try:
        response_stream = client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        )

        for msg in response_stream:
            if msg.choices and msg.choices[0].delta and msg.choices[0].delta.content:
                token = msg.choices[0].delta.content
                response += token
                yield response

        if not response.strip():
            yield "Je ne sais pas. Peux-tu reformuler ?"
    except Exception as e:
        print(f"Erreur réponse : {e}")
        yield "❌ Une erreur est survenue lors de la génération de la réponse."

def handle_question(selected_question):
    """Gère la sélection d'une question pré-remplie."""
    return selected_question.strip()  



with gr.Blocks(theme=gr.themes.Soft()) as demo:
    with gr.Row():
        gr.Image("logo-gaia.png", width=300, height=300, show_label=False, show_download_button=False)

    with gr.Row():
        gr.Markdown("<h1 style='text-align: center;'>📚 Chatbot GAIA</h1>")

    with gr.Row():
        gr.Markdown("<p style='text-align: center;'>Exemple : Allez sur <a href='https://www.societe.ninja/' target='_blank'>société.ninja</a> avec votre numéro de SIREN et utilisez un PDF.</p>")

    with gr.Row():
        with gr.Column():
            gr.Markdown("## ⚙️ Paramètres")
            with gr.Accordion(" Paramètres avancés", open=False):
                system_message = gr.Textbox(value=prompt_systeme, label="Message système")
                max_tokens = gr.Slider(1, 2048, value=800, step=1, label="Max tokens")  
                temperature = gr.Slider(0.1, 4.0, value=0.3, step=0.1, label="Température")
                top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")

            gr.Markdown("## 📂 Télécharger un fichier")
            file_upload = gr.File(label="Téléchargez un PDF ou TXT", file_types=[".pdf", ".txt"], type="filepath")
            file_upload.change(embed_documents, inputs=file_upload, outputs=[])
            gr.Markdown("## ❓ Questions pré-remplies")
            question_buttons = []
            for question in questions:
                button = gr.Button(value=question.strip())
                question_buttons.append(button)

            selected_question_output = gr.Textbox(label="Question sélectionnée", interactive=False)

            for button in question_buttons:
                button.click(handle_question, inputs=button, outputs=selected_question_output)

        with gr.Column():
            gr.Markdown("## 💬 Chat")
            chatbot = gr.ChatInterface(
                chatbot_response,
                additional_inputs=[system_message, max_tokens, temperature, top_p, file_upload],
            )

if __name__ == "__main__":
    demo.launch(share=True)