Programmes commited on
Commit
2e4368d
·
verified ·
1 Parent(s): 9965101

Update rag_utils.py

Browse files
Files changed (1) hide show
  1. rag_utils.py +19 -8
rag_utils.py CHANGED
@@ -7,39 +7,50 @@ from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, pipeline
8
  from huggingface_hub import InferenceClient
9
 
10
- # Choix du modèle
11
  HF_TOKEN = os.environ.get("edup2")
 
12
 
13
- if HF_TOKEN:
14
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
15
- client = InferenceClient(MODEL_NAME, token=HF_TOKEN)
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
17
- use_client = True
18
- else:
 
 
 
 
 
19
  MODEL_NAME = "google/flan-t5-base"
20
  generator = pipeline("text2text-generation", model=MODEL_NAME)
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
  use_client = False
23
 
 
24
  def load_faiss_index(index_path="faiss_index/faiss_index.faiss", doc_path="faiss_index/documents.pkl"):
25
  index = faiss.read_index(index_path)
26
  with open(doc_path, "rb") as f:
27
  documents = pickle.load(f)
28
  return index, documents
29
 
 
30
  def get_embedding_model():
31
  return SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
32
 
 
33
  def query_index(question, index, documents, model, k=3):
34
  question_embedding = model.encode([question])
35
  _, indices = index.search(np.array(question_embedding).astype("float32"), k)
36
  return [documents[i] for i in indices[0]]
37
 
 
38
  def nettoyer_context(context):
39
  context = re.sub(r"\[\'(.*?)\'\]", r"\1", context)
40
  context = context.replace("None", "")
41
  return context
42
 
 
43
  def generate_answer(question, context):
44
  prompt = f"""Voici des informations sur des établissements et formations :
45
 
@@ -55,4 +66,4 @@ Réponse :"""
55
  return response
56
  else:
57
  result = generator(prompt, max_new_tokens=256, do_sample=True)
58
- return result[0]["generated_text"]
 
7
  from transformers import AutoTokenizer, pipeline
8
  from huggingface_hub import InferenceClient
9
 
10
+ # Token Hugging Face depuis les secrets (Space)
11
  HF_TOKEN = os.environ.get("edup2")
12
+ use_client = False
13
 
14
+ # Tentative de chargement de Mistral
15
+ try:
16
+ if HF_TOKEN:
17
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
18
+ client = InferenceClient(MODEL_NAME, token=HF_TOKEN)
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
20
+ use_client = True
21
+ else:
22
+ raise ValueError("Pas de token trouvé pour Mistral.")
23
+ except Exception as e:
24
+ print(f"⚠️ Impossible de charger Mistral : {e}")
25
  MODEL_NAME = "google/flan-t5-base"
26
  generator = pipeline("text2text-generation", model=MODEL_NAME)
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
  use_client = False
29
 
30
+ # Chargement de l’index FAISS et des documents
31
  def load_faiss_index(index_path="faiss_index/faiss_index.faiss", doc_path="faiss_index/documents.pkl"):
32
  index = faiss.read_index(index_path)
33
  with open(doc_path, "rb") as f:
34
  documents = pickle.load(f)
35
  return index, documents
36
 
37
+ # Modèle d’embedding
38
  def get_embedding_model():
39
  return SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
40
 
41
+ # Recherche dans l’index
42
  def query_index(question, index, documents, model, k=3):
43
  question_embedding = model.encode([question])
44
  _, indices = index.search(np.array(question_embedding).astype("float32"), k)
45
  return [documents[i] for i in indices[0]]
46
 
47
+ # Nettoyage du contexte
48
  def nettoyer_context(context):
49
  context = re.sub(r"\[\'(.*?)\'\]", r"\1", context)
50
  context = context.replace("None", "")
51
  return context
52
 
53
+ # Génération de la réponse
54
  def generate_answer(question, context):
55
  prompt = f"""Voici des informations sur des établissements et formations :
56
 
 
66
  return response
67
  else:
68
  result = generator(prompt, max_new_tokens=256, do_sample=True)
69
+ return result[0]["generated_text"]