phase_diagram / app.py
msiron's picture
adding HF database and parsing entries from it as computedentry, logic for filtering from element list
70b543a
raw
history blame
3.76 kB
import gradio as gr
from pymatgen.ext.matproj import MPRester
from pymatgen.analysis.phase_diagram import PhaseDiagram, PDPlotter
from pymatgen.core.composition import Composition
import plotly.graph_objs as go
import os
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.core import Composition
from datasets import load_dataset
# Load only the train split of the dataset
dataset = load_dataset("LeMaterial/leDataset", split="train")
# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()
def create_phase_diagram(elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp):
# Split elements and remove any whitespace
element_list = [el.strip() for el in elements.split('-')]
# Fetch all entries from the Materials Project database
entries = [ComputedEntry(
Composition(row["chemical_formula_descriptive"]),
energy = row['energy'],
correction = row['energy_corrected']-row['energy'],
entry_id = row['immutable_id']) for n,row in df.iterrows() if len(
set(row['elements']).intersection(element_list))>0 and set(
row['elements']).issubset(element_list)]
# Fetch elemental entries (they are usually GGA calculations)
elemental_entries = [e for e in entries if e.composition.is_element]
# Filter entries based on functional
if functional == "GGA":
entries = [e for e in entries if not e.parameters.get("run_type", "").startswith("GGA+U")]
elif functional == "GGA+U":
entries = [e for e in entries if e.parameters.get("run_type", "").startswith("GGA+U")]
# Add elemental entries to ensure they are included
entries.extend([e for e in elemental_entries if e not in entries])
# Build the phase diagram
try:
phase_diagram = PhaseDiagram(entries)
except ValueError as e:
return go.Figure().add_annotation(text=str(e))
# Generate plotly figure
if plot_style == "2D":
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
fig = plotter.get_plot()
else:
# For 3D plots, limit to ternary systems
if len(element_list) == 3:
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly", ternary_style='3d')
fig = plotter.get_plot()
else:
return go.Figure().add_annotation(text="3D plots are only available for ternary systems.")
# Adjust the maximum energy above hull
# (This is a placeholder as PDPlotter does not support direct filtering)
# Return the figure
return fig
# Define Gradio interface components
elements_input = gr.Textbox(label="Elements (e.g., 'Li-Fe-O')", placeholder="Enter elements separated by '-'",
value="Li-Fe-O")
max_e_above_hull_slider = gr.Slider(minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)")
color_scheme_dropdown = gr.Dropdown(choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme")
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["GGA", "GGA+U", "Both"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")
# Create Gradio interface
iface = gr.Interface(
fn=create_phase_diagram,
inputs=[
elements_input,
max_e_above_hull_slider,
color_scheme_dropdown,
plot_style_dropdown,
functional_dropdown,
finite_temp_toggle
],
outputs=gr.Plot(label="Phase Diagram"),
title="Materials Project Phase Diagram",
description="Generate a phase diagram for a set of elements using Materials Project data."
)
# Launch the app
iface.launch()