Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
from pymatgen.ext.matproj import MPRester | |
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter | |
from pymatgen.core.composition import Composition | |
import plotly.graph_objs as go | |
import os | |
from pymatgen.entries.computed_entries import ComputedEntry | |
from pymatgen.core import Composition | |
from datasets import load_dataset | |
# Load only the train split of the dataset | |
dataset = load_dataset("LeMaterial/leDataset", split="train") | |
# Convert the train split to a pandas DataFrame | |
train_df = dataset.to_pandas() | |
def create_phase_diagram(elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp): | |
# Split elements and remove any whitespace | |
element_list = [el.strip() for el in elements.split('-')] | |
# Fetch all entries from the Materials Project database | |
entries = [ComputedEntry( | |
Composition(row["chemical_formula_descriptive"]), | |
energy = row['energy'], | |
correction = row['energy_corrected']-row['energy'], | |
entry_id = row['immutable_id']) for n,row in df.iterrows() if len( | |
set(row['elements']).intersection(element_list))>0 and set( | |
row['elements']).issubset(element_list)] | |
# Fetch elemental entries (they are usually GGA calculations) | |
elemental_entries = [e for e in entries if e.composition.is_element] | |
# Filter entries based on functional | |
if functional == "GGA": | |
entries = [e for e in entries if not e.parameters.get("run_type", "").startswith("GGA+U")] | |
elif functional == "GGA+U": | |
entries = [e for e in entries if e.parameters.get("run_type", "").startswith("GGA+U")] | |
# Add elemental entries to ensure they are included | |
entries.extend([e for e in elemental_entries if e not in entries]) | |
# Build the phase diagram | |
try: | |
phase_diagram = PhaseDiagram(entries) | |
except ValueError as e: | |
return go.Figure().add_annotation(text=str(e)) | |
# Generate plotly figure | |
if plot_style == "2D": | |
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") | |
fig = plotter.get_plot() | |
else: | |
# For 3D plots, limit to ternary systems | |
if len(element_list) == 3: | |
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly", ternary_style='3d') | |
fig = plotter.get_plot() | |
else: | |
return go.Figure().add_annotation(text="3D plots are only available for ternary systems.") | |
# Adjust the maximum energy above hull | |
# (This is a placeholder as PDPlotter does not support direct filtering) | |
# Return the figure | |
return fig | |
# Define Gradio interface components | |
elements_input = gr.Textbox(label="Elements (e.g., 'Li-Fe-O')", placeholder="Enter elements separated by '-'", | |
value="Li-Fe-O") | |
max_e_above_hull_slider = gr.Slider(minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)") | |
color_scheme_dropdown = gr.Dropdown(choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme") | |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") | |
functional_dropdown = gr.Dropdown(choices=["GGA", "GGA+U", "Both"], label="Functional") | |
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=create_phase_diagram, | |
inputs=[ | |
elements_input, | |
max_e_above_hull_slider, | |
color_scheme_dropdown, | |
plot_style_dropdown, | |
functional_dropdown, | |
finite_temp_toggle | |
], | |
outputs=gr.Plot(label="Phase Diagram"), | |
title="Materials Project Phase Diagram", | |
description="Generate a phase diagram for a set of elements using Materials Project data." | |
) | |
# Launch the app | |
iface.launch() | |