File size: 5,238 Bytes
0ef141b
 
e3eae4d
 
0ef141b
 
 
 
 
e3eae4d
 
0ef141b
 
e3eae4d
 
 
0ef141b
 
 
 
 
e3eae4d
0ef141b
 
 
 
 
 
e3eae4d
0ef141b
 
3c3852b
e3eae4d
 
0ef141b
e3eae4d
0ef141b
 
e3eae4d
 
 
 
 
 
0ef141b
 
e3eae4d
 
 
0ef141b
e3eae4d
 
0ef141b
e3eae4d
0ef141b
 
 
 
 
 
 
 
 
 
 
e3eae4d
d96a023
 
b8f152e
 
0ef141b
 
b8f152e
 
 
 
0ef141b
b8f152e
0ef141b
 
 
 
10c6717
0ef141b
 
b8f152e
10c6717
 
 
0ef141b
 
 
10c6717
b8f152e
0ef141b
 
b8f152e
10c6717
 
 
 
0ef141b
10c6717
b8f152e
 
0ef141b
b8f152e
 
 
0ef141b
b8f152e
0ef141b
10c6717
0ef141b
e3eae4d
0ef141b
e3eae4d
 
0ef141b
 
 
 
 
e3eae4d
0ef141b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3eae4d
0ef141b
 
 
0714060
0ef141b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import pathlib
import sqlite3
from datetime import datetime
from io import StringIO

import pandas as pd
import streamlit as st
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from torch_geometric.loader import DataLoader

from model import load_model
from utils import smiles_to_data

# ───────────────────────── Configuration ─────────────────────────
DEVICE       = "cpu"
RDKIT_DIM    = 6
MODEL_PATH   = "best_hybridgnn.pt"
MAX_DISPLAY  = 10

# ─────────────────────── Cached model & database ─────────────────
@st.cache_resource
def get_model():
    return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)

model = get_model()

DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
DB_DIR.mkdir(parents=True, exist_ok=True)

@st.cache_resource
def init_db():
    conn = sqlite3.connect(DB_DIR / "predictions.db", check_same_thread=False)
    c = conn.cursor()
    c.execute(
        """
        CREATE TABLE IF NOT EXISTS predictions (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            smiles TEXT,
            prediction REAL,
            timestamp TEXT
        )
        """
    )
    conn.commit()
    return conn

conn   = init_db()
cursor = conn.cursor()

#  UI header 
st.title("HOMO-LUMO Gap Predictor")
st.markdown(
    """
    Paste SMILES **or** upload a one-column CSV, then click **Run Prediction**.
    The app draws each molecule and shows the predicted HOMO-LUMO gap (eV).
    """
)

#  Input widgets 
csv_file = st.file_uploader("Upload CSV (one SMILES column)", type=["csv"])
if csv_file is not None:
    st.session_state["uploaded_csv"] = csv_file  # persist across reruns

smiles_list = []

with st.form("smiles_or_csv"):
    smiles_text = st.text_area(
        "…or paste SMILES (comma or newline separated)",
        placeholder="CC(=O)Oc1ccccc1C(=O)O",
        height=120,
    )
    run = st.form_submit_button("Run Prediction")

#  Parse input after button 
if run:
    csv_obj = st.session_state.get("uploaded_csv", None)

    #  CSV branch 
    if csv_obj is not None:
        try:
            csv_obj.seek(0)
            df = pd.read_csv(StringIO(csv_obj.getvalue().decode("utf-8")), comment="#")

            if df.shape[1] == 1:
                smiles_col = df.iloc[:, 0]
            elif "smiles" in [c.lower() for c in df.columns]:
                smiles_col = df[
                    [c for c in df.columns if c.lower() == "smiles"][0]
                ]
            else:
                st.error(
                    "CSV must have one column **or** a column named 'SMILES'"
                    f"Found: {', '.join(df.columns)}"
                )
                smiles_col = None

            if smiles_col is not None:
                smiles_list = smiles_col.dropna().astype(str).tolist()
                st.success(f"{len(smiles_list)} SMILES loaded from CSV")
        except Exception as e:
            st.error(f"CSV read error: {e}")

    #  Textarea branch 
    elif smiles_text.strip():
        raw = smiles_text.replace("\n", ",")
        smiles_list = [s.strip() for s in raw.split(",") if s.strip()]
        st.success(f"{len(smiles_list)} SMILES parsed from textbox")
    else:
        st.warning("Paste SMILES or upload a CSV before pressing **Run**")

#  Inference 
if smiles_list:
    with st.spinner("Running model…"):
        data_list = smiles_to_data(smiles_list, device=DEVICE)

    valid_pairs = [
        (smi, data)
        for smi, data in zip(smiles_list, data_list)
        if data is not None
    ]

    if not valid_pairs:
        st.warning("No valid molecules found")
    else:
        valid_smiles, valid_data = zip(*valid_pairs)
        loader = DataLoader(valid_data, batch_size=64)
        preds  = []

        for batch in loader:
            batch = batch.to(DEVICE)
            with torch.no_grad():
                preds.extend(model(batch).view(-1).cpu().numpy().tolist())

        #  Display results 
        st.subheader(f"Predictions (showing up to {MAX_DISPLAY})")
        for i, (smi, pred) in enumerate(zip(valid_smiles, preds)):
            if i >= MAX_DISPLAY:
                st.info(f"…only first {MAX_DISPLAY} molecules shown")
                break
            mol = Chem.MolFromSmiles(smi)
            if mol:
                st.image(Draw.MolToImage(mol, size=(250, 250)))
            st.write(f"**SMILES:** `{smi}`")
            st.write(f"**Predicted Gap:** `{pred:.4f} eV`")

            cursor.execute(
                "INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
                (smi, float(pred), datetime.now().isoformat())
            )
        conn.commit()

        #  Download results 
        res_df = pd.DataFrame(
            {
                "SMILES": valid_smiles,
                "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in preds],
            }
        )
        st.download_button(
            "Download results as CSV",
            res_df.to_csv(index=False).encode("utf-8"),
            "homolumo_predictions.csv",
            "text/csv",
        )