StreamLight commited on
Commit
2ca61f3
·
verified ·
1 Parent(s): a4705d3

Upload 4 files

Browse files
Files changed (3) hide show
  1. app3.py +110 -0
  2. chapes-fluides.xlsx +0 -0
  3. requirements.txt +128 -0
app3.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
+ import re
4
+ import pandas as pd
5
+ import numpy as np
6
+ import faiss
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+
10
+ headers = {"Authorization": "Bearer {API_TOKEN}"}
11
+ API_URL = "https://api-inference.huggingface.co/models/"
12
+ df = pd.read_excel('chapes-fluides.xlsx')
13
+ inference_client = InferenceClient(token=API_TOKEN)
14
+
15
+ # Function to vectorize text - assuming this is already defined in your code
16
+ def create_index(data, text_column, model):
17
+ # Encode the text column to generate embeddings
18
+ embeddings = model.encode(data[text_column].tolist())
19
+
20
+ # Dimension of embeddings
21
+ dimension = embeddings.shape[1]
22
+
23
+ # Prepare the embeddings and their IDs for FAISS
24
+ db_vectors = embeddings.astype(np.float32)
25
+ db_ids = np.arange(len(data)).astype(np.int64)
26
+
27
+ # Normalize the embeddings
28
+ faiss.normalize_L2(db_vectors)
29
+
30
+ # Create and configure the FAISS index
31
+ index = faiss.IndexFlatIP(dimension)
32
+ index = faiss.IndexIDMap(index)
33
+ index.add_with_ids(db_vectors, db_ids)
34
+
35
+ return index, embeddings
36
+
37
+ #Function to vectorize txt, use model.encode
38
+ def vectorize_text(model, text):
39
+ # Encode the question to generate its embedding
40
+ question_embedding = model.encode([text])
41
+
42
+ # Convert to float32 for compatibility with FAISS
43
+ question_embedding = question_embedding.astype(np.float32)
44
+
45
+ # Normalize the embedding
46
+ faiss.normalize_L2(question_embedding)
47
+
48
+ return question_embedding
49
+
50
+ def extract_context(indices, df,i):
51
+ # Extracting only the first index
52
+ index_i = indices[0][i]
53
+ context = df.iloc[index_i]['text_segment']
54
+ return context
55
+
56
+ def generate_answer_from_context(context, client, model,prompt):
57
+ try:
58
+ # Use a hypothetical text generation method if available
59
+ answer = client.text_generation(prompt=prompt, model=model, max_new_tokens=250)
60
+
61
+ answer_cleaned = re.sub(r'^.*Answer:', '', answer).strip()
62
+ return answer_cleaned
63
+ except Exception as e:
64
+ print(f"Error encountered: {e}")
65
+ return None
66
+
67
+
68
+ # Load model
69
+ model_sentence_transformers = SentenceTransformer('intfloat/multilingual-e5-base')
70
+ model_reponse_mixtral_instruct="mistralai/Mixtral-8x7B-Instruct-v0.1"
71
+
72
+ #Load the index
73
+ index_reloaded = faiss.read_index("./index/chapes_fluides_e5.index")
74
+
75
+ K=2
76
+
77
+ # Streamlit app interface
78
+ st.title("CSTB App")
79
+
80
+ if "messages" not in st.session_state:
81
+ st.session_state.messages = []
82
+
83
+ if user_question := st.chat_input("Votre question : "):
84
+ # Vectorize the user question and search in the FAISS index
85
+ st.session_state.messages.append({"role": "user", "content": user_question})
86
+ question_embedding = vectorize_text(model_sentence_transformers, user_question)
87
+ D, I = index_reloaded.search(question_embedding, K) # question_embedding is already 2D
88
+
89
+ # Extract context for the top K results
90
+ context = extract_context(I, df, 0) + ' ' + extract_context(I, df, 1)
91
+ prompts = [
92
+ f"Répondre à cette question : {user_question} en utilisant le contexte suivant {context}. Etre le plus précis possible et ne pas faire de phrase qui ne se finit pas \nReponse:"
93
+ #Autre prompt possible
94
+ #f"Contexte: {context}\nQuestion: {user_question}\nReponse:",
95
+
96
+ ]
97
+
98
+ # Generate answers using different prompts
99
+ answers = [generate_answer_from_context(context, inference_client, model_reponse_mixtral_instruct,prompts[i]) for i in range(len(prompts))]
100
+ # Display answers
101
+ for i, answer in enumerate(answers):
102
+ if answer:
103
+ st.session_state.messages.append({"role": "assistant", "content": answer})
104
+ #st.markdown(answer)
105
+ #st.session_state.messages.append({"role": "assistant", "content": answer})
106
+ else:
107
+ st.session_state.messages.append({"role": "assistant", "content": "Failed to generate an answer."})
108
+ for message in st.session_state.messages:
109
+ with st.chat_message(message["role"]):
110
+ st.markdown(message["content"])
chapes-fluides.xlsx ADDED
Binary file (61.7 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ altair==5.2.0
3
+ asttokens==2.4.1
4
+ astunparse==1.6.3
5
+ attrs==23.2.0
6
+ blinker==1.7.0
7
+ cachetools==5.3.2
8
+ certifi==2023.11.17
9
+ charset-normalizer==3.3.2
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ comm==0.2.0
13
+ debugpy==1.8.0
14
+ decorator==5.1.1
15
+ distlib==0.3.8
16
+ et-xmlfile==1.1.0
17
+ exceptiongroup==1.2.0
18
+ executing==2.0.1
19
+ faiss-cpu==1.7.4
20
+ filelock==3.13.1
21
+ flatbuffers==23.5.26
22
+ fsspec==2023.12.2
23
+ gast==0.4.0
24
+ gitdb==4.0.11
25
+ GitPython==3.1.40
26
+ google-auth==2.25.2
27
+ google-auth-oauthlib==1.2.0
28
+ google-pasta==0.2.0
29
+ grpcio==1.60.0
30
+ h5py==3.10.0
31
+ huggingface-hub==0.19.4
32
+ idna==3.6
33
+ importlib-metadata==6.11.0
34
+ ipykernel==6.27.1
35
+ ipython==8.18.1
36
+ jedi==0.19.1
37
+ Jinja2==3.1.2
38
+ joblib==1.3.2
39
+ jsonschema==4.20.0
40
+ jsonschema-specifications==2023.12.1
41
+ jupyter_client==8.6.0
42
+ jupyter_core==5.5.0
43
+ keras==2.15.0
44
+ Keras-Preprocessing==1.1.2
45
+ libclang==16.0.6
46
+ Markdown==3.5.1
47
+ markdown-it-py==3.0.0
48
+ MarkupSafe==2.1.3
49
+ matplotlib-inline==0.1.6
50
+ mdurl==0.1.2
51
+ ml-dtypes==0.2.0
52
+ mpmath==1.3.0
53
+ nest-asyncio==1.5.8
54
+ networkx==3.2.1
55
+ nltk==3.8.1
56
+ numpy==1.26.2
57
+ oauthlib==3.2.2
58
+ openpyxl==3.1.2
59
+ opt-einsum==3.3.0
60
+ packaging==23.2
61
+ pandas==2.1.4
62
+ parso==0.8.3
63
+ pillow==10.2.0
64
+ platformdirs==4.1.0
65
+ prompt-toolkit==3.0.43
66
+ protobuf==4.23.4
67
+ psutil==5.9.7
68
+ pure-eval==0.2.2
69
+ pyarrow==14.0.2
70
+ pyasn1==0.5.1
71
+ pyasn1-modules==0.3.0
72
+ pydeck==0.8.1b0
73
+ Pygments==2.17.2
74
+ PyMuPDF==1.23.7
75
+ PyMuPDFb==1.23.7
76
+ python-dateutil==2.8.2
77
+ pytz==2023.3.post1
78
+ pywin32==306
79
+ PyYAML==6.0.1
80
+ pyzmq==25.1.2
81
+ referencing==0.32.0
82
+ regex==2023.10.3
83
+ requests==2.31.0
84
+ requests-oauthlib==1.3.1
85
+ rich==13.7.0
86
+ rpds-py==0.16.2
87
+ rsa==4.9
88
+ safetensors==0.4.1
89
+ scikit-learn==1.3.2
90
+ scipy==1.11.4
91
+ sentence-transformers==2.2.2
92
+ sentencepiece==0.1.99
93
+ six==1.16.0
94
+ smmap==5.0.1
95
+ stack-data==0.6.3
96
+ streamlit==1.29.0
97
+ sympy==1.12
98
+ tenacity==8.2.3
99
+ tensorboard==2.15.1
100
+ tensorboard-data-server==0.7.2
101
+ tensorboard-plugin-wit==1.8.1
102
+ tensorflow==2.10.1
103
+ tensorflow-estimator==2.15.0
104
+ tensorflow-hub==0.15.0
105
+ tensorflow-io-gcs-filesystem==0.31.0
106
+ tensorflow-text==2.10.0
107
+ termcolor==2.4.0
108
+ threadpoolctl==3.2.0
109
+ tokenizers==0.15.0
110
+ toml==0.10.2
111
+ toolz==0.12.0
112
+ torch==2.1.2
113
+ torchvision==0.16.2
114
+ tornado==6.4
115
+ tqdm==4.66.1
116
+ traitlets==5.14.0
117
+ transformers==4.36.2
118
+ typing_extensions==4.9.0
119
+ tzdata==2023.3
120
+ tzlocal==5.2
121
+ urllib3==2.1.0
122
+ validators==0.22.0
123
+ virtualenv==20.25.0
124
+ watchdog==3.0.0
125
+ wcwidth==0.2.12
126
+ Werkzeug==3.0.1
127
+ wrapt==1.14.1
128
+ zipp==3.17.0