MooseML commited on
Commit
a1efe76
·
1 Parent(s): b12a328

Compact info expander and change permissions in Dockerfile

Browse files
Files changed (2) hide show
  1. Dockerfile +19 -70
  2. app.py +101 -223
Dockerfile CHANGED
@@ -1,90 +1,39 @@
1
- # Dockerfile for Hugging Face Space: Streamlit + RDKit + PyG
2
-
3
  FROM python:3.10-slim
4
 
5
- # System libraries (needed by RDKit / Pillow)
6
- RUN apt-get update && \
7
- apt-get install -y --no-install-recommends \
8
- build-essential \
9
- libxrender1 \
10
- libxext6 \
11
- libsm6 \
12
- libx11-6 \
13
- libglib2.0-0 \
14
- libfreetype6 \
15
- libpng-dev \
16
- wget && \
17
  rm -rf /var/lib/apt/lists/*
18
-
19
- # Create a non-root user to run the application
20
  RUN useradd -m appuser
21
 
22
- # Python packages
23
  RUN pip install --no-cache-dir --upgrade pip && \
24
  pip install --no-cache-dir \
25
- streamlit==1.45.0 \
26
- rdkit-pypi==2022.9.5 \
27
- pandas==2.2.3 \
28
- numpy==1.26.4 \
29
- torch==2.2.0 \
30
- torch-geometric==2.5.2 \
31
- ogb==1.3.6 \
32
- pillow==10.3.0
33
-
34
- # Set up working directory
35
- WORKDIR /app
36
 
37
- # Copy application files
 
38
  COPY . .
39
 
40
- # Fix permissions for temporary directories
41
- RUN mkdir -p /tmp/streamlit && \
42
- chmod -R 777 /tmp && \
43
- chmod -R 777 /tmp/streamlit && \
44
- # Also ensure the SQLite database directory is writable
45
- mkdir -p /data && \
46
- chmod -R 777 /data && \
47
- # Make sure the app files are readable
48
- chmod -R 755 /app
49
-
50
- # Ensure temp directories exist and are writable
51
- RUN mkdir -p /tmp/csv_uploads && \
52
- chmod -R 777 /tmp/csv_uploads
53
 
54
- # Set environment variables
55
  ENV DB_DIR=/data \
56
- TMPDIR=/tmp \
57
  STREAMLIT_SERVER_HEADLESS=true \
58
  STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
59
  STREAMLIT_SERVER_PORT=7860 \
60
  STREAMLIT_TELEMETRY_DISABLED=true \
61
  STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
62
- # Increase file upload size limit to accommodate larger CSVs
63
- STREAMLIT_SERVER_MAX_UPLOAD_SIZE=50
64
-
65
- # Expose the port Streamlit will run on
66
- EXPOSE 7860
67
-
68
- # Set entrypoint script
69
- COPY <<EOF /app/entrypoint.sh
70
- #!/bin/bash
71
- echo "Starting Streamlit app with debug info"
72
- echo "Current directory: $(pwd)"
73
- echo "Files in current directory: $(ls -la)"
74
- echo "Python version: $(python --version)"
75
- echo "Temp directory: $TMPDIR"
76
- echo "Temp directory exists: $([ -d $TMPDIR ] && echo 'Yes' || echo 'No')"
77
- echo "Temp directory permissions: $(ls -ld $TMPDIR)"
78
-
79
- # Run the app
80
- streamlit run app.py
81
- EOF
82
 
83
- # Make the entrypoint script executable
84
- RUN chmod +x /app/entrypoint.sh
85
 
86
- # Switch to the non-root user for better security
87
  USER appuser
88
-
89
- # Launch using the entrypoint script
90
- CMD ["/app/entrypoint.sh"]
 
1
+ # Dockerfile: Streamlit/RDKit/PyG (Hugging Face Spaces)
 
2
  FROM python:3.10-slim
3
 
4
+ # OS libs for RDKit drawing
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential libxrender1 libxext6 libsm6 libx11-6 \
7
+ libglib2.0-0 libfreetype6 libpng-dev wget && \
 
 
 
 
 
 
 
 
8
  rm -rf /var/lib/apt/lists/*
9
+
10
+ # Non‑root user
11
  RUN useradd -m appuser
12
 
13
+ # Python packages
14
  RUN pip install --no-cache-dir --upgrade pip && \
15
  pip install --no-cache-dir \
16
+ streamlit==1.45.0 rdkit-pypi==2022.9.5 pandas==2.2.3 \
17
+ numpy==1.26.4 torch==2.2.0 torch-geometric==2.5.2 \
18
+ ogb==1.3.6 pillow==10.3.0
 
 
 
 
 
 
 
 
19
 
20
+ # Workdir and code
21
+ WORKDIR /app
22
  COPY . .
23
 
24
+ # Writable dirs with 775 perms
25
+ RUN install -d -m 775 /tmp/streamlit /data
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Environment
28
  ENV DB_DIR=/data \
 
29
  STREAMLIT_SERVER_HEADLESS=true \
30
  STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
31
  STREAMLIT_SERVER_PORT=7860 \
32
  STREAMLIT_TELEMETRY_DISABLED=true \
33
  STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
34
+ STREAMLIT_SERVER_MAX_UPLOAD_SIZE=50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ EXPOSE 7860
 
37
 
 
38
  USER appuser
39
+ CMD ["streamlit", "run", "app.py"]
 
 
app.py CHANGED
@@ -1,87 +1,53 @@
1
- import streamlit as st
 
 
 
2
  import pandas as pd
 
3
  import torch
4
- import sqlite3
5
- from datetime import datetime
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
- import os, pathlib, sys
9
- import tempfile
10
- from io import StringIO, BytesIO
11
  from model import load_model
12
  from utils import smiles_to_data
13
- from torch_geometric.loader import DataLoader
14
 
15
- # Config
16
- DEVICE = "cpu"
17
- RDKIT_DIM = 6
18
- MODEL_PATH = "best_hybridgnn.pt"
19
- MAX_DISPLAY = 10
20
-
21
- # Debug sidebar
22
- with st.sidebar:
23
- st.title("Debug Tools")
24
- if st.button("Show Environment Info"):
25
- st.write("### System Info")
26
- st.write(f"Python version: {sys.version}")
27
- st.write(f"Current working directory: {os.getcwd()}")
28
- st.write(f"Temp directory: {tempfile.gettempdir()}")
29
- st.write(f"Temp dir exists: {os.path.exists(tempfile.gettempdir())}")
30
- st.write(f"Temp dir writable: {os.access(tempfile.gettempdir(), os.W_OK)}")
31
- st.write(f"Current user: {os.getenv('USER', 'unknown')}")
32
-
33
- try:
34
- st.write("### Directory Contents")
35
- st.write(f"Files in current directory: {os.listdir('.')}")
36
- st.write(f"Files in /tmp: {os.listdir('/tmp')}")
37
- except Exception as e:
38
- st.error(f"Error listing directories: {e}")
39
-
40
- st.write("### Environment Variables")
41
- for key, value in os.environ.items():
42
- if not key.startswith(('AWS', 'SECRET')): # Skip sensitive vars
43
- st.write(f"{key}: {value}")
44
-
45
- # Load Model
46
  @st.cache_resource
47
- def load_cached_model():
48
- try:
49
- return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
50
- except Exception as e:
51
- st.error(f"Error loading model: {e}")
52
- return None
53
 
54
- model = load_cached_model()
55
 
56
- # SQLite Setup
57
- DB_DIR = os.getenv("DB_DIR", "/tmp")
58
- pathlib.Path(DB_DIR).mkdir(parents=True, exist_ok=True)
59
 
60
  @st.cache_resource
61
  def init_db():
62
- try:
63
- db_file = os.path.join(DB_DIR, "predictions.db")
64
- conn = sqlite3.connect(db_file, check_same_thread=False)
65
- c = conn.cursor()
66
- c.execute("""
67
- CREATE TABLE IF NOT EXISTS predictions (
68
- id INTEGER PRIMARY KEY AUTOINCREMENT,
69
- smiles TEXT,
70
- prediction REAL,
71
- timestamp TEXT
72
- )
73
- """)
74
- conn.commit()
75
- return conn
76
- except Exception as e:
77
- st.error(f"Database initialization error: {e}")
78
- return None
79
-
80
- conn = init_db()
81
- if conn:
82
- cursor = conn.cursor()
83
-
84
- # Streamlit UI
85
  st.title("HOMO-LUMO Gap Predictor")
86
  st.markdown("""
87
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
@@ -94,160 +60,72 @@ This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph N
94
  - The app will display predictions and molecule images (up to 10 shown at once).
95
  """)
96
 
97
- # File handling with caching
98
- @st.cache_data
99
- def read_csv_file(file_content):
100
- """Cache the file reading operation"""
101
- try:
102
- # Try to read as string first
103
- if isinstance(file_content, str):
104
- df = pd.read_csv(StringIO(file_content), comment="#")
105
- else:
106
- # If it's bytes, decode it
107
- df = pd.read_csv(StringIO(file_content.decode('utf-8')), comment="#")
108
- return df, None
109
- except Exception as e:
110
- return None, str(e)
111
-
112
- # Debug container for file upload messages
113
- file_debug = st.container()
114
-
115
- # File uploader outside the form
116
- uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
117
- if uploaded_file is not None:
118
- with file_debug:
119
- st.write(f"File name: {uploaded_file.name}")
120
- st.write(f"File type: {uploaded_file.type}")
121
- st.write(f"File size: {uploaded_file.size} bytes")
122
-
123
- with st.form("input_form"):
124
- smiles_input = st.text_area(
125
- "Enter SMILES string(s)",
126
- placeholder="C1=CC=CC=C1, CC(=O)Oc1ccccc1C(=O)O",
127
- height=120
128
- )
129
- run_button = st.form_submit_button("Submit")
130
 
 
131
  smiles_list = []
132
-
133
- # Process after the button press
134
- if run_button:
135
- # CSV path
136
- if uploaded_file is not None:
137
- with file_debug:
138
- st.write("### Processing CSV file")
139
- try:
140
- # Save file temporarily for debugging
141
- temp_file = os.path.join(tempfile.gettempdir(), uploaded_file.name)
142
- with open(temp_file, 'wb') as f:
143
- f.write(uploaded_file.getvalue())
144
- st.write(f"Saved temporary file at: {temp_file}")
145
- st.write(f"File exists: {os.path.exists(temp_file)}")
146
- st.write(f"File size on disk: {os.path.getsize(temp_file)} bytes")
147
-
148
- # Read file content
149
- file_content = uploaded_file.getvalue()
150
- st.write(f"Read {len(file_content)} bytes from file")
151
-
152
- # Try to decode first few bytes
153
- preview = file_content[:100] if len(file_content) > 100 else file_content
154
- try:
155
- decoded_preview = preview.decode('utf-8')
156
- st.write(f"File preview (decoded): {decoded_preview}")
157
- except:
158
- st.write(f"File preview (hex): {preview.hex()}")
159
-
160
- # Use cached reading function
161
- df, error = read_csv_file(file_content)
162
-
163
- if error:
164
- st.error(f"CSV reading error: {error}")
165
- elif df is not None:
166
- st.write(f"CSV loaded with {df.shape[0]} rows and {df.shape[1]} columns")
167
- st.write("CSV columns:", df.columns.tolist())
168
- st.write("First few rows:", df.head())
169
-
170
- # choose the SMILES column
171
- if df.shape[1] == 1:
172
- smiles_col = df.iloc[:, 0]
173
- st.write("Using the only column for SMILES")
174
- elif "smiles" in [c.lower() for c in df.columns]:
175
- col_name = [c for c in df.columns if c.lower() == "smiles"][0]
176
- smiles_col = df[col_name]
177
- st.write(f"Using column '{col_name}' for SMILES")
178
- else:
179
- st.error(f"CSV must have a single column or a column named 'SMILES'. Found columns: {', '.join(df.columns)}")
180
- st.write("Using first column as fallback")
181
- smiles_col = df.iloc[:, 0]
182
-
183
- smiles_list = smiles_col.dropna().astype(str).tolist()
184
- st.success(f"{len(smiles_list)} SMILES loaded from CSV")
185
- if smiles_list:
186
- st.write("First few SMILES:", smiles_list[:5])
187
- else:
188
- st.error("Failed to process CSV: DataFrame is None")
189
- except Exception as e:
190
- st.error(f"Critical error processing CSV: {str(e)}")
191
- st.exception(e) # This shows the full traceback
192
-
193
- # Textarea path
194
- elif smiles_input.strip():
195
- raw_input = smiles_input.replace("\n", ",")
196
- smiles_list = [s.strip() for s in raw_input.split(",") if s.strip()]
197
- st.success(f"{len(smiles_list)} SMILES parsed from text")
198
- if smiles_list:
199
- st.write("First few SMILES:", smiles_list[:5])
200
-
201
- # Run Inference
202
- if smiles_list:
203
- with st.spinner("Processing molecules..."):
204
  try:
205
- data_list = smiles_to_data(smiles_list, device=DEVICE)
206
-
207
- # Filter only valid molecules and keep aligned SMILES
208
- valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
209
-
210
- if not valid_pairs:
211
- st.warning("No valid molecules found")
212
  else:
213
- valid_smiles, valid_data = zip(*valid_pairs)
214
- loader = DataLoader(valid_data, batch_size=64)
215
- predictions = []
216
-
217
- for batch in loader:
218
- batch = batch.to(DEVICE)
219
- with torch.no_grad():
220
- pred = model(batch).view(-1).cpu().numpy()
221
- predictions.extend(pred.tolist())
222
-
223
- # Display Results
224
- st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
225
-
226
- for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
227
- if i >= MAX_DISPLAY:
228
- st.info(f"...only showing the first {MAX_DISPLAY} molecules")
229
- break
230
-
231
- mol = Chem.MolFromSmiles(smi)
232
- if mol:
233
- st.image(Draw.MolToImage(mol, size=(250, 250)))
234
- st.write(f"**SMILES**: `{smi}`")
235
- st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
236
-
237
- # Log to SQLite if connection exists
238
- if conn:
239
- cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
240
- (smi, pred, str(datetime.now())))
241
- conn.commit()
242
-
243
- # Download Results
244
- result_df = pd.DataFrame({"SMILES": valid_smiles,
245
- "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]})
246
-
247
- st.download_button(label="Download Predictions as CSV",
248
- data=result_df.to_csv(index=False).encode('utf-8'),
249
- file_name="homolumo_predictions.csv",
250
- mime="text/csv")
251
  except Exception as e:
252
- st.error(f"Error during inference: {str(e)}")
253
- st.exception(e) # This shows the full traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pathlib, sqlite3, sys, tempfile
2
+ from datetime import datetime
3
+ from io import StringIO
4
+
5
  import pandas as pd
6
+ import streamlit as st
7
  import torch
 
 
8
  from rdkit import Chem
9
  from rdkit.Chem import Draw
10
+ from torch_geometric.loader import DataLoader
11
+
 
12
  from model import load_model
13
  from utils import smiles_to_data
 
14
 
15
+ # Config
16
+ DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 10
17
+
18
+ # Model & DB (cached)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @st.cache_resource
20
+ def get_model():
21
+ return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
 
 
 
 
22
 
23
+ model = get_model()
24
 
25
+ DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
26
+ DB_DIR.mkdir(parents=True, exist_ok=True)
 
27
 
28
  @st.cache_resource
29
  def init_db():
30
+ conn = sqlite3.connect(DB_DIR / "predictions.db", check_same_thread=False)
31
+ conn.execute(
32
+ """CREATE TABLE IF NOT EXISTS predictions(
33
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
34
+ smiles TEXT, prediction REAL, timestamp TEXT)"""
35
+ )
36
+ conn.commit()
37
+ return conn
38
+
39
+ conn = init_db()
40
+ cursor = conn.cursor()
41
+
42
+ # debug and info panel
43
+ with st.sidebar.expander("Info & Env", expanded=False):
44
+ st.write(f"Python {sys.version.split()[0]}")
45
+ st.write(f"Temp dir: `{tempfile.gettempdir()}` "
46
+ f"({'writable' if os.access(tempfile.gettempdir(), os.W_OK) else 'read-only'})")
47
+ if "csv_bytes" in st.session_state:
48
+ st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
49
+
50
+ # Header
 
 
51
  st.title("HOMO-LUMO Gap Predictor")
52
  st.markdown("""
53
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
 
60
  - The app will display predictions and molecule images (up to 10 shown at once).
61
  """)
62
 
63
+ # File uploader (outside form)
64
+ csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
65
+ if csv_file is not None:
66
+ st.session_state["csv_bytes"] = csv_file.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Input form
69
  smiles_list = []
70
+ with st.form("main_form"):
71
+ smiles_text = st.text_area("…or paste SMILES (comma/newline separated)",
72
+ placeholder="CC(=O)Oc1ccccc1C(=O)O",
73
+ height=120)
74
+ run = st.form_submit_button("Run Prediction")
75
+
76
+ # Parse input
77
+ if run:
78
+ if "csv_bytes" in st.session_state:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  try:
80
+ df = pd.read_csv(StringIO(st.session_state["csv_bytes"].decode("utf-8")), comment="#")
81
+ col = df.columns[0] if df.shape[1] == 1 else next((c for c in df.columns if c.lower() == "smiles"), None)
82
+ if col is None:
83
+ st.error("CSV needs one column or a 'SMILES' column")
 
 
 
84
  else:
85
+ smiles_list = df[col].dropna().astype(str).tolist()
86
+ st.success(f"{len(smiles_list)} SMILES loaded from CSV")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
+ st.error(f"CSV error: {e}")
89
+
90
+ elif smiles_text.strip():
91
+ smiles_list = [s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()]
92
+ st.success(f"{len(smiles_list)} SMILES parsed from textbox")
93
+ else:
94
+ st.warning("No input provided")
95
+
96
+ # Inference & display
97
+ if smiles_list:
98
+ data_list = smiles_to_data(smiles_list, device=DEVICE)
99
+ valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
100
+
101
+ if not valid:
102
+ st.warning("No valid molecules")
103
+ else:
104
+ vsmi, vdata = zip(*valid)
105
+ preds = []
106
+ for batch in DataLoader(vdata, batch_size=64):
107
+ with torch.no_grad():
108
+ preds.extend(get_model()(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())
109
+
110
+ st.subheader(f"Results (first {MAX_DISPLAY})")
111
+ for i, (smi, pred) in enumerate(zip(vsmi, preds)):
112
+ if i >= MAX_DISPLAY:
113
+ st.info("…truncated")
114
+ break
115
+ mol = Chem.MolFromSmiles(smi)
116
+ if mol:
117
+ st.image(Draw.MolToImage(mol, size=(250, 250)))
118
+ st.write(f"`{smi}` → **{pred:.4f} eV**")
119
+
120
+ cursor.execute(
121
+ "INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
122
+ (smi, float(pred), datetime.now().isoformat()),
123
+ )
124
+ conn.commit()
125
+
126
+ st.download_button("Download CSV",
127
+ pd.DataFrame(
128
+ {"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]}
129
+ ).to_csv(index=False).encode(),
130
+ "homolumo_predictions.csv",
131
+ "text/csv")