phase_diagram / app.py
Ramlaoui's picture
Use sparse matrix and don't store df
3de9b87
raw
history blame
10.4 kB
import os
import gradio as gr
import numpy as np
import pandas as pd
import periodictable
import plotly.graph_objs as go
import polars as pl
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
ComputedStructureEntry,
GibbsComputedStructureEntry,
)
HF_TOKEN = os.environ.get("HF_TOKEN")
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
]
# polars_dfs = {
# subset: pl.read_parquet(
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
# storage_options={
# "token": HF_TOKEN,
# },
# )
# for subset in subsets
# }
# # Load only the train split of the dataset
subsets_ds = {}
for subset in subsets:
dataset = load_dataset(
"LeMaterial/LeMat-Bulk",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
],
)
subsets_ds[subset] = dataset["train"]
elements_df = {
k: subset.select_columns("elements").to_pandas() for k, subset in subsets_ds.items()
}
from scipy.sparse import csr_matrix
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
elements_indices = {}
for subset, df in elements_df.items():
print("Processing subset: ", subset)
elements_indices[subset] = np.zeros((len(df), len(all_elements)))
def map_elements(row):
index, xs = row["index"], row["elements"]
for x in xs:
elements_indices[subset][index, all_elements[x]] = 1
df = df.reset_index().apply(map_elements, axis=1)
import ipdb
ipdb.set_trace()
elements_indices[subset] = csr_matrix(elements_indices[subset])
map_functional = {
"PBE": "compatible_pbe",
"PBESol (No correction scheme)": "compatible_pbesol",
"SCAN (No correction scheme)": "compatible_scan",
}
def create_phase_diagram(
elements,
energy_correction,
plot_style,
functional,
finite_temp,
**kwargs,
):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split("-")]
subset_name = map_functional[functional]
element_list_vector = np.zeros(len(all_elements))
for el in element_list:
element_list_vector[all_elements[el]] = 1
n_elements = elements_indices[subset_name].sum(axis=1)
n_elements_query = elements_indices[subset_name][
:, element_list_vector.astype(bool)
]
if n_elements_query.shape[1] == 0:
indices_with_only_elements = []
else:
indices_with_only_elements = np.where(
n_elements_query.sum(axis=1) == n_elements
)[0]
print(indices_with_only_elements)
entries_df = subsets_ds[subset_name].select(indices_with_only_elements).to_pandas()
entries_df = entries_df[~entries_df["immutable_id"].isna()]
print(entries_df)
# Fetch all entries from the Materials Project database
def get_energy_correction(energy_correction, row):
if energy_correction == "Database specific, or MP2020" and functional == "PBE":
print("applying MP corrections")
return (
row["energy_corrected"] - row["energy"]
if not np.isnan(row["energy_corrected"])
else 0
)
elif energy_correction == "The 110 PBE Method" and functional == "PBE":
print("applying PBE110 corrections")
return row["energy"] * 1.1 - row["energy"]
elif map_functional[functional] != "pbe":
print("not applying any corrections")
return 0
entries = [
ComputedStructureEntry(
Structure(
[x.tolist() for x in row["lattice_vectors"].tolist()],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
),
energy=row["energy"],
correction=get_energy_correction(energy_correction, row),
entry_id=row["immutable_id"],
parameters={"run_type": row["functional"]},
)
for n, row in entries_df.iterrows()
]
# TODO: Fetch elemental entries (they are usually GGA calculations)
# entries.extend([e for e in entries if e.composition.is_element])
if finite_temp:
entries = GibbsComputedStructureEntry.from_entries(entries)
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
print(e)
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
)
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(
text="3D plots are only available for ternary systems."
)
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(
label="Elements (e.g., 'Li-Fe-O')",
placeholder="Enter elements separated by '-'",
value="Li-Fe-O",
)
# max_e_above_hull_slider = gr.Slider(
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
# )
energy_correction_dropdown = gr.Dropdown(
choices=[
"The 110 PBE Method",
"Database specific, or MP2020",
],
label="Energy correction",
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(
choices=["PBE", "PBESol (No correction scheme)", "SCAN (No correction scheme)"],
label="Functional",
)
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
warning_message = "⚠️ This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."
warning_message += "<br> Additionally, we have provided the 110 PBE correction method"
warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/67252d617be152b1d0b2c1ef/original/a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf' target='_blank'>Rohr et al (2024)</a>."
message = "{} <br><br> Generate a phase diagram for a set of elements using LeMat-Bulk data.".format(
warning_message
)
message += """
<div style="font-size: 14px; line-height: 1.6; padding: 20px 0">
<p>
This web app is powered by
<a href="https://github.com/materialsproject/crystaltoolkit" target="_blank" style="text-decoration: none;">Crystal Toolkit</a>,
<a href="https://github.com/materialsproject/dash-mp-components" target="_blank" style="text-decoration: none;">MP Dash Components</a>,
and
<a href="https://pymatgen.org/" target="_blank" style="text-decoration: none;">Pymatgen</a>.
All tools are developed by the
<a href="https://next-gen.materialsproject.org/" target="_blank" style="text-decoration: none;">Materials Project</a>.
We are grateful for their open-source software packages. This app is intended for data exploration in LeMat-Bulk and is not affiliated with or endorsed by the Materials Project.
</p>
</div>
"""
footer_content = """
<div style="font-size: 14px; line-height: 1.6; padding: 20px 0; text-align: center;">
<hr style="border-top: 1px solid #ddd; margin: 10px 0;">
<p>
<strong>CC-BY-4.0</strong> requires proper acknowledgement. If you use materials data with an immutable_id starting with <code>mp-</code>, please cite the
<a href="https://pubs.aip.org/aip/apm/article/1/1/011002/119685/Commentary-The-Materials-Project-A-materials" target="_blank" style="text-decoration: none;">Materials Project</a>.
If you use materials data with an immutable_id starting with <code>agm-</code>, please cite
<a href="https://www.science.org/doi/10.1126/sciadv.abi7948" target="_blank" style="text-decoration: none;">Alexandria, PBE</a>
or
<a href="https://hdl.handle.net/10.1038/s41597-022-01177-w" target="_blank" style="text-decoration: none;">Alexandria PBESol, SCAN</a>.
If you use materials data with an immutable_id starting with <code>oqmd-</code>, please cite
<a href="https://link.springer.com/article/10.1007/s11837-013-0755-4" target="_blank" style="text-decoration: none;">OQMD</a>.
</p>
<p>
If you use the Phase Diagram or Crystal Viewer, please acknowledge
<a href="https://github.com/materialsproject/crystaltoolkit" target="_blank" style="text-decoration: none;">Crystal Toolkit</a>.
</p>
</div>
"""
# Create Gradio interface
iface = gr.Interface(
fn=create_phase_diagram,
inputs=[
elements_input,
# max_e_above_hull_slider,
energy_correction_dropdown,
plot_style_dropdown,
functional_dropdown,
finite_temp_toggle,
],
outputs=gr.Plot(label="Phase Diagram", elem_classes="plot-out"),
css=".plot-out {background-color: #ffffff;}",
title="MP Phase Diagram Viewer for LeMat-Bulk",
description=message,
article=footer_content,
)
# Launch the app
iface.launch()