File size: 5,555 Bytes
1a9eacc
 
baa7f8d
33df548
baa7f8d
70b543a
baa7f8d
0641fac
baa7f8d
2af7535
 
 
 
70b543a
a0bb336
 
70b543a
6f4e929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b543a
 
 
fa131a7
1a9eacc
 
baa7f8d
2af7535
 
 
 
 
 
 
baa7f8d
1a9eacc
baa7f8d
1a9eacc
33df548
 
b88325e
33df548
b88325e
33df548
b88325e
33df548
 
 
 
 
 
 
70b543a
baa7f8d
0641fac
 
33df548
0641fac
 
 
 
baa7f8d
33df548
 
 
baa7f8d
137bc98
baa7f8d
33df548
baa7f8d
b88325e
 
1a9eacc
0641fac
 
 
1a9eacc
 
 
 
 
 
 
 
 
 
 
 
 
baa7f8d
 
 
1a9eacc
 
baa7f8d
 
 
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
 
 
 
 
 
 
 
 
 
 
1a9eacc
137bc98
1a9eacc
 
2af7535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a53814
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
2af7535
1a9eacc
c61d5b5
4610d52
 
1a9eacc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os

import gradio as gr
import numpy as np
import plotly.graph_objs as go
from datasets import load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
    ComputedStructureEntry,
    GibbsComputedStructureEntry,
)

HF_TOKEN = os.environ.get("HF_TOKEN")

# Load only the train split of the dataset
dataset = load_dataset(
    "LeMaterial/leDataset",
    token=HF_TOKEN,
    split="train",
    columns=[
        "lattice_vectors",
        "species_at_sites",
        "cartesian_site_positions",
        "energy",
        "energy_corrected",
        "immutable_id",
        "elements",
        "functional",
    ],
)

# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()
del dataset


def create_phase_diagram(
    elements,
    max_e_above_hull,
    color_scheme,
    plot_style,
    functional,
    finite_temp,
    **kwargs,
):
    # Split elements and remove any whitespace
    element_list = [el.strip() for el in elements.split("-")]

    # Filter entries based on functional
    if functional == "PBE":
        entries_df = train_df[train_df["functional"] == "pbe"]
    elif functional == "PBESol":
        entries_df = train_df[train_df["functional"] == "pbesol"]
    elif functional == "SCAN":
        entries_df = train_df[train_df["functional"] == "scan"]

    isubset = lambda x: set(x).issubset(element_list)
    isintersection = lambda x: len(set(x).intersection(element_list)) > 0
    entries_df = entries_df[
        [isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
    ]

    # Fetch all entries from the Materials Project database
    entries = [
        ComputedStructureEntry(
            Structure(
                [x.tolist() for x in row["lattice_vectors"].tolist()],
                row["species_at_sites"],
                row["cartesian_site_positions"],
                coords_are_cartesian=True,
            ),
            energy=row["energy"],
            correction=row["energy_corrected"] - row["energy"]
            if not np.isnan(row["energy_corrected"])
            else 0,
            entry_id=row["immutable_id"],
            parameters={"run_type": row["functional"]},
        )
        for n, row in entries_df.iterrows()
    ]
    # TODO: Fetch elemental entries (they are usually GGA calculations)
    # entries.extend([e for e in entries if e.composition.is_element])

    if finite_temp:
        entries = GibbsComputedStructureEntry.from_entries(entries)

    # Build the phase diagram
    try:
        phase_diagram = PhaseDiagram(entries)
    except ValueError as e:
        return go.Figure().add_annotation(text=str(e))

    # Generate plotly figure
    if plot_style == "2D":
        plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
        fig = plotter.get_plot()
    else:
        # For 3D plots, limit to ternary systems
        if len(element_list) == 3:
            plotter = PDPlotter(
                phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
            )
            fig = plotter.get_plot()
        else:
            return go.Figure().add_annotation(
                text="3D plots are only available for ternary systems."
            )

    # Adjust the maximum energy above hull
    # (This is a placeholder as PDPlotter does not support direct filtering)

    # Return the figure
    return fig


# Define Gradio interface components
elements_input = gr.Textbox(
    label="Elements (e.g., 'Li-Fe-O')",
    placeholder="Enter elements separated by '-'",
    value="Li-Fe-O",
)
max_e_above_hull_slider = gr.Slider(
    minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
)
color_scheme_dropdown = gr.Dropdown(
    choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme"
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")

warning_message = "This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."

with gr.Blocks() as banner:
    message = gr.HTML(
        '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">&times;</span>{}</div>'.format(
            warning_message
        )
    )

# Create Gradio interface
iface = gr.Interface(
    fn=create_phase_diagram,
    inputs=[
        elements_input,
        max_e_above_hull_slider,
        color_scheme_dropdown,
        plot_style_dropdown,
        functional_dropdown,
        finite_temp_toggle,
        message,
    ],
    outputs=gr.Plot(label="Phase Diagram"),
    title="LeMaterial - Phase Diagram Viewer",
    description="Generate a phase diagram for a set of elements using LeMat-Bulk data.",
)

# Launch the app
iface.launch()