import os import polars as pl import gradio as gr import numpy as np import pandas as pd import plotly.graph_objs as go from datasets import concatenate_datasets, load_dataset from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram from pymatgen.core import Composition, Structure, Element from pymatgen.core.composition import Composition from pymatgen.entries.computed_entries import ( ComputedStructureEntry, GibbsComputedStructureEntry, ) HF_TOKEN = os.environ.get("HF_TOKEN") subsets = [ "compatible_pbe", "compatible_pbesol", "compatible_scan", ] polars_dfs = { subset: pl.read_parquet( "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset), storage_options={ "token": HF_TOKEN, }, ) for subset in subsets } # Load only the train split of the dataset # datasets = [] # for subset in subsets: # dataset = load_dataset( # "LeMaterial/leMat-Bulk", # subset, # token=HF_TOKEN, # columns=[ # "lattice_vectors", # "species_at_sites", # "cartesian_site_positions", # "energy", # "energy_corrected", # "immutable_id", # "elements", # "functional", # ], # ) # datasets.append(dataset["train"]) # Convert the train split to a pandas DataFrame # df = pd.concat([x.to_pandas() for x in datasets]) # train_df = dataset.to_pandas() # del dataset dataset = concatenate_datasets(datasets) # dataset_element_combination_dict = {} # isubset = lambda x: set(x).issubset(element_list) # isintersection = lambda x: len(set(x).intersection(element_list)) > 0 # for element_1 in Element: # for element_2 in Element: # for element_3 in Element: # if element_1 != element_2 and element_2 != element_3 and element_3 != element_1: # print("processing {},{},{}".format(*element_list)) # element_list = [element_1.name, element_2.name, element_3.name] # dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter( # lambda example: isintersection(example["elements"]) # and isubset(example["elements"]) # ) def create_phase_diagram( elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp, **kwargs, ): # Split elements and remove any whitespace element_list = [el.strip() for el in elements.split("-")] # Filter entries based on functional if functional == "PBE": df = polars_dfs["compatible_pbe"].clone() # entries_df = train_df[train_df["functional"] == "pbe"] elif functional == "PBESol": df = polars_dfs["compatible_pbesol"].clone() # entries_df = train_df[train_df["functional"] == "pbesol"] elif functional == "SCAN": df = polars_dfs["compatible_scan"].clone() # entries_df = train_df[train_df["functional"] == "scan"] # entries_df = df.to_pandas() # isubset = lambda x: set(x).issubset(element_list) # isintersection = lambda x: len(set(x).intersection(element_list)) > 0 # entries_df = entries_df[entries_df["elements"]]( # lambda example: isintersection(example["elements"]) # and isubset(example["elements"]) # ) df = df.filter((df.col("elements").list.contains(x) for x in element_list)) df = df.filter( pl.col("elements") .list.eval(pl.element().is_in(element_list)) .list.any() .alias("check") ) entries_df = df.to_pandas() # Fetch all entries from the Materials Project database entries = [ ComputedStructureEntry( Structure( [x.tolist() for x in row["lattice_vectors"].tolist()], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ), energy=row["energy"], correction=( row["energy_corrected"] - row["energy"] if not np.isnan(row["energy_corrected"]) else 0 ), entry_id=row["immutable_id"], parameters={"run_type": row["functional"]}, ) for n, row in entries_df.iterrows() ] # TODO: Fetch elemental entries (they are usually GGA calculations) # entries.extend([e for e in entries if e.composition.is_element]) if finite_temp: entries = GibbsComputedStructureEntry.from_entries(entries) # Build the phase diagram try: phase_diagram = PhaseDiagram(entries) except ValueError as e: return go.Figure().add_annotation(text=str(e)) # Generate plotly figure if plot_style == "2D": plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") fig = plotter.get_plot() else: # For 3D plots, limit to ternary systems if len(element_list) == 3: plotter = PDPlotter( phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" ) fig = plotter.get_plot() else: return go.Figure().add_annotation( text="3D plots are only available for ternary systems." ) # Adjust the maximum energy above hull # (This is a placeholder as PDPlotter does not support direct filtering) # Return the figure return fig # Define Gradio interface components elements_input = gr.Textbox( label="Elements (e.g., 'Li-Fe-O')", placeholder="Enter elements separated by '-'", value="Li-Fe-O", ) max_e_above_hull_slider = gr.Slider( minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" ) color_scheme_dropdown = gr.Dropdown( choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme" ) plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional") finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") warning_message = "This application uses energy correction schemes directly" warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP" warning_message += " Compatibility scheme applied to OQMD. However, because we did" warning_message += " not directly apply the compatibility schemes to Alexandria, MP" warning_message += " we have noticed discrepencies in the data. While the correction" warning_message += " scheme will be standardized in a soon to be released update, for" warning_message += " now please take caution when analyzing the results of this" warning_message += " application." message = '
×{}
Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format( warning_message ) # Create Gradio interface iface = gr.Interface( fn=create_phase_diagram, inputs=[ elements_input, max_e_above_hull_slider, color_scheme_dropdown, plot_style_dropdown, functional_dropdown, finite_temp_toggle, ], outputs=gr.Plot(label="Phase Diagram"), title="LeMaterial - Phase Diagram Viewer", description=message, ) # Launch the app iface.launch()