Ramlaoui commited on
Commit
2dd66b7
·
1 Parent(s): a91474f

Faster search and Table

Browse files
Files changed (4) hide show
  1. Dockerfile +5 -0
  2. app.py +101 -42
  3. create_index.py +52 -0
  4. requirements.txt +1 -0
Dockerfile CHANGED
@@ -19,6 +19,9 @@ RUN pip install --no-cache-dir -r requirements.txt
19
  # Copy the application code
20
  COPY app.py .
21
 
 
 
 
22
  # Expose the port Dash will run on
23
  EXPOSE 7860
24
 
@@ -29,6 +32,8 @@ RUN --mount=type=secret,id=MATERIALS_PROJECT_API_KEY \
29
  # Create the cache directory and set permissions
30
  RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
31
 
 
 
32
 
33
  # Set an environment variable for Hugging Face cache
34
  ENV HF_HOME=/app/.cache
 
19
  # Copy the application code
20
  COPY app.py .
21
 
22
+ # Copy the preprocessing script
23
+ COPY create_index.py .
24
+
25
  # Expose the port Dash will run on
26
  EXPOSE 7860
27
 
 
32
  # Create the cache directory and set permissions
33
  RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
34
 
35
+ # Create the index
36
+ RUN python create_index.py
37
 
38
  # Set an environment variable for Hugging Face cache
39
  ENV HF_HOME=/app/.cache
app.py CHANGED
@@ -11,6 +11,7 @@ from pymatgen.core import Structure
11
  from pymatgen.ext.matproj import MPRester
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
14
 
15
  # Load only the train split of the dataset
16
  dataset = load_dataset(
@@ -40,9 +41,40 @@ dataset = load_dataset(
40
  ],
41
  )
42
 
