import os import gradio as gr import plotly.graph_objs as go from datasets import load_dataset from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram from pymatgen.core import Composition, Structure from pymatgen.core.composition import Composition from pymatgen.entries.computed_entries import (ComputedStructureEntry, GibbsComputedStructureEntry) HF_TOKEN = os.environ.get("HF_TOKEN") # Load only the train split of the dataset dataset = load_dataset("LeMaterial/leDataset", token=HF_TOKEN, split="train") # Convert the train split to a pandas DataFrame train_df = dataset.to_pandas() def create_phase_diagram( elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp ): # Split elements and remove any whitespace element_list = [el.strip() for el in elements.split("-")] # Fetch all entries from the Materials Project database entries = [ ComputedStructureEntry( Structure( [x.tolist() for x in df.iloc[0]["lattice_vectors"].tolist()], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ), energy=row["energy"], correction=row["energy_corrected"] - row["energy"], entry_id=row["immutable_id"], parameters={"run_type": row["functional"]}, ) for n, row in train_df.iterrows() if len(set(row["elements"]).intersection(element_list)) > 0 and set(row["elements"]).issubset(element_list) ] # Fetch elemental entries (they are usually GGA calculations) elemental_entries = [e for e in entries if e.composition.is_element] # Filter entries based on functional if functional == "PBE": entries = [e for e in entries if e.parameters.get("run_type", "") == "pbe"] entries.extend([e for e in elemental_entries if e not in entries]) elif functional == "PBESol": entries = [e for e in entries if e.parameters.get("run_type", "") == "pbesol"] # Add elemental entries to ensure they are included entries.extend([e for e in elemental_entries if e not in entries]) elif functional == "SCAN": entries = [e for e in entries if e.parameters.get("run_type", "") == "scan"] # Add elemental entries to ensure they are included entries.extend([e for e in elemental_entries if e not in entries]) if finite_temp: entries = GibbsComputedStructureEntry.from_entries(entries) # Build the phase diagram try: phase_diagram = PhaseDiagram(entries) except ValueError as e: return go.Figure().add_annotation(text=str(e)) # Generate plotly figure if plot_style == "2D": plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") fig = plotter.get_plot() else: # For 3D plots, limit to ternary systems if len(element_list) == 3: plotter = PDPlotter( phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" ) fig = plotter.get_plot() else: return go.Figure().add_annotation( text="3D plots are only available for ternary systems." ) # Adjust the maximum energy above hull # (This is a placeholder as PDPlotter does not support direct filtering) # Return the figure return fig # Define Gradio interface components elements_input = gr.Textbox( label="Elements (e.g., 'Li-Fe-O')", placeholder="Enter elements separated by '-'", value="Li-Fe-O", ) max_e_above_hull_slider = gr.Slider( minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" ) color_scheme_dropdown = gr.Dropdown( choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme" ) plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional") finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") # Create Gradio interface iface = gr.Interface( fn=create_phase_diagram, inputs=[ elements_input, max_e_above_hull_slider, color_scheme_dropdown, plot_style_dropdown, functional_dropdown, finite_temp_toggle, ], outputs=gr.Plot(label="Phase Diagram"), title="Materials Project Phase Diagram", description="Generate a phase diagram for a set of elements using Materials Project data.", ) # Launch the app iface.launch()