msiron commited on
Commit
1e6e599
·
1 Parent(s): ce5365e

fix splits load all

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -5,45 +5,52 @@ import crystal_toolkit.components as ctc
5
  import dash
6
  import dash_mp_components as dmp
7
  import numpy as np
 
8
  import periodictable
9
  from crystal_toolkit.settings import SETTINGS
10
  from dash import dcc, html
11
  from dash.dependencies import Input, Output, State
12
  from dash_breakpoints import WindowBreakpoints
13
- from datasets import load_dataset
14
  from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
15
  from pymatgen.core import Structure
16
 
17
  HF_TOKEN = os.environ.get("HF_TOKEN")
18
  top_k = 500
19
 
 
 
20
  # Load only the train split of the dataset
21
- dataset = load_dataset(
22
- "LeMaterial/leMat-Bulk",
23
- token=HF_TOKEN,
24
- split="train",
25
- columns=[
26
- "lattice_vectors",
27
- "species_at_sites",
28
- "cartesian_site_positions",
29
- "energy",
30
- # "energy_corrected", # not yet available in LeMat-Bulk
31
- "immutable_id",
32
- "elements",
33
- "functional",
34
- "stress_tensor",
35
- "magnetic_moments",
36
- "forces",
37
- # "band_gap_direct", #future release
38
- # "band_gap_indirect", #future release
39
- "dos_ef",
40
- # "charges", #future release
41
- "functional",
42
- "chemical_formula_reduced",
43
- "chemical_formula_descriptive",
44
- "total_magnetization",
45
- ],
46
- ).select(range(1000))
 
 
 
 
47
 
48
  display_columns = [
49
  "chemical_formula_descriptive",
@@ -64,6 +71,8 @@ map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
64
  n_elements = len(map_periodic_table)
65
 
66
  # Preprocessing step to create an index for the dataset
 
 
67
  train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
68
 
69
  pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
@@ -367,8 +376,8 @@ def display_material(active_cell, selected_rows):
367
  row["cartesian_site_positions"],
368
  coords_are_cartesian=True,
369
  )
370
- if row['magnetic_moments']:
371
- structure.add_site_property('magmom',row['magnetic_moments'])
372
 
373
  sga = SpacegroupAnalyzer(structure)
374
 
@@ -379,7 +388,9 @@ def display_material(active_cell, selected_rows):
379
  properties = {
380
  "Material ID": row["immutable_id"],
381
  "Formula": row["chemical_formula_descriptive"],
382
- "Energy per atom (eV/atom)": round(row["energy"] / len(row["species_at_sites"]), 3),
 
 
383
  # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
384
  "Total Magnetization (μB)": row["total_magnetization"],
385
  "Density (g/cm^3)": round(structure.density, 3),
 
5
  import dash
6
  import dash_mp_components as dmp
7
  import numpy as np
8
+ import pandas as pd
9
  import periodictable
10
  from crystal_toolkit.settings import SETTINGS
11
  from dash import dcc, html
12
  from dash.dependencies import Input, Output, State
13
  from dash_breakpoints import WindowBreakpoints
14
+ from datasets import concatenate_datasets, load_dataset
15
  from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
16
  from pymatgen.core import Structure
17
 
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  top_k = 500
20
 
21
+ splits = ["compatible_pbe", "compatible_pbesol", "compatible_scan", "non_compatible"]
22
+
23
  # Load only the train split of the dataset
24
+
25
+ datasets = []
26
+ for split in splits:
27
+ dataset = load_dataset(
28
+ "LeMaterial/leMat-Bulk",
29
+ token=HF_TOKEN,
30
+ split=split,
31
+ columns=[
32
+ "lattice_vectors",
33
+ "species_at_sites",
34
+ "cartesian_site_positions",
35
+ "energy",
36
+ # "energy_corrected", # not yet available in LeMat-Bulk
37
+ "immutable_id",
38
+ "elements",
39
+ "functional",
40
+ "stress_tensor",
41
+ "magnetic_moments",
42
+ "forces",
43
+ # "band_gap_direct", #future release
44
+ # "band_gap_indirect", #future release
45
+ "dos_ef",
46
+ # "charges", #future release
47
+ "functional",
48
+ "chemical_formula_reduced",
49
+ "chemical_formula_descriptive",
50
+ "total_magnetization",
51
+ ],
52
+ )
53
+ datasets.append(dataset)
54
 
55
  display_columns = [
56
  "chemical_formula_descriptive",
 
71
  n_elements = len(map_periodic_table)
72
 
73
  # Preprocessing step to create an index for the dataset
74
+ # df = pd.concat([x.to_pandas() for x in datasets])
75
+ dataset = concatenate_datasets(datasets)
76
  train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
77
 
78
  pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
 
376
  row["cartesian_site_positions"],
377
  coords_are_cartesian=True,
378
  )
379
+ if row["magnetic_moments"]:
380
+ structure.add_site_property("magmom", row["magnetic_moments"])
381
 
382
  sga = SpacegroupAnalyzer(structure)
383
 
 
388
  properties = {
389
  "Material ID": row["immutable_id"],
390
  "Formula": row["chemical_formula_descriptive"],
391
+ "Energy per atom (eV/atom)": round(
392
+ row["energy"] / len(row["species_at_sites"]), 3
393
+ ),
394
  # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
395
  "Total Magnetization (μB)": row["total_magnetization"],
396
  "Density (g/cm^3)": round(structure.density, 3),