43
- # Convert the train split to a pandas DataFrame
44
- train_df = dataset.to_pandas()
45
- del dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Initialize the Dash app
48
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
@@ -58,11 +90,11 @@ layout = html.Div(
58
  [
59
  html.H3("Search for materials by elements (eg. 'Ac,Cd,Ge')"),
60
  dmp.MaterialsInput(
61
- allowedInputTypes=["elements"],
62
  hidePeriodicTable=False,
63
  periodicTableMode="toggle",
64
  showSubmitButton=True,
65
- submitButtonText="Submit",
66
  type="elements",
67
  id="materials-input",
68
  ),
@@ -79,10 +111,24 @@ layout = html.Div(
79
  html.Div(
80
  [
81
  html.Label("Select Material"),
82
- dcc.Dropdown(
83
- id="material-dropdown",
84
- options=[], # Empty options initially
85
- value=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ),
87
  ],
88
  style={"margin-bottom": "20px"},
@@ -118,40 +164,51 @@ layout = html.Div(
118
  )
119
 
120
 
121
- # Function to search for materials
122
  def search_materials(query):
123
- element_list = [el.strip() for el in query.split(",")]
124
- isubset = lambda x: set(x).issubset(element_list)
125
- isintersection = lambda x: len(set(x).intersection(element_list)) > 0
126
- entries_df = train_df[
127
- [isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()]
128
- ]
129
-
130
- options = [
131
- {
132
- "label": f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}",
133
- "value": n,
134
- }
135
- for n, res in entries_df.iterrows()
136
- ]
137
- del entries_df
 
 
 
 
 
 
 
 
 
 
 
 
138
  return options
139
 
140
 
141
- # Callback to update the material dropdown based on search
142
  @app.callback(
143
- Output("material-dropdown", "options"),
144
- Output("material-dropdown", "value"),
145
  Input("materials-input", "submitButtonClicks"),
146
  Input("materials-input", "value"),
147
  )
148
  def on_submit_materials_input(n_clicks, query):
149
  if n_clicks is None or not query:
150
- return [], None
151
- options = search_materials(query)
152
- if not options:
153
- return [], None
154
- return options, options[0]["value"]
 
155
 
156
 
157
  # Callback to display the selected material
@@ -161,12 +218,14 @@ def on_submit_materials_input(n_clicks, query):
161
  Output("properties-container", "children"),
162
  ],
163
  Input("display-button", "n_clicks"),
164
- State("material-dropdown", "value"),
165
  )
166
- def display_material(n_clicks, material_id):
167
- if n_clicks is None or not material_id:
168
  return "", ""
169
- row = train_df.iloc[material_id]
 
 
170
 
171
  structure = Structure(
172
  [x for y in row["lattice_vectors"] for x in y],
@@ -180,11 +239,11 @@ def display_material(n_clicks, material_id):
180
 
181
  # Extract key properties
182
  properties = {
183
- "Material ID": row.immutable_id,
184
- "Formula": row.chemical_formula_descriptive,
185
- "Energy per atom (eV/atom)": row.energy / len(row.species_at_sites),
186
- "Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect,
187
- "Total Magnetization (μB/f.u.)": row.total_magnetization,
188
  }
189
 
190
  # Format properties as an HTML table
 
11
  from pymatgen.ext.matproj import MPRester
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ top_k = 100
15
 
16
  # Load only the train split of the dataset
17
  dataset = load_dataset(
 
41
  ],
42
  )
43
 
44
+ display_columns = [
45
+ "chemical_formula_descriptive",
46
+ "functional",
47
+ "immutable_id",
48
+ "energy",
49
+ ]
50
+ display_names = {
51
+ "chemical_formula_descriptive": "Formula",
52
+ "functional": "Functional",
53
+ "immutable_id": "Material ID",
54
+ "energy": "Energy (eV)",
55
+ }
56
+
57
+ mapping_table_idx_dataset_idx = {}
58
+
59
+ import numpy as np
60
+ import periodictable
61
+
62
+ map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
63
+
64
+ # import re
65
+ #
66
+ # dataset_index = np.zeros((len(dataset), 118))
67
+ # import tqdm
68
+ #
69
+ # for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
70
+ # for el in row["chemical_formula_descriptive"].split(" "):
71
+ # matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
72
+ # el = matches[0][0]
73
+ # numb = int(matches[0][1]) if matches[0][1] else 1
74
+ # dataset_index[i][map_periodic_table[el]] = numb
75
+
76
+
77
+ dataset_index = np.load("dataset_index.npy")
78
 
79
  # Initialize the Dash app
80
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
 
90
  [
91
  html.H3("Search for materials by elements (eg. 'Ac,Cd,Ge')"),
92
  dmp.MaterialsInput(
93
+ allowedInputTypes=["elements", "formula"],
94
  hidePeriodicTable=False,
95
  periodicTableMode="toggle",
96
  showSubmitButton=True,
97
+ submitButtonText="Search",
98
  type="elements",
99
  id="materials-input",
100
  ),
 
111
  html.Div(
112
  [
113
  html.Label("Select Material"),
114
+ # dcc.Dropdown(
115
+ # id="material-dropdown",
116
+ # options=[], # Empty options initially
117
+ # value=None,
118
+ # ),
119
+ dash.dash_table.DataTable(
120
+ id="table",
121
+ columns=[
122
+ {"name": display_names[col], "id": col}
123
+ for col in display_columns
124
+ ],
125
+ data=[{}],
126
+ style_table={
127
+ "overflowX": "auto",
128
+ "height": "400px",
129
+ "overflowY": "auto",
130
+ },
131
+ style_cell={"textAlign": "left"},
132
  ),
133
  ],
134
  style={"margin-bottom": "20px"},
 
164
  )
165
 
166
 
 
167
  def search_materials(query):
168
+ query_vector = np.zeros(118)
169
+
170
+ if "," in query:
171
+ element_list = [el.strip() for el in query.split(",")]
172
+ for el in element_list:
173
+ query_vector[map_periodic_table[el]] = 1
174
+ else:
175
+ # Formula
176
+ import re
177
+
178
+ matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
179
+ for el, numb in matches:
180
+ numb = int(numb) if numb else 1
181
+ query_vector[map_periodic_table[el]] = numb
182
+
183
+ similarity = np.dot(dataset_index, query_vector) / (
184
+ np.linalg.norm(dataset_index) * np.linalg.norm(query_vector)
185
+ )
186
+ print(similarity[::-1][:top_k])
187
+ indices = np.argsort(similarity)[::-1][:top_k]
188
+
189
+ options = [dataset[int(i)] for i in indices]
190
+
191
+ mapping_table_idx_dataset_idx.clear()
192
+ for i, idx in enumerate(indices):
193
+ mapping_table_idx_dataset_idx[int(i)] = int(idx)
194
+
195
  return options
196
 
197
 
198
+ # Callback to update the table based on search
199
  @app.callback(
200
+ Output("table", "data"),
 
201
  Input("materials-input", "submitButtonClicks"),
202
  Input("materials-input", "value"),
203
  )
204
  def on_submit_materials_input(n_clicks, query):
205
  if n_clicks is None or not query:
206
+ return []
207
+
208
+ entries = search_materials(query)
209
+ print(len(entries))
210
+
211
+ return [{col: entry[col] for col in display_columns} for entry in entries]
212
 
213
 
214
  # Callback to display the selected material
 
218
  Output("properties-container", "children"),
219
  ],
220
  Input("display-button", "n_clicks"),
221
+ Input("table", "active_cell"),
222
  )
