phase_diagram / app.py
msiron's picture
add unused features
bddaa7b
raw
history blame
7.41 kB
import os
import polars as pl
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objs as go
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Structure, Element
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
ComputedStructureEntry,
GibbsComputedStructureEntry,
)
HF_TOKEN = os.environ.get("HF_TOKEN")
subsets = [
"compatible_pbe",
"compatible_pbesol",
"compatible_scan",
]
# polars_dfs = {
# subset: pl.read_parquet(
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
# storage_options={
# "token": HF_TOKEN,
# },
# )
# for subset in subsets
# }
# # Load only the train split of the dataset
subsets_ds = {}
for subset in subsets:
dataset = load_dataset(
"LeMaterial/leMat1",
subset,
token=HF_TOKEN,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
],
)
subsets_ds[subset] = dataset["train"]
# Convert the train split to a pandas DataFrame
# df = pd.concat([x.to_pandas() for x in datasets])
# train_df = dataset.to_pandas()
# del dataset
# dataset_element_combination_dict = {}
# isubset = lambda x: set(x).issubset(element_list)
# isintersection = lambda x: len(set(x).intersection(element_list)) > 0
# for element_1 in Element:
# for element_2 in Element:
# for element_3 in Element:
# if element_1 != element_2 and element_2 != element_3 and element_3 != element_1:
# print("processing {},{},{}".format(*element_list))
# element_list = [element_1.name, element_2.name, element_3.name]
# dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter(
# lambda example: isintersection(example["elements"])
# and isubset(example["elements"])
# )
def create_phase_diagram(
elements,
plot_style,
functional,
finite_temp,
**kwargs,
):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split("-")]
# Filter entries based on functional
if functional == "PBE":
entries_df = subsets_ds["compatible_pbe"].to_pandas()
# entries_df = train_df[train_df["functional"] == "pbe"]
elif functional == "PBESol":
entries_df = subsets_ds["compatible_pbesol"].to_pandas()
# entries_df = train_df[train_df["functional"] == "pbesol"]
elif functional == "SCAN":
entries_df = subsets_ds["compatible_scan"].to_pandas()
# entries_df = train_df[train_df["functional"] == "scan"]
# entries_df = df.to_pandas()
entries_df = entries_df[~entries_df['immutable_id'].isna()]
isubset = lambda x: set(x).issubset(element_list)
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
entries_df = entries_df[
[isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
]
# df = df.filter((df.col("elements").list.contains(x) for x in element_list))
# df = df.filter(
# pl.col("elements")
# .list.eval(pl.element().is_in(element_list))
# .list.any()
# .alias("check")
# )
# entries_df = df.to_pandas()
# Fetch all entries from the Materials Project database
entries = [
ComputedStructureEntry(
Structure(
[x.tolist() for x in row["lattice_vectors"].tolist()],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
),
energy=row["energy"],
correction=(
row["energy_corrected"] - row["energy"]
if not np.isnan(row["energy_corrected"])
else 0
),
entry_id=row["immutable_id"],
parameters={"run_type": row["functional"]},
)
for n, row in entries_df.iterrows()
]
# TODO: Fetch elemental entries (they are usually GGA calculations)
# entries.extend([e for e in entries if e.composition.is_element])
if finite_temp:
entries = GibbsComputedStructureEntry.from_entries(entries)
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
)
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(
text="3D plots are only available for ternary systems."
)
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(
label="Elements (e.g., 'Li-Fe-O')",
placeholder="Enter elements separated by '-'",
value="Li-Fe-O",
)
# max_e_above_hull_slider = gr.Slider(
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
# )
# color_scheme_dropdown = gr.Dropdown(
# choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme"
# )
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
warning_message = "This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">&times;</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
warning_message
)
# Create Gradio interface
iface = gr.Interface(
fn=create_phase_diagram,
inputs=[
elements_input,
# max_e_above_hull_slider,
# color_scheme_dropdown,
plot_style_dropdown,
functional_dropdown,
finite_temp_toggle,
],
outputs=gr.Plot(label="Phase Diagram"),
title="LeMaterial - Phase Diagram Viewer",
description=message,
)
# Launch the app
iface.launch()