Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import polars as pl | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.graph_objs as go | |
from datasets import concatenate_datasets, load_dataset | |
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram | |
from pymatgen.core import Composition, Structure, Element | |
from pymatgen.core.composition import Composition | |
from pymatgen.entries.computed_entries import ( | |
ComputedStructureEntry, | |
GibbsComputedStructureEntry, | |
) | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
subsets = [ | |
"compatible_pbe", | |
"compatible_pbesol", | |
"compatible_scan", | |
] | |
# polars_dfs = { | |
# subset: pl.read_parquet( | |
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset), | |
# storage_options={ | |
# "token": HF_TOKEN, | |
# }, | |
# ) | |
# for subset in subsets | |
# } | |
# # Load only the train split of the dataset | |
subsets_ds = {} | |
for subset in subsets: | |
dataset = load_dataset( | |
"LeMaterial/leMat1", | |
subset, | |
token=HF_TOKEN, | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
], | |
) | |
subsets_ds[subset] = dataset["train"] | |
# Convert the train split to a pandas DataFrame | |
# df = pd.concat([x.to_pandas() for x in datasets]) | |
# train_df = dataset.to_pandas() | |
# del dataset | |
# dataset_element_combination_dict = {} | |
# isubset = lambda x: set(x).issubset(element_list) | |
# isintersection = lambda x: len(set(x).intersection(element_list)) > 0 | |
# for element_1 in Element: | |
# for element_2 in Element: | |
# for element_3 in Element: | |
# if element_1 != element_2 and element_2 != element_3 and element_3 != element_1: | |
# print("processing {},{},{}".format(*element_list)) | |
# element_list = [element_1.name, element_2.name, element_3.name] | |
# dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter( | |
# lambda example: isintersection(example["elements"]) | |
# and isubset(example["elements"]) | |
# ) | |
def create_phase_diagram( | |
elements, | |
max_e_above_hull, | |
color_scheme, | |
plot_style, | |
functional, | |
finite_temp, | |
**kwargs, | |
): | |
# Split elements and remove any whitespace | |
element_list = [el.strip() for el in elements.split("-")] | |
# Filter entries based on functional | |
if functional == "PBE": | |
entries_df = subsets_ds["compatible_pbe"].to_pandas() | |
# entries_df = train_df[train_df["functional"] == "pbe"] | |
elif functional == "PBESol": | |
entries_df = subsets_ds["compatible_pbesol"].to_pandas() | |
# entries_df = train_df[train_df["functional"] == "pbesol"] | |
elif functional == "SCAN": | |
entries_df = subsets_ds["compatible_scan"].to_pandas() | |
# entries_df = train_df[train_df["functional"] == "scan"] | |
# entries_df = df.to_pandas() | |
entries_df = entries_df[~entries_df['immutable_id'].isna()] | |
isubset = lambda x: set(x).issubset(element_list) | |
isintersection = lambda x: len(set(x).intersection(element_list)) > 0 | |
entries_df = entries_df[ | |
[isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()] | |
] | |
# df = df.filter((df.col("elements").list.contains(x) for x in element_list)) | |
# df = df.filter( | |
# pl.col("elements") | |
# .list.eval(pl.element().is_in(element_list)) | |
# .list.any() | |
# .alias("check") | |
# ) | |
# entries_df = df.to_pandas() | |
# Fetch all entries from the Materials Project database | |
entries = [ | |
ComputedStructureEntry( | |
Structure( | |
[x.tolist() for x in row["lattice_vectors"].tolist()], | |
row["species_at_sites"], | |
row["cartesian_site_positions"], | |
coords_are_cartesian=True, | |
), | |
energy=row["energy"], | |
correction=( | |
row["energy_corrected"] - row["energy"] | |
if not np.isnan(row["energy_corrected"]) | |
else 0 | |
), | |
entry_id=row["immutable_id"], | |
parameters={"run_type": row["functional"]}, | |
) | |
for n, row in entries_df.iterrows() | |
] | |
# TODO: Fetch elemental entries (they are usually GGA calculations) | |
# entries.extend([e for e in entries if e.composition.is_element]) | |
if finite_temp: | |
entries = GibbsComputedStructureEntry.from_entries(entries) | |
# Build the phase diagram | |
try: | |
phase_diagram = PhaseDiagram(entries) | |
except ValueError as e: | |
return go.Figure().add_annotation(text=str(e)) | |
# Generate plotly figure | |
if plot_style == "2D": | |
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") | |
fig = plotter.get_plot() | |
else: | |
# For 3D plots, limit to ternary systems | |
if len(element_list) == 3: | |
plotter = PDPlotter( | |
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" | |
) | |
fig = plotter.get_plot() | |
else: | |
return go.Figure().add_annotation( | |
text="3D plots are only available for ternary systems." | |
) | |
# Adjust the maximum energy above hull | |
# (This is a placeholder as PDPlotter does not support direct filtering) | |
# Return the figure | |
return fig | |
# Define Gradio interface components | |
elements_input = gr.Textbox( | |
label="Elements (e.g., 'Li-Fe-O')", | |
placeholder="Enter elements separated by '-'", | |
value="Li-Fe-O", | |
) | |
max_e_above_hull_slider = gr.Slider( | |
minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" | |
) | |
color_scheme_dropdown = gr.Dropdown( | |
choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme" | |
) | |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") | |
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional") | |
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") | |
warning_message = "This application uses energy correction schemes directly" | |
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP" | |
warning_message += " Compatibility scheme applied to OQMD. However, because we did" | |
warning_message += " not directly apply the compatibility schemes to Alexandria, MP" | |
warning_message += " we have noticed discrepencies in the data. While the correction" | |
warning_message += " scheme will be standardized in a soon to be released update, for" | |
warning_message += " now please take caution when analyzing the results of this" | |
warning_message += " application." | |
message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">×</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format( | |
warning_message | |
) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=create_phase_diagram, | |
inputs=[ | |
elements_input, | |
max_e_above_hull_slider, | |
color_scheme_dropdown, | |
plot_style_dropdown, | |
functional_dropdown, | |
finite_temp_toggle, | |
], | |
outputs=gr.Plot(label="Phase Diagram"), | |
title="LeMaterial - Phase Diagram Viewer", | |
description=message, | |
) | |
# Launch the app | |
iface.launch() | |