import spaces # Standard Libraries import os import io import csv import json import glob import random import tempfile import atexit from datetime import datetime # Third-Party Libraries import numpy as np import pandas as pd import torch import imageio from rdkit import Chem from rdkit.Chem import Draw import gradio as gr # Local Modules from evaluator import Evaluator from loader import load_graph_decoder # --------------------------- Configuration Constants --------------------------- # DATA_DIR = 'data' EVALUATORS_DIR = 'evaluators' FLAGGED_FOLDER = "flagged" KNOWN_LABELS_FILE = os.path.join(DATA_DIR, 'known_labels.csv') KNOWN_SMILES_FILE = os.path.join(DATA_DIR, 'known_polymers.csv') ALL_PROPERTIES = ['CH4', 'CO2', 'H2', 'N2', 'O2'] MODEL_NAME_MAPPING = { "model_all": "Graph DiT (trained on labeled + unlabeled)", "model_labeled": "Graph DiT (trained on labeled)" } GIF_TEMP_PREFIX = "polymer_gifs_" # --------------------------- Data Loading --------------------------- # def load_known_data(): """Load known labels and SMILES data from CSV files.""" try: known_labels = pd.read_csv(KNOWN_LABELS_FILE) known_smiles = pd.read_csv(KNOWN_SMILES_FILE) return known_labels, known_smiles except Exception as e: raise FileNotFoundError(f"Error loading data files: {e}") # Load data known_labels, known_smiles = load_known_data() # --------------------------- Evaluator Setup --------------------------- # def initialize_evaluators(properties, evaluators_dir): """Initialize evaluators for each property.""" evaluators = {} for prop in properties: evaluator_path = os.path.join(evaluators_dir, f'{prop}.joblib') evaluators[prop] = Evaluator(evaluator_path, prop) return evaluators evaluators = initialize_evaluators(ALL_PROPERTIES, EVALUATORS_DIR) # --------------------------- Property Ranges --------------------------- # def get_property_ranges(labels, properties): """Get min and max values for each property.""" return {prop: (labels[prop].min(), labels[prop].max()) for prop in properties} property_ranges = get_property_ranges(known_labels, ALL_PROPERTIES) # --------------------------- Temporary Directory Setup --------------------------- # temp_dir = tempfile.mkdtemp(prefix=GIF_TEMP_PREFIX) def cleanup_temp_files(): """Clean up temporary GIF files on exit.""" try: for file in glob.glob(os.path.join(temp_dir, "*.gif")): os.remove(file) os.rmdir(temp_dir) except Exception as e: print(f"Error during cleanup: {e}") atexit.register(cleanup_temp_files) # --------------------------- Utility Functions --------------------------- # def random_properties(): """Select a random set of properties from known labels.""" return known_labels[ALL_PROPERTIES].sample(1).values.tolist()[0] def load_model(model_choice): """Load the graph decoder model based on the choice.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_graph_decoder(path=model_choice) return model, device def save_interesting_log(smiles, properties, suggested_properties): """Save interesting polymer data to a CSV log file.""" log_file = os.path.join(FLAGGED_FOLDER, "log.csv") os.makedirs(FLAGGED_FOLDER, exist_ok=True) file_exists = os.path.isfile(log_file) fieldnames = ['timestamp', 'smiles'] + ALL_PROPERTIES + [f'suggested_{prop}' for prop in ALL_PROPERTIES] try: with open(log_file, 'a', newline='') as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) if not file_exists: writer.writeheader() log_data = { 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'smiles': smiles, **{prop: value for prop, value in zip(ALL_PROPERTIES, properties)}, **{f'suggested_{prop}': value for prop, value in suggested_properties.items()} } writer.writerow(log_data) except Exception as e: print(f"Error saving log: {e}") def is_nan_like(x): """Check if a value should be treated as NaN.""" return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x)) def numpy_to_python(obj): """Convert NumPy objects to native Python types.""" if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, list): return [numpy_to_python(item) for item in obj] elif isinstance(obj, dict): return {k: numpy_to_python(v) for k, v in obj.items()} else: return obj # --------------------------- Graph Generation Function --------------------------- # @spaces.GPU def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps): """ Generate a polymer graph based on the input properties and model. Returns generation results including SMILES, images, and properties. """ print('Generating graph...') model, device = model_state properties = [CH4, CO2, H2, N2, O2] # Handle NaN-like values properties = [None if is_nan_like(prop) else prop for prop in properties] nan_gases = [gas for gas, prop in zip(ALL_PROPERTIES, properties) if prop is None] nan_message = "The following gas properties were treated as NaN: " + (", ".join(nan_gases) if nan_gases else "None") num_nodes = None if num_nodes == 0 else num_nodes for attempt in range(repeating_time): try: generated_molecule, img_list = model.generate( properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps ) gif_path = None if img_list: imgs = [np.array(pil_img) for pil_img in img_list] imgs.extend([imgs[-1]] * 10) # Extend the last image for GIF gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif") imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0) if generated_molecule: mol = Chem.MolFromSmiles(generated_molecule) if mol: standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True) is_novel = standardized_smiles not in known_smiles['SMILES'].values novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)" img = Draw.MolToImage(mol) # Evaluate the generated molecule suggested_properties = {prop: evaluator([standardized_smiles])[0] for prop, evaluator in evaluators.items()} suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()]) return ( f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n" f"**{nan_message}**\n\n" f"**{novelty_status}**\n\n" f"**Suggested Properties:**\n{suggested_properties_text}", img, gif_path, standardized_smiles, properties, suggested_properties ) except Exception as e: print(f"Attempt {attempt + 1} failed: {e}") continue # If all attempts fail return ( f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None, "", [], {} ) # --------------------------- Feedback Processing --------------------------- # def process_feedback(checkbox_value, smiles, properties, suggested_properties): """ Process user feedback. If the user finds the polymer interesting, log it accordingly. """ if checkbox_value and smiles: save_interesting_log(smiles, properties, suggested_properties) return "Thank you for your feedback! This polymer has been saved to our interesting polymers log." return "Thank you for your feedback!" # --------------------------- Model Switching --------------------------- # def switch_model(choice): """Switch the model based on user selection.""" internal_name = next(key for key, value in MODEL_NAME_MAPPING.items() if value == choice) return load_model(internal_name) # --------------------------- Gradio Interface Setup --------------------------- # def create_gradio_interface(): """Create and return the Gradio Blocks interface.""" with gr.Blocks(title="Polymer Design with GraphDiT") as iface: # Navigation Bar with gr.Row(elem_id="navbar"): gr.Markdown("""