Spaces:
Sleeping
Sleeping
File size: 5,549 Bytes
e3eae4d 3c3852b 10c6717 e3eae4d 3c3852b e3eae4d 3c3852b e3eae4d 10c6717 e3eae4d 10c6717 e3eae4d 10c6717 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import streamlit as st
import pandas as pd
import torch
import sqlite3
from datetime import datetime
from rdkit import Chem
from rdkit.Chem import Draw
import os, pathlib
from io import StringIO
from model import load_model
from utils import smiles_to_data
from torch_geometric.loader import DataLoader
# Config
DEVICE = "cpu"
RDKIT_DIM = 6
MODEL_PATH = "best_hybridgnn.pt"
MAX_DISPLAY = 10
# Load Model
model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
# SQLite Setup
DB_DIR = os.getenv("DB_DIR", "/tmp") # /data if you add a volume later
pathlib.Path(DB_DIR).mkdir(parents=True, exist_ok=True)
@st.cache_resource
def init_db():
db_file = os.path.join(DB_DIR, "predictions.db")
conn = sqlite3.connect(db_file, check_same_thread=False)
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS predictions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
smiles TEXT,
prediction REAL,
timestamp TEXT
)
""")
conn.commit()
return conn
conn = init_db()
cursor = conn.cursor()
# Streamlit UI
st.title("HOMO-LUMO Gap Predictor")
st.markdown("""
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
**Instructions:**
- Enter a **single SMILES** string or **comma-separated list** in the box below.
- Or **upload a CSV file** containing a single column of SMILES strings.
- **Note**: If you've uploaded a CSV and want to switch to typing SMILES, please click the “X” next to the uploaded file to clear it.
- SMILES format should look like: `CC(=O)Oc1ccccc1C(=O)O` (for aspirin).
- The app will display predictions and molecule images (up to 10 shown at once).
""")
# Text Input
smiles_input = st.text_area("Enter SMILES string(s)", placeholder="C1=CC=CC=C1, CC(=O)Oc1ccccc1C(=O)O")
# File Upload
uploaded_file = st.file_uploader("...or upload a CSV file", type=["csv"])
smiles_list = []
with st.form("input_form"):
smiles_input = st.text_area("Enter SMILES string(s)", placeholder="C1=CC=CC=C1, CC(=O)Oc1ccccc1C(=O)O", height=120)
uploaded_file = st.file_uploader("…or upload a CSV file", type=["csv"])
run_button = st.form_submit_button("Run Prediction")
# Process only after the user presses the button
if run_button:
# CSV path
if uploaded_file is not None:
try:
data = uploaded_file.getvalue() # read bytes
df = pd.read_csv(StringIO(data.decode("utf-8")), comment="#")
# choose the SMILES column
if df.shape[1] == 1:
smiles_col = df.iloc[:, 0]
elif "smiles" in [c.lower() for c in df.columns]:
smiles_col = df[[c for c in df.columns if c.lower() == "smiles"][0]]
else:
st.error("CSV must have a single column or a column named 'SMILES'" f"Found columns: {', '.join(df.columns)}")
smiles_col = None
if smiles_col is not None:
smiles_list = smiles_col.dropna().astype(str).tolist()
st.success(f"{len(smiles_list)} SMILES loaded from CSV")
except Exception as e:
st.error(f"Could not read CSV: {e}")
# Textarea path
elif smiles_input.strip():
raw_input = smiles_input.replace("\n", ",")
smiles_list = [s.strip() for s in raw_input.split(",") if s.strip()]
st.success(f"{len(smiles_list)} SMILES parsed from text")
# Run Inference
if smiles_list:
with st.spinner("Processing molecules..."):
data_list = smiles_to_data(smiles_list, device=DEVICE)
# Filter only valid molecules and keep aligned SMILES
valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
if not valid_pairs:
st.warning("No valid molecules found")
else:
valid_smiles, valid_data = zip(*valid_pairs)
loader = DataLoader(valid_data, batch_size=64)
predictions = []
for batch in loader:
batch = batch.to(DEVICE)
with torch.no_grad():
pred = model(batch).view(-1).cpu().numpy()
predictions.extend(pred.tolist())
# Display Results
st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
if i >= MAX_DISPLAY:
st.info(f"...only showing the first {MAX_DISPLAY} molecules")
break
mol = Chem.MolFromSmiles(smi)
if mol:
st.image(Draw.MolToImage(mol, size=(250, 250)))
st.write(f"**SMILES**: `{smi}`")
st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
# Log to SQLite
cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
(smi, pred, str(datetime.now())))
conn.commit()
# Download Results
result_df = pd.DataFrame({"SMILES": valid_smiles,
"Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]})
st.download_button(label="Download Predictions as CSV",
data=result_df.to_csv(index=False).encode('utf-8'),
file_name="homolumo_predictions.csv",
mime="text/csv")
|