Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 7,432 Bytes
1a9eacc 7f3ef59 1a9eacc baa7f8d 33df548 8689fa0 baa7f8d fbd13ac baa7f8d 7f3ef59 baa7f8d 2af7535 70b543a a0bb336 193b388 fbd13ac 9ec01d1 7f3ef59 9ec01d1 2200e20 9ec01d1 70b543a 193b388 1a9eacc 7f3ef59 aba1804 baa7f8d 2af7535 baa7f8d 1a9eacc baa7f8d 1a9eacc 33df548 9ec01d1 193b388 33df548 9ec01d1 193b388 33df548 9ec01d1 193b388 33df548 7f3ef59 9ec01d1 7f3ef59 9ec01d1 193b388 9ec01d1 33df548 70b543a baa7f8d 0641fac 33df548 0641fac baa7f8d 8689fa0 baa7f8d 137bc98 baa7f8d 33df548 baa7f8d b88325e 1a9eacc 0641fac 1a9eacc baa7f8d 1a9eacc baa7f8d 1a9eacc baa7f8d 1a9eacc 137bc98 1a9eacc 2af7535 5b4c478 fbd13ac 8a53814 1a9eacc baa7f8d 1a9eacc c61d5b5 4610d52 bd292ac 1a9eacc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import os
import polars as pl
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Structure, Element
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
ComputedStructureEntry,
GibbsComputedStructureEntry,
)
HF_TOKEN = os.environ.get("HF_TOKEN")
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
]
# polars_dfs = {
# subset: pl.read_parquet(
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
# storage_options={
# "token": HF_TOKEN,
# },
# )
# for subset in subsets
# }
# # Load only the train split of the dataset
subsets_ds = {}
for subset in subsets:
dataset = load_dataset(
"LeMaterial/leMat1",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
],
)
subsets_ds[subset] = dataset["train"]
# Convert the train split to a pandas DataFrame
# df = pd.concat([x.to_pandas() for x in datasets])
# train_df = dataset.to_pandas()
# del dataset
# dataset_element_combination_dict = {}
# isubset = lambda x: set(x).issubset(element_list)
# isintersection = lambda x: len(set(x).intersection(element_list)) > 0
# for element_1 in Element:
# for element_2 in Element:
# for element_3 in Element:
# if element_1 != element_2 and element_2 != element_3 and element_3 != element_1:
# print("processing {},{},{}".format(*element_list))
# element_list = [element_1.name, element_2.name, element_3.name]
# dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter(
# lambda example: isintersection(example["elements"])
# and isubset(example["elements"])
# )
def create_phase_diagram(
elements,
max_e_above_hull,
color_scheme,
plot_style,
functional,
finite_temp,
**kwargs,
):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split("-")]
# Filter entries based on functional
if functional == "PBE":
entries_df = subsets_ds["compatible_pbe"].to_pandas()
# entries_df = train_df[train_df["functional"] == "pbe"]
elif functional == "PBESol":
entries_df = subsets_ds["compatible_pbesol"].to_pandas()
# entries_df = train_df[train_df["functional"] == "pbesol"]
elif functional == "SCAN":
entries_df = subsets_ds["compatible_scan"].to_pandas()
# entries_df = train_df[train_df["functional"] == "scan"]
# entries_df = df.to_pandas()
entries_df = entries_df[~entries_df['immutable_id'].isna()]
isubset = lambda x: set(x).issubset(element_list)
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
entries_df = entries_df[
[isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
]
# df = df.filter((df.col("elements").list.contains(x) for x in element_list))
# df = df.filter(
# pl.col("elements")
# .list.eval(pl.element().is_in(element_list))
# .list.any()
# .alias("check")
# )
# entries_df = df.to_pandas()
# Fetch all entries from the Materials Project database
entries = [
ComputedStructureEntry(
Structure(
[x.tolist() for x in row["lattice_vectors"].tolist()],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
),
energy=row["energy"],
correction=(
row["energy_corrected"] - row["energy"]
if not np.isnan(row["energy_corrected"])
else 0
),
entry_id=row["immutable_id"],
parameters={"run_type": row["functional"]},
)
for n, row in entries_df.iterrows()
]
# TODO: Fetch elemental entries (they are usually GGA calculations)
# entries.extend([e for e in entries if e.composition.is_element])
if finite_temp:
entries = GibbsComputedStructureEntry.from_entries(entries)
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
)
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(
text="3D plots are only available for ternary systems."
)
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(
label="Elements (e.g., 'Li-Fe-O')",
placeholder="Enter elements separated by '-'",
value="Li-Fe-O",
)
max_e_above_hull_slider = gr.Slider(
minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
)
color_scheme_dropdown = gr.Dropdown(
choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme"
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
warning_message = "This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
warning_message
)
# Create Gradio interface
iface = gr.Interface(
fn=create_phase_diagram,
inputs=[
elements_input,
max_e_above_hull_slider,
color_scheme_dropdown,
plot_style_dropdown,
functional_dropdown,
finite_temp_toggle,
],
outputs=gr.Plot(label="Phase Diagram"),
title="LeMaterial - Phase Diagram Viewer",
description=message,
)
# Launch the app
iface.launch()
|