File size: 4,250 Bytes
1a9eacc
 
baa7f8d
 
70b543a
baa7f8d
0641fac
baa7f8d
0641fac
 
 
 
70b543a
 
 
 
 
 
1a9eacc
 
baa7f8d
 
 
1a9eacc
baa7f8d
1a9eacc
70b543a
baa7f8d
0641fac
 
 
 
 
 
 
baa7f8d
 
 
 
 
 
 
 
70b543a
 
1a9eacc
70b543a
 
baa7f8d
 
 
 
 
70b543a
baa7f8d
 
 
70b543a
 
1a9eacc
0641fac
 
 
1a9eacc
 
 
 
 
 
 
 
 
 
 
 
 
baa7f8d
 
 
1a9eacc
 
baa7f8d
 
 
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
 
 
 
 
 
 
 
 
 
 
1a9eacc
 
 
 
 
 
 
 
 
 
 
 
 
baa7f8d
1a9eacc
 
 
baa7f8d
1a9eacc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os

import gradio as gr
import plotly.graph_objs as go
from datasets import load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
    ComputedStructureEntry,
    GibbsComputedStructureEntry,
)

# Load only the train split of the dataset
dataset = load_dataset("LeMaterial/leDataset", split="train")

# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()


def create_phase_diagram(
    elements, max_e_above_hull, color_scheme, plot_style, functional, finite_temp
):
    # Split elements and remove any whitespace
    element_list = [el.strip() for el in elements.split("-")]

    # Fetch all entries from the Materials Project database
    entries = [
        ComputedStructureEntry(
            Structure(
                [x.tolist() for x in df.iloc[0]["lattice_vectors"].tolist()],
                row["species_at_sites"],
                row["cartesian_site_positions"],
                coords_are_cartesian=True,
            ),
            energy=row["energy"],
            correction=row["energy_corrected"] - row["energy"],
            entry_id=row["immutable_id"],
        )
        for n, row in train_df.iterrows()
        if len(set(row["elements"]).intersection(element_list)) > 0
        and set(row["elements"]).issubset(element_list)
    ]
    # Fetch elemental entries (they are usually GGA calculations)
    elemental_entries = [e for e in entries if e.composition.is_element]

    # Filter entries based on functional
    if functional == "GGA":
        entries = [
            e
            for e in entries
            if not e.parameters.get("run_type", "").startswith("GGA+U")
        ]
    elif functional == "GGA+U":
        entries = [
            e for e in entries if e.parameters.get("run_type", "").startswith("GGA+U")
        ]
        # Add elemental entries to ensure they are included
        entries.extend([e for e in elemental_entries if e not in entries])

    if finite_temp:
        entries = GibbsComputedStructureEntry.from_entries(entries)

    # Build the phase diagram
    try:
        phase_diagram = PhaseDiagram(entries)
    except ValueError as e:
        return go.Figure().add_annotation(text=str(e))

    # Generate plotly figure
    if plot_style == "2D":
        plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
        fig = plotter.get_plot()
    else:
        # For 3D plots, limit to ternary systems
        if len(element_list) == 3:
            plotter = PDPlotter(
                phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
            )
            fig = plotter.get_plot()
        else:
            return go.Figure().add_annotation(
                text="3D plots are only available for ternary systems."
            )

    # Adjust the maximum energy above hull
    # (This is a placeholder as PDPlotter does not support direct filtering)

    # Return the figure
    return fig


# Define Gradio interface components
elements_input = gr.Textbox(
    label="Elements (e.g., 'Li-Fe-O')",
    placeholder="Enter elements separated by '-'",
    value="Li-Fe-O",
)
max_e_above_hull_slider = gr.Slider(
    minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
)
color_scheme_dropdown = gr.Dropdown(
    choices=["Energy Above Hull", "Formation Energy"], label="Color Scheme"
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["GGA", "GGA+U", "Both"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")

# Create Gradio interface
iface = gr.Interface(
    fn=create_phase_diagram,
    inputs=[
        elements_input,
        max_e_above_hull_slider,
        color_scheme_dropdown,
        plot_style_dropdown,
        functional_dropdown,
        finite_temp_toggle,
    ],
    outputs=gr.Plot(label="Phase Diagram"),
    title="Materials Project Phase Diagram",
    description="Generate a phase diagram for a set of elements using Materials Project data.",
)

# Launch the app
iface.launch()