phase_diagram / app.py
msiron's picture
remove alert style and add emoji warning
de6e3ae
raw
history blame
7.59 kB
import os
import gradio as gr
import numpy as np
import pandas as pd
import periodictable
import plotly.graph_objs as go
import polars as pl
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
ComputedStructureEntry,
GibbsComputedStructureEntry,
)
HF_TOKEN = os.environ.get("HF_TOKEN")
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
]
# polars_dfs = {
# subset: pl.read_parquet(
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
# storage_options={
# "token": HF_TOKEN,
# },
# )
# for subset in subsets
# }
# # Load only the train split of the dataset
subsets_ds = {}
for subset in subsets:
dataset = load_dataset(
"LeMaterial/leMat1",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
],
)
subsets_ds[subset] = dataset["train"].to_pandas()
elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()}
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
elements_indices = {}
for subset, df in elements_df.items():
print("Processing subset: ", subset)
elements_indices[subset] = np.zeros((len(df), len(all_elements)))
def map_elements(row):
index, xs = row["index"], row["elements"]
for x in xs:
elements_indices[subset][index, all_elements[x]] = 1
df = df.reset_index().apply(map_elements, axis=1)
map_functional = {
"PBE": "compatible_pbe",
"PBESol": "compatible_pbesol",
"SCAN": "compatible_scan",
}
def create_phase_diagram(
elements,
energy_correction,
plot_style,
functional,
finite_temp,
**kwargs,
):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split("-")]
subset_name = map_functional[functional]
element_list_vector = np.zeros(len(all_elements))
for el in element_list:
element_list_vector[all_elements[el]] = 1
n_elements = elements_indices[subset_name].sum(axis=1)
n_elements_query = elements_indices[subset_name][
:, element_list_vector.astype(bool)
]
if n_elements_query.shape[1] == 0:
indices_with_only_elements = []
else:
indices_with_only_elements = np.where(
n_elements_query.sum(axis=1) == n_elements
)[0]
print(indices_with_only_elements)
entries_df = subsets_ds[subset_name].loc[indices_with_only_elements]
entries_df = entries_df[~entries_df["immutable_id"].isna()]
print(entries_df)
# Fetch all entries from the Materials Project database
def get_energy_correction(energy_correction, row):
if energy_correction == "Database specific, or MP2020":
return (
row["energy_corrected"] - row["energy"]
if not np.isnan(row["energy_corrected"])
else 0
)
elif energy_correction == "The 110 PBE Method":
return row["energy"] * 1.1
entries = [
ComputedStructureEntry(
Structure(
[x.tolist() for x in row["lattice_vectors"].tolist()],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
),
energy=row["energy"],
correction=get_energy_correction(energy_correction, row),
entry_id=row["immutable_id"],
parameters={"run_type": row["functional"]},
)
for n, row in entries_df.iterrows()
]
# TODO: Fetch elemental entries (they are usually GGA calculations)
# entries.extend([e for e in entries if e.composition.is_element])
if finite_temp:
entries = GibbsComputedStructureEntry.from_entries(entries)
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
print(e)
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
)
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(
text="3D plots are only available for ternary systems."
)
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(
label="Elements (e.g., 'Li-Fe-O')",
placeholder="Enter elements separated by '-'",
value="Li-Fe-O",
)
# max_e_above_hull_slider = gr.Slider(
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
# )
energy_correction_dropdown = gr.Dropdown(
choices=[
"The 110 PBE Method",
"Database specific, or MP2020",
],
label="Energy correction",
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
warning_message = "⚠️ This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."
warning_message += "<br> Additionally, we have provided the 110 PBE correction method"
warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/67252d617be152b1d0b2c1ef/original/a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf' target='_blank'>Rohr et al (2024)</a>.<br>"
message = '{}. <br><br> Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
warning_message
)
message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"
# Create Gradio interface
iface = gr.Interface(
fn=create_phase_diagram,
inputs=[
elements_input,
# max_e_above_hull_slider,
energy_correction_dropdown,
plot_style_dropdown,
functional_dropdown,
finite_temp_toggle,
],
outputs=gr.Plot(label="Phase Diagram"),
title="LeMaterial - Phase Diagram Viewer",
description=message,
)
# Launch the app
iface.launch()