File size: 4,693 Bytes
1a9eacc
 
baa7f8d
33df548
baa7f8d
70b543a
baa7f8d
0641fac
baa7f8d
137bc98
 
70b543a
a0bb336
 
70b543a
6f4e929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b543a
 
 
fa131a7
1a9eacc
 
baa7f8d
 
 
1a9eacc
baa7f8d
1a9eacc
33df548
 
 
 
 
 
 
 
 
 
 
 
 
 
70b543a
baa7f8d
0641fac
 
33df548
0641fac
 
 
 
baa7f8d
33df548
 
 
baa7f8d
137bc98
baa7f8d
33df548
baa7f8d
70b543a
 
1a9eacc
0641fac
 
 
1a9eacc
 
 
 
 
 
 
 
 
 
 
 
 
baa7f8d
 
 
1a9eacc
 
baa7f8d
 
 
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
 
 
 
 
 
 
 
 
 
 
1a9eacc
137bc98
1a9eacc
 
 
 
 
 
 
 
 
 
 
baa7f8d
1a9eacc
 
 
baa7f8d
1a9eacc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os

import gradio as gr
import numpy as np
import plotly.graph_objs as go
from datasets import load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (ComputedStructureEntry,
                                               GibbsComputedStructureEntry)

HF_TOKEN = os.environ.get("HF_TOKEN")

# Load only the train split of the dataset
dataset = load_dataset(
    "LeMaterial/leDataset",
    token=HF_TOKEN,
    split="train",
    columns=[
        "lattice_vectors",
        "species_at_sites",
        "cartesian_site_positions",
        "energy",
        "energy_corrected",
        "immutable_id",
        "elements",
        "functional",
    ],
)

# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()
del dataset


def create_phase_diagram(
    elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp
):
    # Split elements and remove any whitespace
    element_list = [el.strip() for el in elements.split("-")]

    # Filter entries based on functional
    if functional == "PBE":
        entries_df = entries_df[train_df["functional"] == "pbe"]
    elif functional == "PBESol":
        entries_df = entries_df[train_df["functional"] == "pbe"]
    elif functional == "SCAN":
        entries_df = entries_df[train_df["functional"] == "pbe"]

    isubset = lambda x: set(x).issubset(element_list)
    isintersection = lambda x: len(set(x).intersection(element_list)) > 0
    entries_df = entries_df[
        [isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
    ]

    # Fetch all entries from the Materials Project database
    entries = [
        ComputedStructureEntry(
            Structure(
                [x.tolist() for x in row["lattice_vectors"].tolist()],
                row["species_at_sites"],
                row["cartesian_site_positions"],
                coords_are_cartesian=True,
            ),
            energy=row["energy"],
            correction=row["energy_corrected"] - row["energy"]
            if not np.isnan(row["energy_corrected"])
            else 0,
            entry_id=row["immutable_id"],
            parameters={"run_type": row["functional"]},
        )
        for n, row in entries_df.iterrows()
    ]
    # Fetch elemental entries (they are usually GGA calculations)
    elemental_entries = [e for e in entries if e.composition.is_element]

    if finite_temp:
        entries = GibbsComputedStructureEntry.from_entries(entries)

    # Build the phase diagram
    try:
        phase_diagram = PhaseDiagram(entries)
    except ValueError as e:
        return go.Figure().add_annotation(text=str(e))

    # Generate plotly figure
    if plot_style == "2D":
        plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
        fig = plotter.get_plot()
    else:
        # For 3D plots, limit to ternary systems
        if len(element_list) == 3:
            plotter = PDPlotter(
                phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
            )
            fig = plotter.get_plot()
        else:
            return go.Figure().add_annotation(
                text="3D plots are only available for ternary systems."
            )

    # Adjust the maximum energy above hull
    # (This is a placeholder as PDPlotter does not support direct filtering)

    # Return the figure
    return fig


# Define Gradio interface components
elements_input = gr.Textbox(
    label="Elements (e.g., 'Li-Fe-O')",
    placeholder="Enter elements separated by '-'",
    value="Li-Fe-O",
)
max_e_above_hull_slider = gr.Slider(
    minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
)
color_scheme_dropdown = gr.Dropdown(
    choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme"
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")

# Create Gradio interface
iface = gr.Interface(
    fn=create_phase_diagram,
    inputs=[
        elements_input,
        max_e_above_hull_slider,
        color_scheme_dropdown,
        plot_style_dropdown,
        functional_dropdown,
        finite_temp_toggle,
    ],
    outputs=gr.Plot(label="Phase Diagram"),
    title="Materials Project Phase Diagram",
    description="Generate a phase diagram for a set of elements using Materials Project data.",
)

# Launch the app
iface.launch()