223
+ def display_material(n_clicks, active_cell):
224
+ if n_clicks is None or not active_cell:
225
  return "", ""
226
+
227
+ idx_active = active_cell["row"]
228
+ row = dataset[mapping_table_idx_dataset_idx[idx_active]]
229
 
230
  structure = Structure(
231
  [x for y in row["lattice_vectors"] for x in y],
 
239
 
240
  # Extract key properties
241
  properties = {
242
+ "Material ID": row["immutable_id"],
243
+ "Formula": row["chemical_formula_descriptive"],
244
+ "Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]),
245
+ "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"],
246
+ "Total Magnetization (μB/f.u.)": row["total_magnetization"],
247
  }
248
 
249
  # Format properties as an HTML table
create_index.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import numpy as np
5
+ import periodictable
6
+ from datasets import load_dataset
7
+
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")
9
+
10
+ # Load only the train split of the dataset
11
+ dataset = load_dataset(
12
+ "LeMaterial/leDataset",
13
+ token=HF_TOKEN,
14
+ split="train",
15
+ columns=[
16
+ "lattice_vectors",
17
+ "species_at_sites",
18
+ "cartesian_site_positions",
19
+ "energy",
20
+ "energy_corrected",
21
+ "immutable_id",
22
+ "elements",
23
+ "functional",
24
+ "stress_tensor",
25
+ "magnetic_moments",
26
+ "forces",
27
+ "band_gap_direct",
28
+ "band_gap_indirect",
29
+ "dos_ef",
30
+ "charges",
31
+ "functional",
32
+ "chemical_formula_reduced",
33
+ "chemical_formula_descriptive",
34
+ "total_magnetization",
35
+ ],
36
+ )
37
+
38
+
39
+ map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
40
+
41
+
42
+ dataset_index = np.zeros((len(dataset), 118))
43
+ import tqdm
44
+
45
+ for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
46
+ for el in row["chemical_formula_descriptive"].split(" "):
47
+ matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
48
+ el = matches[0][0]
49
+ numb = int(matches[0][1]) if matches[0][1] else 1
50
+ dataset_index[i][map_periodic_table[el]] = numb
51
+
52
+ np.save("dataset_index.npy", dataset_index)
requirements.txt CHANGED
@@ -9,3 +9,4 @@ pandas
9
  dash-bootstrap-components
10
  datasets
11
  dash-mp-components
 
 
9
  dash-bootstrap-components
10
  datasets
11
  dash-mp-components
12
+ periodictable