Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	New tools and filters for cheminfo
Browse files* Update Cheminformatrics use Cases
Add _cheminfo_tools.py with lipinksi filter , View Mol Image , View mol filter with smarts and smiles and highlights are done .
* update new filters  and chembl webapi
update new filters  and chembl webapi
veber, pains, muegge, brenk_aggregator_filter, egan , ghose , new qsar2.py code with matplotlib plots.
* update tools
update on chembl uniprot based search
* update the code
Delete the old files and folder
Put in example \ Cheminformatics folders
Chembl web service client with example
Plots with plot qsar and plot qsar2 with confidence intervals
* Update new code with new workspace
New workspace created deleted ex1 and ex2 .
Deleted the ecfp and maccs model .pkl file
- examples/.crdt/Image table.lynxkite.json.crdt +0 -0
- examples/.crdt/requirements.txt.crdt +0 -0
- examples/Cheminformatics/chem_utils.py +263 -0
- examples/Cheminformatics/chembl_api_uses.lynxkite.json +0 -0
- examples/Cheminformatics/chembl_tools.py +206 -0
- examples/Cheminformatics/cheminfo_tools.py +610 -0
- examples/Cheminformatics/qsar_example.lynxkite.json +0 -0
- examples/draw_molecules.py +0 -29
- examples/requirements.txt +3 -0
    	
        examples/.crdt/Image table.lynxkite.json.crdt
    ADDED
    
    | Binary file (31.8 kB). View file | 
|  | 
    	
        examples/.crdt/requirements.txt.crdt
    ADDED
    
    | Binary file (251 Bytes). View file | 
|  | 
    	
        examples/Cheminformatics/chem_utils.py
    ADDED
    
    | @@ -0,0 +1,263 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import io
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            from io import StringIO
         | 
| 5 | 
            +
            from operator import itemgetter
         | 
| 6 | 
            +
            from typing import List
         | 
| 7 | 
            +
            from typing import Tuple
         | 
| 8 | 
            +
            import itertools
         | 
| 9 | 
            +
            import matplotlib.pyplot as plt
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import seaborn as sns
         | 
| 12 | 
            +
            from rdkit import Chem, DataStructs, RDLogger
         | 
| 13 | 
            +
            from rdkit.Chem.Draw import rdMolDraw2D
         | 
| 14 | 
            +
            from rdkit.Chem.rdchem import Mol
         | 
| 15 | 
            +
            from rdkit.ML.Cluster import Butina
         | 
| 16 | 
            +
            from rdkit.rdBase import BlockLogs
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import pandas as pd
         | 
| 19 | 
            +
            from rdkit.Chem.rdMMPA import FragmentMol
         | 
| 20 | 
            +
            from rdkit.Chem.rdRGroupDecomposition import RGroupDecompose
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def smi2mol_with_errors(smi: str) -> Tuple[Mol, str]:
         | 
| 24 | 
            +
                """Parse SMILES and return any associated errors or warnings
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                :param smi: input SMILES
         | 
| 27 | 
            +
                :return: tuple of RDKit molecule, warning or error
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                sio = sys.stderr = StringIO()
         | 
| 30 | 
            +
                mol = Chem.MolFromSmiles(smi)
         | 
| 31 | 
            +
                err = sio.getvalue()
         | 
| 32 | 
            +
                sio = sys.stderr = StringIO()
         | 
| 33 | 
            +
                sys.stderr = sys.__stderr__
         | 
| 34 | 
            +
                return mol, err
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def count_fragments(mol: Mol) -> int:
         | 
| 38 | 
            +
                """Count the number of fragments in a molecule
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                :param mol: RDKit molecule
         | 
| 41 | 
            +
                :return: number of fragments
         | 
| 42 | 
            +
                """
         | 
| 43 | 
            +
                return len(Chem.GetMolFrags(mol, asMols=True))
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def get_largest_fragment(mol: Mol) -> Mol:
         | 
| 47 | 
            +
                """Return the fragment with the largest number of atoms
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                :param mol: RDKit molecule
         | 
| 50 | 
            +
                :return: RDKit molecule with the largest number of atoms
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                frag_list = list(Chem.GetMolFrags(mol, asMols=True))
         | 
| 53 | 
            +
                frag_mw_list = [(x.GetNumAtoms(), x) for x in frag_list]
         | 
| 54 | 
            +
                frag_mw_list.sort(key=itemgetter(0), reverse=True)
         | 
| 55 | 
            +
                return frag_mw_list[0][1]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            # ----------- Clustering
         | 
| 59 | 
            +
            # https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GroupShuffleSplit.html
         | 
| 60 | 
            +
            def taylor_butina_clustering(
         | 
| 61 | 
            +
                fp_list: List[DataStructs.ExplicitBitVect], cutoff: float = 0.65
         | 
| 62 | 
            +
            ) -> List[int]:
         | 
| 63 | 
            +
                """Cluster a set of fingerprints using the RDKit Taylor-Butina implementation
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                :param fp_list: a list of fingerprints
         | 
| 66 | 
            +
                :param cutoff: distance cutoff (1 - Tanimoto similarity)
         | 
| 67 | 
            +
                :return: a list of cluster ids
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                dists = []
         | 
| 70 | 
            +
                nfps = len(fp_list)
         | 
| 71 | 
            +
                for i in range(1, nfps):
         | 
| 72 | 
            +
                    sims = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])
         | 
| 73 | 
            +
                    dists.extend([1 - x for x in sims])
         | 
| 74 | 
            +
                cluster_res = Butina.ClusterData(dists, nfps, cutoff, isDistData=True)
         | 
| 75 | 
            +
                cluster_id_list = np.zeros(nfps, dtype=int)
         | 
| 76 | 
            +
                for cluster_num, cluster in enumerate(cluster_res):
         | 
| 77 | 
            +
                    for member in cluster:
         | 
| 78 | 
            +
                        cluster_id_list[member] = cluster_num
         | 
| 79 | 
            +
                return cluster_id_list.tolist()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            # ----------- Atom tagging
         | 
| 83 | 
            +
            def label_atoms(mol: Mol, labels: List[str]) -> Mol:
         | 
| 84 | 
            +
                """Label atoms when depicting a molecule
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                :param mol: input molecule
         | 
| 87 | 
            +
                :param labels: labels, one for each atom
         | 
| 88 | 
            +
                :return: molecule with labels
         | 
| 89 | 
            +
                """
         | 
| 90 | 
            +
                [atm.SetProp("atomNote", "") for atm in mol.GetAtoms()]
         | 
| 91 | 
            +
                for atm in mol.GetAtoms():
         | 
| 92 | 
            +
                    idx = atm.GetIdx()
         | 
| 93 | 
            +
                    mol.GetAtomWithIdx(idx).SetProp("atomNote", f"{labels[idx]}")
         | 
| 94 | 
            +
                return mol
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def tag_atoms(mol: Mol, atoms_to_tag: List[int], tag: str = "x") -> Mol:
         | 
| 98 | 
            +
                """Tag atoms with a specified string
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                :param mol: input molecule
         | 
| 101 | 
            +
                :param atoms_to_tag: indices of atoms to tag
         | 
| 102 | 
            +
                :param tag: string to use for the tags
         | 
| 103 | 
            +
                :return: molecule with atoms tagged
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                [atm.SetProp("atomNote", "") for atm in mol.GetAtoms()]
         | 
| 106 | 
            +
                [mol.GetAtomWithIdx(idx).SetProp("atomNote", tag) for idx in atoms_to_tag]
         | 
| 107 | 
            +
                return mol
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            # ----------- Logging
         | 
| 111 | 
            +
            def rd_shut_the_hell_up() -> None:
         | 
| 112 | 
            +
                """Make the RDKit be a bit more quiet
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                :return: None
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                lg = RDLogger.logger()
         | 
| 117 | 
            +
                lg.setLevel(RDLogger.CRITICAL)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def demo_block_logs() -> None:
         | 
| 121 | 
            +
                """An example of another way to turn off RDKit logging
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                :return: None
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                block = BlockLogs()
         | 
| 126 | 
            +
                # do stuff
         | 
| 127 | 
            +
                del block
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            # ----------- Image generation
         | 
| 131 | 
            +
            def boxplot_base64_image(dist: np.ndarray, x_lim: list[int] = [0, 10]) -> str:
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
                Plot a distribution as a seaborn boxplot and save the resulting image as a base64 image.
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                Parameters:
         | 
| 136 | 
            +
                dist (np.ndarray): The distribution data to plot.
         | 
| 137 | 
            +
                x_lim (list[int]): The x-axis limits for the boxplot.
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                Returns:
         | 
| 140 | 
            +
                str: The base64 encoded image string.
         | 
| 141 | 
            +
                """
         | 
| 142 | 
            +
                sns.set(rc={"figure.figsize": (3, 1)})
         | 
| 143 | 
            +
                sns.set_style("whitegrid")
         | 
| 144 | 
            +
                ax = sns.boxplot(x=dist)
         | 
| 145 | 
            +
                ax.set_xlim(x_lim[0], x_lim[1])
         | 
| 146 | 
            +
                s = io.BytesIO()
         | 
| 147 | 
            +
                plt.savefig(s, format="png", bbox_inches="tight")
         | 
| 148 | 
            +
                plt.close()
         | 
| 149 | 
            +
                s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
         | 
| 150 | 
            +
                return '<img align="left" src="data:image/png;base64,%s">' % s
         | 
| 151 | 
            +
             | 
| 152 | 
            +
             | 
| 153 | 
            +
            def mol_to_base64_image(mol: Chem.Mol) -> str:
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                Convert an RDKit molecule to a base64 encoded image string.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                Parameters:
         | 
| 158 | 
            +
                mol (Chem.Mol): The RDKit molecule to convert.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                Returns:
         | 
| 161 | 
            +
                str: The base64 encoded image string.
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                drawer = rdMolDraw2D.MolDraw2DCairo(300, 150)
         | 
| 164 | 
            +
                drawer.DrawMolecule(mol)
         | 
| 165 | 
            +
                drawer.FinishDrawing()
         | 
| 166 | 
            +
                text = drawer.GetDrawingText()
         | 
| 167 | 
            +
                im_text64 = base64.b64encode(text).decode("utf8")
         | 
| 168 | 
            +
                img_str = f"<img src='data:image/png;base64, {im_text64}'/>"
         | 
| 169 | 
            +
                return img_str
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def cleanup_fragment(mol: Mol) -> Tuple[Mol, int]:
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
                Replace atom map numbers with Hydrogens
         | 
| 175 | 
            +
                :param mol: input molecule
         | 
| 176 | 
            +
                :return: modified molecule, number of R-groups
         | 
| 177 | 
            +
                """
         | 
| 178 | 
            +
                rgroup_count = 0
         | 
| 179 | 
            +
                for atm in mol.GetAtoms():
         | 
| 180 | 
            +
                    atm.SetAtomMapNum(0)
         | 
| 181 | 
            +
                    if atm.GetAtomicNum() == 0:
         | 
| 182 | 
            +
                        rgroup_count += 1
         | 
| 183 | 
            +
                        atm.SetAtomicNum(1)
         | 
| 184 | 
            +
                mol = Chem.RemoveAllHs(mol)
         | 
| 185 | 
            +
                return mol, rgroup_count
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def generate_fragments(mol: Mol) -> pd.DataFrame:
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
                Generate fragments using the RDKit
         | 
| 191 | 
            +
                :param mol: RDKit molecule
         | 
| 192 | 
            +
                :return: a Pandas dataframe with Scaffold SMILES, Number of Atoms, Number of R-Groups
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
                # Generate molecule fragments
         | 
| 195 | 
            +
                frag_list = FragmentMol(mol)
         | 
| 196 | 
            +
                # Flatten the output into a single list
         | 
| 197 | 
            +
                flat_frag_list = [x for x in itertools.chain(*frag_list) if x]
         | 
| 198 | 
            +
                # The output of Fragment mol is contained in single molecules.  Extract the largest fragment from each molecule
         | 
| 199 | 
            +
                flat_frag_list = [get_largest_fragment(x) for x in flat_frag_list]
         | 
| 200 | 
            +
                # Keep fragments where the number of atoms in the fragment is at least 2/3 of the number fragments in
         | 
| 201 | 
            +
                # input molecule
         | 
| 202 | 
            +
                num_mol_atoms = mol.GetNumAtoms()
         | 
| 203 | 
            +
                flat_frag_list = [x for x in flat_frag_list if x.GetNumAtoms() / num_mol_atoms > 0.67]
         | 
| 204 | 
            +
                # remove atom map numbers from the fragments
         | 
| 205 | 
            +
                flat_frag_list = [cleanup_fragment(x) for x in flat_frag_list]
         | 
| 206 | 
            +
                # Convert fragments to SMILES
         | 
| 207 | 
            +
                frag_smiles_list = [[Chem.MolToSmiles(x), x.GetNumAtoms(), y] for (x, y) in flat_frag_list]
         | 
| 208 | 
            +
                # Add the input molecule to the fragment list
         | 
| 209 | 
            +
                frag_smiles_list.append([Chem.MolToSmiles(mol), mol.GetNumAtoms(), 1])
         | 
| 210 | 
            +
                # Put the results into a Pandas dataframe
         | 
| 211 | 
            +
                frag_df = pd.DataFrame(frag_smiles_list, columns=["Scaffold", "NumAtoms", "NumRgroupgs"])
         | 
| 212 | 
            +
                # Remove duplicate fragments
         | 
| 213 | 
            +
                frag_df = frag_df.drop_duplicates("Scaffold")
         | 
| 214 | 
            +
                return frag_df
         | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            def find_scaffolds(df_in: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                Generate scaffolds for a set of molecules
         | 
| 220 | 
            +
                :param df_in: Pandas dataframe with [SMILES, Name, RDKit molecule] columns
         | 
| 221 | 
            +
                :return: dataframe with molecules and scaffolds, dataframe with unique scaffolds
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
                # Loop over molecules and generate fragments, fragments for each molecule are returned as a Pandas dataframe
         | 
| 224 | 
            +
                df_list = []
         | 
| 225 | 
            +
                for smiles, name, mol in df_in[["SMILES", "Name", "mol"]].values:
         | 
| 226 | 
            +
                    tmp_df = generate_fragments(mol).copy()
         | 
| 227 | 
            +
                    tmp_df["Name"] = name
         | 
| 228 | 
            +
                    tmp_df["SMILES"] = smiles
         | 
| 229 | 
            +
                    df_list.append(tmp_df)
         | 
| 230 | 
            +
                # Combine the list of dataframes into a single dataframe
         | 
| 231 | 
            +
                mol_df = pd.concat(df_list)
         | 
| 232 | 
            +
                # Collect scaffolds
         | 
| 233 | 
            +
                scaffold_list = []
         | 
| 234 | 
            +
                for k, v in mol_df.groupby("Scaffold"):
         | 
| 235 | 
            +
                    scaffold_list.append([k, len(v.Name.unique()), v.NumAtoms.values[0]])
         | 
| 236 | 
            +
                scaffold_df = pd.DataFrame(scaffold_list, columns=["Scaffold", "Count", "NumAtoms"])
         | 
| 237 | 
            +
                # Any fragment that occurs more times than the number of fragments can't be a scaffold
         | 
| 238 | 
            +
                num_df_rows = len(df_in)  # noqa: F841
         | 
| 239 | 
            +
                scaffold_df = scaffold_df.query(f"Count <= {num_df_rows}")
         | 
| 240 | 
            +
                # Sort scaffolds by frequency
         | 
| 241 | 
            +
                scaffold_df = scaffold_df.sort_values(["Count", "NumAtoms"], ascending=[False, False])
         | 
| 242 | 
            +
                return mol_df, scaffold_df
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            def get_molecules_with_scaffold(
         | 
| 246 | 
            +
                scaffold: str, mol_df: pd.DataFrame, activity_df: pd.DataFrame
         | 
| 247 | 
            +
            ) -> Tuple[List[str], pd.DataFrame]:
         | 
| 248 | 
            +
                """
         | 
| 249 | 
            +
                Associate molecules with scaffolds
         | 
| 250 | 
            +
                :param scaffold: scaffold SMILES
         | 
| 251 | 
            +
                :param mol_df: dataframe with molecules and scaffolds, returned by find_scaffolds()
         | 
| 252 | 
            +
                :param activity_df: dataframe with [SMILES, Name, pIC50] columns
         | 
| 253 | 
            +
                :return: list of core(s) with R-groups labeled, dataframe with [SMILES, Name, pIC50]
         | 
| 254 | 
            +
                """
         | 
| 255 | 
            +
                match_df = mol_df.query("Scaffold == @scaffold")
         | 
| 256 | 
            +
                merge_df = match_df.merge(activity_df, on=["SMILES", "Name"])
         | 
| 257 | 
            +
                scaffold_mol = Chem.MolFromSmiles(scaffold)
         | 
| 258 | 
            +
                rgroup_match, rgroup_miss = RGroupDecompose(scaffold_mol, merge_df.mol, asSmiles=True)
         | 
| 259 | 
            +
                if len(rgroup_match):
         | 
| 260 | 
            +
                    rgroup_df = pd.DataFrame(rgroup_match)
         | 
| 261 | 
            +
                    return rgroup_df.Core.unique(), merge_df[["SMILES", "Name", "pIC50"]]
         | 
| 262 | 
            +
                else:
         | 
| 263 | 
            +
                    return [], merge_df[["SMILES", "Name", "pIC50"]]
         | 
    	
        examples/Cheminformatics/chembl_api_uses.lynxkite.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        examples/Cheminformatics/chembl_tools.py
    ADDED
    
    | @@ -0,0 +1,206 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from lynxkite.core.ops import op
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            from chembl_webresource_client.new_client import new_client
         | 
| 4 | 
            +
            from rdkit import Chem
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            @op("LynxKite Graph Analytics", "chembl sim search")
         | 
| 8 | 
            +
            def similarity_to_dataframe(*, smiles: str, cutoff: int = 70) -> pd.DataFrame:
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                Run a ChEMBL similarity search and return the hits as a pandas DataFrame.
         | 
| 11 | 
            +
                If the SMILES is invalid or an error occurs, prints a message and returns
         | 
| 12 | 
            +
                an empty DataFrame with the expected columns.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Parameters
         | 
| 15 | 
            +
                ----------
         | 
| 16 | 
            +
                smiles : str
         | 
| 17 | 
            +
                    The SMILES string to search on.
         | 
| 18 | 
            +
                cutoff : int
         | 
| 19 | 
            +
                    The minimum Tanimoto similarity (0–100).
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                Returns
         | 
| 22 | 
            +
                -------
         | 
| 23 | 
            +
                pd.DataFrame
         | 
| 24 | 
            +
                    Columns: 'molecule_chembl_id', 'similarity'
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                # Prepare empty frame to return on error
         | 
| 27 | 
            +
                cols = ["molecule_chembl_id", "similarity"]
         | 
| 28 | 
            +
                empty_df = pd.DataFrame(columns=cols)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                # 1) Quick SMILES validation
         | 
| 31 | 
            +
                if Chem.MolFromSmiles(smiles) is None:
         | 
| 32 | 
            +
                    print("Please input a correct SMILES string.")
         | 
| 33 | 
            +
                    return empty_df
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                try:
         | 
| 36 | 
            +
                    # 2) Do the ChEMBL API call
         | 
| 37 | 
            +
                    similarity = new_client.similarity
         | 
| 38 | 
            +
                    results = similarity.filter(smiles=smiles, similarity=cutoff).only(cols)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # 3) Build DataFrame
         | 
| 41 | 
            +
                    data = list(results)
         | 
| 42 | 
            +
                    df = pd.DataFrame.from_records(data, columns=cols)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # 4) Inform if no hits
         | 
| 45 | 
            +
                    if df.empty:
         | 
| 46 | 
            +
                        print("No hits found for that SMILES at the given cutoff.")
         | 
| 47 | 
            +
                    return df
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                except Exception as e:
         | 
| 50 | 
            +
                    # Catch network errors, unexpected API replies, etc.
         | 
| 51 | 
            +
                    print("An error occurred during the similarity search.")
         | 
| 52 | 
            +
                    print("  Details:", str(e))
         | 
| 53 | 
            +
                    return empty_df
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            @op("LynxKite Graph Analytics", "chembl structure")
         | 
| 57 | 
            +
            def _chembl_structures(
         | 
| 58 | 
            +
                df: pd.DataFrame, *, id_col: str = "molecule_chembl_id", timeout: int = 5
         | 
| 59 | 
            +
            ) -> pd.DataFrame:
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                Given a DataFrame with a column of ChEMBL molecule IDs, append
         | 
| 62 | 
            +
                canonical SMILES, standard InChI, and standard InChIKey.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                Parameters
         | 
| 65 | 
            +
                ----------
         | 
| 66 | 
            +
                df : pd.DataFrame
         | 
| 67 | 
            +
                    Input DataFrame; must contain `id_col`.
         | 
| 68 | 
            +
                id_col : str
         | 
| 69 | 
            +
                    Name of the column in `df` that holds ChEMBL IDs (e.g. 'CHEMBL1234').
         | 
| 70 | 
            +
                timeout : int
         | 
| 71 | 
            +
                    How many seconds to wait for the API (not currently used by chembl client,
         | 
| 72 | 
            +
                    but reserved for future enhancements or custom wrappers).
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                Returns
         | 
| 75 | 
            +
                -------
         | 
| 76 | 
            +
                pd.DataFrame
         | 
| 77 | 
            +
                    A new DataFrame with three additional columns:
         | 
| 78 | 
            +
                      - smiles
         | 
| 79 | 
            +
                      - standard_inchi
         | 
| 80 | 
            +
                      - standard_inchi_key
         | 
| 81 | 
            +
                """
         | 
| 82 | 
            +
                # make a copy so we don’t modify in-place
         | 
| 83 | 
            +
                out = df.copy()
         | 
| 84 | 
            +
                # prepare new columns
         | 
| 85 | 
            +
                out["smiles"] = None
         | 
| 86 | 
            +
                out["standard_inchi"] = None
         | 
| 87 | 
            +
                out["standard_inchi_key"] = None
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                mol_client = new_client.molecule
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                for idx, chembl_id in out[id_col].items():
         | 
| 92 | 
            +
                    try:
         | 
| 93 | 
            +
                        # query ChEMBL for this molecule
         | 
| 94 | 
            +
                        res = mol_client.filter(chembl_id=chembl_id).only(
         | 
| 95 | 
            +
                            ["molecule_chembl_id", "molecule_structures"]
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
                        # filter() returns an iterable; grab first record if exists
         | 
| 98 | 
            +
                        rec = next(iter(res), None)
         | 
| 99 | 
            +
                        if rec and rec.get("molecule_structures"):
         | 
| 100 | 
            +
                            struct = rec["molecule_structures"]
         | 
| 101 | 
            +
                            out.at[idx, "smiles"] = struct.get("canonical_smiles")
         | 
| 102 | 
            +
                            out.at[idx, "standard_inchi"] = struct.get("standard_inchi")
         | 
| 103 | 
            +
                            out.at[idx, "standard_inchi_key"] = struct.get("standard_inchi_key")
         | 
| 104 | 
            +
                        else:
         | 
| 105 | 
            +
                            print(f"[Warning] No structure found for {chembl_id}")
         | 
| 106 | 
            +
                    except Exception as e:
         | 
| 107 | 
            +
                        print(f"[Error] Lookup failed for {chembl_id}: {e!s}")
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                return out
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            @op("LynxKite Graph Analytics", "get chembl drugs")
         | 
| 113 | 
            +
            def fetch_chembl_drugs(
         | 
| 114 | 
            +
                *, first_approval: int = 2000, development_phase: int = None
         | 
| 115 | 
            +
            ) -> pd.DataFrame:
         | 
| 116 | 
            +
                """
         | 
| 117 | 
            +
                Fetch drugs from ChEMBL matching the given USAN stem, approval year,
         | 
| 118 | 
            +
                and development phase, returning key fields as a DataFrame.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                Parameters
         | 
| 121 | 
            +
                ----------
         | 
| 122 | 
            +
                first_approval : int, optional
         | 
| 123 | 
            +
                    Only include drugs first approved in or after this year (default=1980).
         | 
| 124 | 
            +
                development_phase : int, optional
         | 
| 125 | 
            +
                    Only include drugs in this development phase (e.g. 2, 3, 4).
         | 
| 126 | 
            +
                    If None, do not filter by phase.
         | 
| 127 | 
            +
                usan_stem : str, optional
         | 
| 128 | 
            +
                    USAN stem to filter on (default="-azosin").
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                Returns
         | 
| 131 | 
            +
                -------
         | 
| 132 | 
            +
                pd.DataFrame
         | 
| 133 | 
            +
                    Columns:
         | 
| 134 | 
            +
                      - development_phase
         | 
| 135 | 
            +
                      - first_approval
         | 
| 136 | 
            +
                      - molecule_chembl_id
         | 
| 137 | 
            +
                      - synonyms
         | 
| 138 | 
            +
                      - usan_stem
         | 
| 139 | 
            +
                      - usan_stem_definition
         | 
| 140 | 
            +
                      - usan_year
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    If no results (or on error), returns an empty DataFrame with these columns.
         | 
| 143 | 
            +
                """
         | 
| 144 | 
            +
                cols = [
         | 
| 145 | 
            +
                    "development_phase",
         | 
| 146 | 
            +
                    "first_approval",
         | 
| 147 | 
            +
                    "molecule_chembl_id",
         | 
| 148 | 
            +
                    "synonyms",
         | 
| 149 | 
            +
                    "usan_stem",
         | 
| 150 | 
            +
                    "usan_stem_definition",
         | 
| 151 | 
            +
                    "usan_year",
         | 
| 152 | 
            +
                ]
         | 
| 153 | 
            +
                empty_df = pd.DataFrame(columns=cols)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # Validate inputs
         | 
| 156 | 
            +
                if first_approval is not None and not isinstance(first_approval, int):
         | 
| 157 | 
            +
                    print("Error: first_approval must be an integer year.")
         | 
| 158 | 
            +
                    return empty_df
         | 
| 159 | 
            +
                if development_phase is not None and not isinstance(development_phase, int):
         | 
| 160 | 
            +
                    print("Error: development_phase must be an integer.")
         | 
| 161 | 
            +
                    return empty_df
         | 
| 162 | 
            +
                # if not isinstance(usan_stem, str):
         | 
| 163 | 
            +
                #     print("Error: usan_stem must be a string.")
         | 
| 164 | 
            +
                #     return empty_df
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                try:
         | 
| 167 | 
            +
                    drug = new_client.drug
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    # apply approval-year filter
         | 
| 170 | 
            +
                    if first_approval is not None:
         | 
| 171 | 
            +
                        drug = drug.filter(first_approval__gte=first_approval)
         | 
| 172 | 
            +
                    # apply development-phase filter
         | 
| 173 | 
            +
                    if development_phase is not None:
         | 
| 174 | 
            +
                        drug = drug.filter(development_phase=development_phase)
         | 
| 175 | 
            +
                    # apply USAN stem filter
         | 
| 176 | 
            +
                    # drug = drug.filter(usan_stem=usan_stem)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    res = drug.only(cols)
         | 
| 179 | 
            +
                    df = pd.DataFrame(res, columns=cols)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    if df.empty:
         | 
| 182 | 
            +
                        print("No drugs found for those filters.")
         | 
| 183 | 
            +
                    return df
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                except Exception as e:
         | 
| 186 | 
            +
                    print("An error occurred during the ChEMBL query:")
         | 
| 187 | 
            +
                    print(" ", str(e))
         | 
| 188 | 
            +
                    return empty_df
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            @op("LynxKite Graph Analytics", "get bioactivity from uniprot")
         | 
| 192 | 
            +
            def fetch_chembl_bioactivity(*, uniprot_id: str = "Q9NZQ7"):
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
                Fetch bioactivity data from ChEMBL for a given UniProt ID.
         | 
| 195 | 
            +
                """
         | 
| 196 | 
            +
                target = new_client.target.filter(target_components__accession=uniprot_id)
         | 
| 197 | 
            +
                targets = list(target)
         | 
| 198 | 
            +
                if not targets:
         | 
| 199 | 
            +
                    return []
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                target_chembl_id = targets[0]["target_chembl_id"]
         | 
| 202 | 
            +
                activities = new_client.activity.filter(
         | 
| 203 | 
            +
                    target_chembl_id=target_chembl_id, standard_type__in=["IC50", "Ki", "Kd"]
         | 
| 204 | 
            +
                )
         | 
| 205 | 
            +
                df = pd.DataFrame(activities)
         | 
| 206 | 
            +
                return df
         | 
    	
        examples/Cheminformatics/cheminfo_tools.py
    CHANGED
    
    | @@ -16,6 +16,7 @@ from sklearn.ensemble import RandomForestRegressor | |
| 16 | 
             
            from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
         | 
| 17 | 
             
            from sklearn.model_selection import train_test_split
         | 
| 18 | 
             
            import numpy as np
         | 
|  | |
| 19 |  | 
| 20 |  | 
| 21 | 
             
            @op("LynxKite Graph Analytics", "View mol filter", view="matplotlib", slow=True)
         | 
| @@ -303,3 +304,612 @@ def build_qsar_model( | |
| 303 |  | 
| 304 | 
             
                print(f"Trained & saved QSAR model for '{fp_type}' → {model_file}")
         | 
| 305 | 
             
                return metrics_df
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 16 | 
             
            from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
         | 
| 17 | 
             
            from sklearn.model_selection import train_test_split
         | 
| 18 | 
             
            import numpy as np
         | 
| 19 | 
            +
            from rdkit.Chem import MACCSkeys
         | 
| 20 |  | 
| 21 |  | 
| 22 | 
             
            @op("LynxKite Graph Analytics", "View mol filter", view="matplotlib", slow=True)
         | 
|  | |
| 304 |  | 
| 305 | 
             
                print(f"Trained & saved QSAR model for '{fp_type}' → {model_file}")
         | 
| 306 | 
             
                return metrics_df
         | 
| 307 | 
            +
             | 
| 308 | 
            +
             | 
| 309 | 
            +
            def predict_with_ci(model, X, confidence=0.95):
         | 
| 310 | 
            +
                """
         | 
| 311 | 
            +
                Calculates predictions and confidence intervals for a RandomForestRegressor.
         | 
| 312 | 
            +
                (Implementation is the same as in the previous answer)
         | 
| 313 | 
            +
                """
         | 
| 314 | 
            +
                # Get predictions from each individual tree
         | 
| 315 | 
            +
                tree_preds = np.array([tree.predict(X) for tree in model.estimators_])
         | 
| 316 | 
            +
                # Calculate mean prediction
         | 
| 317 | 
            +
                y_pred_mean = np.mean(tree_preds, axis=0)
         | 
| 318 | 
            +
                # Calculate percentiles for confidence interval
         | 
| 319 | 
            +
                alpha = (1.0 - confidence) / 2.0
         | 
| 320 | 
            +
                lower_percentile = alpha * 100
         | 
| 321 | 
            +
                upper_percentile = (1.0 - alpha) * 100
         | 
| 322 | 
            +
                y_pred_lower = np.percentile(tree_preds, lower_percentile, axis=0)
         | 
| 323 | 
            +
                y_pred_upper = np.percentile(tree_preds, upper_percentile, axis=0)
         | 
| 324 | 
            +
                return y_pred_mean, y_pred_lower, y_pred_upper
         | 
| 325 | 
            +
             | 
| 326 | 
            +
             | 
| 327 | 
            +
            # --- End of predict_with_ci definition ---
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            @op("LynxKite Graph Analytics", "Train QSAR2")
         | 
| 331 | 
            +
            def build_qsar_model2(
         | 
| 332 | 
            +
                df: pd.DataFrame,
         | 
| 333 | 
            +
                *,
         | 
| 334 | 
            +
                smiles_col: str,
         | 
| 335 | 
            +
                target_col: str,
         | 
| 336 | 
            +
                fp_type: str,
         | 
| 337 | 
            +
                radius: int = 2,
         | 
| 338 | 
            +
                n_bits: int = 2048,
         | 
| 339 | 
            +
                test_size: float = 0.2,
         | 
| 340 | 
            +
                random_state: int = 42,
         | 
| 341 | 
            +
                out_dir: str = "Models",
         | 
| 342 | 
            +
                confidence: float = 0.95,
         | 
| 343 | 
            +
            ):
         | 
| 344 | 
            +
                """
         | 
| 345 | 
            +
                Train/save RandomForest QSAR model, returning the model and a results DataFrame.
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                The results DataFrame contains per-point data ('actual', 'predicted',
         | 
| 348 | 
            +
                'lower_ci', 'upper_ci', 'split') AND repeated summary metrics for each
         | 
| 349 | 
            +
                split ('split_R2', 'split_MAE', 'split_RMSE').
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                Parameters
         | 
| 352 | 
            +
                ----------
         | 
| 353 | 
            +
                (Parameters are the same as before)
         | 
| 354 | 
            +
                bundle : any
         | 
| 355 | 
            +
                table_name : str
         | 
| 356 | 
            +
                smiles_col : str
         | 
| 357 | 
            +
                target_col : str
         | 
| 358 | 
            +
                fp_type : str
         | 
| 359 | 
            +
                radius : int
         | 
| 360 | 
            +
                n_bits : int
         | 
| 361 | 
            +
                test_size : float
         | 
| 362 | 
            +
                random_state : int
         | 
| 363 | 
            +
                out_dir : str
         | 
| 364 | 
            +
                confidence : float, optional
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                Returns
         | 
| 367 | 
            +
                -------
         | 
| 368 | 
            +
                model : RandomForestRegressor
         | 
| 369 | 
            +
                    The trained QSAR model.
         | 
| 370 | 
            +
                results_df : pandas.DataFrame
         | 
| 371 | 
            +
                    DataFrame containing columns: 'actual', 'predicted', 'lower_ci',
         | 
| 372 | 
            +
                    'upper_ci', 'split', 'split_R2', 'split_MAE', 'split_RMSE'.
         | 
| 373 | 
            +
                    The metric columns repeat the overall metric for the corresponding split.
         | 
| 374 | 
            +
                """
         | 
| 375 | 
            +
                # Steps 1-5: Load data, split, featurize, split features, train model
         | 
| 376 | 
            +
                # (Code is identical to previous versions up to model training)
         | 
| 377 | 
            +
                # ... (load data, sanitize, split indices) ...
         | 
| 378 | 
            +
                # df = bundle.dfs.get(table_name)
         | 
| 379 | 
            +
                df = df.copy()
         | 
| 380 | 
            +
                if df is None:
         | 
| 381 | 
            +
                    raise KeyError("Table not found")
         | 
| 382 | 
            +
                df[target_col] = pd.to_numeric(df[target_col], errors="coerce")
         | 
| 383 | 
            +
                df.dropna(subset=[target_col, smiles_col], inplace=True)
         | 
| 384 | 
            +
                df["mol"] = df[smiles_col].apply(Chem.MolFromSmiles)
         | 
| 385 | 
            +
                df = df[df["mol"].notnull()].reset_index(drop=True)
         | 
| 386 | 
            +
                if df.empty:
         | 
| 387 | 
            +
                    raise ValueError("No valid molecules or targets")
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                indices = np.arange(len(df))
         | 
| 390 | 
            +
                train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                print(f"Featurizing using {fp_type}...")
         | 
| 393 | 
            +
                fps = []
         | 
| 394 | 
            +
                valid_indices = []
         | 
| 395 | 
            +
                for i, mol in enumerate(df["mol"]):
         | 
| 396 | 
            +
                    try:
         | 
| 397 | 
            +
                        # ... (fp generation logic as before) ...
         | 
| 398 | 
            +
                        if fp_type == "ecfp":
         | 
| 399 | 
            +
                            bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
         | 
| 400 | 
            +
                            current_n_bits = n_bits
         | 
| 401 | 
            +
                        elif fp_type == "rdkit":
         | 
| 402 | 
            +
                            bv = Chem.RDKFingerprint(mol, fpSize=n_bits)
         | 
| 403 | 
            +
                            current_n_bits = n_bits
         | 
| 404 | 
            +
                        elif fp_type == "torsion":
         | 
| 405 | 
            +
                            bv = AllChem.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=n_bits)
         | 
| 406 | 
            +
                            current_n_bits = n_bits
         | 
| 407 | 
            +
                        elif fp_type == "atompair":
         | 
| 408 | 
            +
                            bv = AllChem.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=n_bits)
         | 
| 409 | 
            +
                            current_n_bits = n_bits
         | 
| 410 | 
            +
                        elif fp_type == "maccs":
         | 
| 411 | 
            +
                            bv = MACCSkeys.GenMACCSKeys(mol)  # 167 bits
         | 
| 412 | 
            +
                            current_n_bits = 167
         | 
| 413 | 
            +
                        else:
         | 
| 414 | 
            +
                            raise ValueError(f"Unsupported fp type: '{fp_type}'")
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        arr = np.zeros((current_n_bits,), dtype=np.int8)
         | 
| 417 | 
            +
                        DataStructs.ConvertToNumpyArray(bv, arr)
         | 
| 418 | 
            +
                        fps.append(arr)
         | 
| 419 | 
            +
                        valid_indices.append(i)
         | 
| 420 | 
            +
                    except Exception as e:
         | 
| 421 | 
            +
                        print(f"Warning: Featurization failed index {i}. Skipping. Error: {e}")
         | 
| 422 | 
            +
                        continue
         | 
| 423 | 
            +
                if not fps:
         | 
| 424 | 
            +
                    raise ValueError("No molecules featurized.")
         | 
| 425 | 
            +
                X = np.vstack(fps)
         | 
| 426 | 
            +
                df_filtered = df.iloc[valid_indices].reset_index(drop=True)
         | 
| 427 | 
            +
                y = df_filtered[target_col].values
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                # original_indices_set = set(valid_indices)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                train_idx_filtered = [
         | 
| 432 | 
            +
                    i for i, original_idx in enumerate(valid_indices) if original_idx in train_idx
         | 
| 433 | 
            +
                ]
         | 
| 434 | 
            +
                test_idx_filtered = [
         | 
| 435 | 
            +
                    i for i, original_idx in enumerate(valid_indices) if original_idx in test_idx
         | 
| 436 | 
            +
                ]
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                X_train, y_train = X[train_idx_filtered], y[train_idx_filtered]
         | 
| 439 | 
            +
                X_test, y_test = X[test_idx_filtered], y[test_idx_filtered]
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                if X_train.shape[0] == 0 or X_test.shape[0] == 0:
         | 
| 442 | 
            +
                    raise ValueError("Train or test split empty after filtering.")
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                print("Training RandomForestRegressor...")
         | 
| 445 | 
            +
                model = RandomForestRegressor(random_state=random_state, n_jobs=-1)
         | 
| 446 | 
            +
                model.fit(X_train, y_train)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                # 6) Compute predictions and *summary* performance metrics
         | 
| 449 | 
            +
                print("Calculating predictions and metrics...")
         | 
| 450 | 
            +
                y_pred_train, lower_ci_train, upper_ci_train = predict_with_ci(model, X_train, confidence)
         | 
| 451 | 
            +
                y_pred_test, lower_ci_test, upper_ci_test = predict_with_ci(model, X_test, confidence)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                def _metrics(y_true, y_pred_mean):
         | 
| 454 | 
            +
                    # (Same helper function as before)
         | 
| 455 | 
            +
                    y_true = np.ravel(y_true)
         | 
| 456 | 
            +
                    y_pred_mean = np.ravel(y_pred_mean)
         | 
| 457 | 
            +
                    if len(y_true) == 0:
         | 
| 458 | 
            +
                        return {"R2": np.nan, "MAE": np.nan, "RMSE": np.nan}
         | 
| 459 | 
            +
                    mse = mean_squared_error(y_true, y_pred_mean)
         | 
| 460 | 
            +
                    return {
         | 
| 461 | 
            +
                        "R2": r2_score(y_true, y_pred_mean),
         | 
| 462 | 
            +
                        "MAE": mean_absolute_error(y_true, y_pred_mean),
         | 
| 463 | 
            +
                        "RMSE": np.sqrt(mse),
         | 
| 464 | 
            +
                    }
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                train_metrics_dict = _metrics(y_train, y_pred_train)
         | 
| 467 | 
            +
                test_metrics_dict = _metrics(y_test, y_pred_test)
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                # 7) Create results DataFrames and ADD metrics columns
         | 
| 470 | 
            +
                train_results = pd.DataFrame(
         | 
| 471 | 
            +
                    {
         | 
| 472 | 
            +
                        "actual": y_train,
         | 
| 473 | 
            +
                        "predicted": y_pred_train,
         | 
| 474 | 
            +
                        "lower_ci": lower_ci_train,
         | 
| 475 | 
            +
                        "upper_ci": upper_ci_train,
         | 
| 476 | 
            +
                        "split": "train",
         | 
| 477 | 
            +
                    }
         | 
| 478 | 
            +
                )
         | 
| 479 | 
            +
                # Add repeated metrics
         | 
| 480 | 
            +
                for metric, value in train_metrics_dict.items():
         | 
| 481 | 
            +
                    train_results[f"split_{metric}"] = value
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                test_results = pd.DataFrame(
         | 
| 484 | 
            +
                    {
         | 
| 485 | 
            +
                        "actual": y_test,
         | 
| 486 | 
            +
                        "predicted": y_pred_test,
         | 
| 487 | 
            +
                        "lower_ci": lower_ci_test,
         | 
| 488 | 
            +
                        "upper_ci": upper_ci_test,
         | 
| 489 | 
            +
                        "split": "test",
         | 
| 490 | 
            +
                    }
         | 
| 491 | 
            +
                )
         | 
| 492 | 
            +
                # Add repeated metrics
         | 
| 493 | 
            +
                for metric, value in test_metrics_dict.items():
         | 
| 494 | 
            +
                    test_results[f"split_{metric}"] = value
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                # Concatenate into the final DataFrame
         | 
| 497 | 
            +
                results_df = pd.concat([train_results, test_results], ignore_index=True)
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                # 8) Save the model (same as before)
         | 
| 500 | 
            +
                os.makedirs(out_dir, exist_ok=True)
         | 
| 501 | 
            +
                model_file = os.path.join(out_dir, f"qsar_model_{fp_type}.pkl")
         | 
| 502 | 
            +
                try:
         | 
| 503 | 
            +
                    with open(model_file, "wb") as fout:
         | 
| 504 | 
            +
                        pickle.dump(model, fout)
         | 
| 505 | 
            +
                    print(f"Trained & saved QSAR model for '{fp_type}' -> {model_file}")
         | 
| 506 | 
            +
                except Exception as e:
         | 
| 507 | 
            +
                    print(f"Error saving model to {model_file}: {e}")
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                return results_df
         | 
| 510 | 
            +
             | 
| 511 | 
            +
             | 
| 512 | 
            +
            @op("LynxKite Graph Analytics", "plot qsar", view="matplotlib")
         | 
| 513 | 
            +
            def plot_qsar(results_df: pd.DataFrame):
         | 
| 514 | 
            +
                """
         | 
| 515 | 
            +
                Plots actual vs. predicted values from a QSAR results DataFrame.
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                Requires a single positional argument: the results DataFrame. All other
         | 
| 518 | 
            +
                parameters are optional keyword arguments. It extracts summary metrics
         | 
| 519 | 
            +
                directly from columns ('split_R2', 'split_MAE', 'split_RMSE')
         | 
| 520 | 
            +
                expected within the results_df.
         | 
| 521 | 
            +
                """
         | 
| 522 | 
            +
                title = "QSAR Model Performance: Actual vs. Predicted"
         | 
| 523 | 
            +
                xlabel = "Actual Values"
         | 
| 524 | 
            +
                ylabel = "Predicted Values"
         | 
| 525 | 
            +
                show_metrics = True
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                if not isinstance(results_df, pd.DataFrame):
         | 
| 528 | 
            +
                    raise TypeError(
         | 
| 529 | 
            +
                        "plot_qsar() missing 1 required positional argument: 'results_df' or the provided argument is not a pandas DataFrame."
         | 
| 530 | 
            +
                    )
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                required_cols = ["actual", "predicted", "lower_ci", "upper_ci", "split"]
         | 
| 533 | 
            +
                if not all(col in results_df.columns for col in required_cols):
         | 
| 534 | 
            +
                    raise ValueError(f"Invalid 'results_df'. Must contain columns: {required_cols}")
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                metric_cols = ["split_R2", "split_MAE", "split_RMSE"]
         | 
| 537 | 
            +
                metrics_available = all(col in results_df.columns for col in metric_cols)
         | 
| 538 | 
            +
                if show_metrics and not metrics_available:
         | 
| 539 | 
            +
                    print(
         | 
| 540 | 
            +
                        f"Warning: Metrics display requested, but one or more metric columns ({metric_cols}) are missing in results_df."
         | 
| 541 | 
            +
                    )
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                # --- Prepare Data ---
         | 
| 544 | 
            +
                train_data = results_df[results_df["split"] == "train"]
         | 
| 545 | 
            +
                test_data = results_df[results_df["split"] == "test"]
         | 
| 546 | 
            +
                can_plot_train = not train_data.empty
         | 
| 547 | 
            +
                can_plot_test = not test_data.empty
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                if not can_plot_train and not can_plot_test:
         | 
| 550 | 
            +
                    print("Warning: Both training and test data subsets are empty. Cannot generate plot.")
         | 
| 551 | 
            +
                    return  # Exit function early if no data
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                # --- Create Plot (Internal Figure/Axes) ---
         | 
| 554 | 
            +
                fig, ax = plt.subplots(figsize=(8, 8))
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                # --- Plotting Logic ---
         | 
| 557 | 
            +
                # (Draws scatter, error bars, line, grid, labels, title, legend on 'ax')
         | 
| 558 | 
            +
                if can_plot_train:
         | 
| 559 | 
            +
                    train_error = [
         | 
| 560 | 
            +
                        train_data["predicted"] - train_data["lower_ci"],
         | 
| 561 | 
            +
                        train_data["upper_ci"] - train_data["predicted"],
         | 
| 562 | 
            +
                    ]
         | 
| 563 | 
            +
                    ax.scatter(
         | 
| 564 | 
            +
                        train_data["actual"],
         | 
| 565 | 
            +
                        train_data["predicted"],
         | 
| 566 | 
            +
                        label="Train",
         | 
| 567 | 
            +
                        alpha=0.6,
         | 
| 568 | 
            +
                        s=30,
         | 
| 569 | 
            +
                        edgecolors="w",
         | 
| 570 | 
            +
                        linewidth=0.5,
         | 
| 571 | 
            +
                    )
         | 
| 572 | 
            +
                    ax.errorbar(
         | 
| 573 | 
            +
                        train_data["actual"],
         | 
| 574 | 
            +
                        train_data["predicted"],
         | 
| 575 | 
            +
                        yerr=train_error,
         | 
| 576 | 
            +
                        fmt="none",
         | 
| 577 | 
            +
                        ecolor="tab:blue",
         | 
| 578 | 
            +
                        label="_nolegend_",
         | 
| 579 | 
            +
                        capsize=0,
         | 
| 580 | 
            +
                        elinewidth=1,
         | 
| 581 | 
            +
                    )
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                if can_plot_test:
         | 
| 584 | 
            +
                    test_error = [
         | 
| 585 | 
            +
                        test_data["predicted"] - test_data["lower_ci"],
         | 
| 586 | 
            +
                        test_data["upper_ci"] - test_data["predicted"],
         | 
| 587 | 
            +
                    ]
         | 
| 588 | 
            +
                    ax.scatter(
         | 
| 589 | 
            +
                        test_data["actual"],
         | 
| 590 | 
            +
                        test_data["predicted"],
         | 
| 591 | 
            +
                        label="Test",
         | 
| 592 | 
            +
                        alpha=0.8,
         | 
| 593 | 
            +
                        s=40,
         | 
| 594 | 
            +
                        edgecolors="w",
         | 
| 595 | 
            +
                        linewidth=0.5,
         | 
| 596 | 
            +
                    )
         | 
| 597 | 
            +
                    ax.errorbar(
         | 
| 598 | 
            +
                        test_data["actual"],
         | 
| 599 | 
            +
                        test_data["predicted"],
         | 
| 600 | 
            +
                        yerr=test_error,
         | 
| 601 | 
            +
                        fmt="none",
         | 
| 602 | 
            +
                        ecolor="tab:orange",
         | 
| 603 | 
            +
                        label="_nolegend_",
         | 
| 604 | 
            +
                        capsize=0,
         | 
| 605 | 
            +
                        elinewidth=1,
         | 
| 606 | 
            +
                    )
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                all_actual = results_df["actual"].dropna()
         | 
| 609 | 
            +
                all_pred_ci = pd.concat(
         | 
| 610 | 
            +
                    [results_df["predicted"], results_df["lower_ci"], results_df["upper_ci"]]
         | 
| 611 | 
            +
                ).dropna()
         | 
| 612 | 
            +
                all_values = pd.concat([all_actual, all_pred_ci]).dropna()
         | 
| 613 | 
            +
                if all_values.empty:
         | 
| 614 | 
            +
                    min_val, max_val = 0, 1
         | 
| 615 | 
            +
                else:
         | 
| 616 | 
            +
                    min_val, max_val = all_values.min(), all_values.max()
         | 
| 617 | 
            +
                    if min_val == max_val:
         | 
| 618 | 
            +
                        min_val -= 0.5
         | 
| 619 | 
            +
                        max_val += 0.5
         | 
| 620 | 
            +
                    padding = (max_val - min_val) * 0.05
         | 
| 621 | 
            +
                    min_val -= padding
         | 
| 622 | 
            +
                    max_val += padding
         | 
| 623 | 
            +
                ax.plot([min_val, max_val], [min_val, max_val], "k--", alpha=0.7, lw=1, label="y=x")
         | 
| 624 | 
            +
                ax.set_xlim(min_val, max_val)
         | 
| 625 | 
            +
                ax.set_ylim(min_val, max_val)
         | 
| 626 | 
            +
                ax.set_aspect("equal", adjustable="box")
         | 
| 627 | 
            +
                ax.grid(True, linestyle=":", alpha=0.6)
         | 
| 628 | 
            +
                ax.set_xlabel(xlabel)
         | 
| 629 | 
            +
                ax.set_ylabel(ylabel)
         | 
| 630 | 
            +
                ax.set_title(title)
         | 
| 631 | 
            +
                ax.legend(loc="lower right")
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                # --- Display Metrics Text ---
         | 
| 634 | 
            +
                if show_metrics and metrics_available:
         | 
| 635 | 
            +
                    # (Logic for extracting and formatting metrics text remains the same)
         | 
| 636 | 
            +
                    metrics_text = ""
         | 
| 637 | 
            +
                    try:
         | 
| 638 | 
            +
                        if can_plot_train:
         | 
| 639 | 
            +
                            train_metrics = train_data[metric_cols].iloc[0]
         | 
| 640 | 
            +
                            r2_tr = (
         | 
| 641 | 
            +
                                f"{train_metrics['split_R2']:.3f}"
         | 
| 642 | 
            +
                                if pd.notna(train_metrics["split_R2"])
         | 
| 643 | 
            +
                                else "N/A"
         | 
| 644 | 
            +
                            )
         | 
| 645 | 
            +
                            mae_tr = (
         | 
| 646 | 
            +
                                f"{train_metrics['split_MAE']:.3f}"
         | 
| 647 | 
            +
                                if pd.notna(train_metrics["split_MAE"])
         | 
| 648 | 
            +
                                else "N/A"
         | 
| 649 | 
            +
                            )
         | 
| 650 | 
            +
                            rmse_tr = (
         | 
| 651 | 
            +
                                f"{train_metrics['split_RMSE']:.3f}"
         | 
| 652 | 
            +
                                if pd.notna(train_metrics["split_RMSE"])
         | 
| 653 | 
            +
                                else "N/A"
         | 
| 654 | 
            +
                            )
         | 
| 655 | 
            +
                            metrics_text += f"Train: $R^2$={r2_tr}, MAE={mae_tr}, RMSE={rmse_tr}\n"
         | 
| 656 | 
            +
                        else:
         | 
| 657 | 
            +
                            metrics_text += "Train: N/A (No Data)\n"
         | 
| 658 | 
            +
                        if can_plot_test:
         | 
| 659 | 
            +
                            test_metrics = test_data[metric_cols].iloc[0]
         | 
| 660 | 
            +
                            r2_te = (
         | 
| 661 | 
            +
                                f"{test_metrics['split_R2']:.3f}"
         | 
| 662 | 
            +
                                if pd.notna(test_metrics["split_R2"])
         | 
| 663 | 
            +
                                else "N/A"
         | 
| 664 | 
            +
                            )
         | 
| 665 | 
            +
                            mae_te = (
         | 
| 666 | 
            +
                                f"{test_metrics['split_MAE']:.3f}"
         | 
| 667 | 
            +
                                if pd.notna(test_metrics["split_MAE"])
         | 
| 668 | 
            +
                                else "N/A"
         | 
| 669 | 
            +
                            )
         | 
| 670 | 
            +
                            rmse_te = (
         | 
| 671 | 
            +
                                f"{test_metrics['split_RMSE']:.3f}"
         | 
| 672 | 
            +
                                if pd.notna(test_metrics["split_RMSE"])
         | 
| 673 | 
            +
                                else "N/A"
         | 
| 674 | 
            +
                            )
         | 
| 675 | 
            +
                            metrics_text += f"Test:  $R^2$={r2_te}, MAE={mae_te}, RMSE={rmse_te}"
         | 
| 676 | 
            +
                        else:
         | 
| 677 | 
            +
                            metrics_text += "Test:  N/A (No Data)"
         | 
| 678 | 
            +
                        if metrics_text:
         | 
| 679 | 
            +
                            ax.text(
         | 
| 680 | 
            +
                                0.05,
         | 
| 681 | 
            +
                                0.95,
         | 
| 682 | 
            +
                                metrics_text.strip(),
         | 
| 683 | 
            +
                                transform=ax.transAxes,
         | 
| 684 | 
            +
                                fontsize=9,
         | 
| 685 | 
            +
                                verticalalignment="top",
         | 
| 686 | 
            +
                                bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.8),
         | 
| 687 | 
            +
                            )
         | 
| 688 | 
            +
                    except Exception as e:
         | 
| 689 | 
            +
                        print(f"An error occurred during metrics display: {e}")
         | 
| 690 | 
            +
                        ax.text(
         | 
| 691 | 
            +
                            0.05,
         | 
| 692 | 
            +
                            0.95,
         | 
| 693 | 
            +
                            "Error displaying metrics",
         | 
| 694 | 
            +
                            transform=ax.transAxes,
         | 
| 695 | 
            +
                            fontsize=9,
         | 
| 696 | 
            +
                            color="red",
         | 
| 697 | 
            +
                            verticalalignment="top",
         | 
| 698 | 
            +
                            bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.8),
         | 
| 699 | 
            +
                        )
         | 
| 700 | 
            +
             | 
| 701 | 
            +
             | 
| 702 | 
            +
            @op("LynxKite Graph Analytics", "plot qsar2", view="matplotlib")
         | 
| 703 | 
            +
            def plot_qsar2(results_df: pd.DataFrame):
         | 
| 704 | 
            +
                """
         | 
| 705 | 
            +
                Plots actual vs. predicted values resembling the example image.
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                Includes separate markers for train/test, y=x line, and parallel dashed
         | 
| 708 | 
            +
                error bands based on test set RMSE (optional). Does NOT use per-point CIs.
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                Handles displaying the plot via plt.show() or saving it to a file
         | 
| 711 | 
            +
                based on the `save_path` parameter. THIS FUNCTION DOES NOT RETURN ANY VALUE.
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                Parameters
         | 
| 714 | 
            +
                ----------
         | 
| 715 | 
            +
                results_df : pd.DataFrame
         | 
| 716 | 
            +
                    Mandatory input DataFrame. Must contain: 'actual', 'predicted', 'split'.
         | 
| 717 | 
            +
                    Should also contain 'split_RMSE' column for error bands and metrics display.
         | 
| 718 | 
            +
                title : str, optional
         | 
| 719 | 
            +
                xlabel : str, optional
         | 
| 720 | 
            +
                ylabel : str, optional
         | 
| 721 | 
            +
                rmse_multiplier_for_bands : float or None, optional
         | 
| 722 | 
            +
                    Determines the width of the dashed error bands (multiplier * test_RMSE).
         | 
| 723 | 
            +
                    Set to None to disable bands. Default is 1.0.
         | 
| 724 | 
            +
                show_metrics : bool, optional
         | 
| 725 | 
            +
                    Whether to display R2/MAE/RMSE text (requires metric columns). Default is True.
         | 
| 726 | 
            +
                save_path : str, optional
         | 
| 727 | 
            +
                    If provided, saves plot to this path. If None (default), displays plot.
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                Raises
         | 
| 730 | 
            +
                ------
         | 
| 731 | 
            +
                ValueError / TypeError : For invalid inputs.
         | 
| 732 | 
            +
                """
         | 
| 733 | 
            +
                COLOR_TRAIN = "royalblue"
         | 
| 734 | 
            +
                COLOR_TEST = "darkorange"  # Changed from red for potentially better contrast/appeal
         | 
| 735 | 
            +
                COLOR_PERFECT = "black"
         | 
| 736 | 
            +
                COLOR_BANDS = "dimgrey"  # Less prominent than the perfect line
         | 
| 737 | 
            +
                COLOR_GRID = "lightgrey"
         | 
| 738 | 
            +
                title = "QSAR Model Performance: Actual vs. Predicted"
         | 
| 739 | 
            +
                xlabel = "Actual Values"
         | 
| 740 | 
            +
                ylabel = "Predicted Values"
         | 
| 741 | 
            +
                # ci_alpha = 0.2
         | 
| 742 | 
            +
                show_metrics = True
         | 
| 743 | 
            +
                rmse_multiplier_for_bands = 1.0
         | 
| 744 | 
            +
                # --- Input Validation ---
         | 
| 745 | 
            +
                if not isinstance(results_df, pd.DataFrame):
         | 
| 746 | 
            +
                    raise TypeError("Input must be a pandas DataFrame.")
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                required_cols = ["actual", "predicted", "split"]
         | 
| 749 | 
            +
                if not all(col in results_df.columns for col in required_cols):
         | 
| 750 | 
            +
                    raise ValueError(f"DataFrame must contain columns: {required_cols}")
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                metric_cols = ["split_R2", "split_MAE", "split_RMSE"]
         | 
| 753 | 
            +
                metrics_available = all(col in results_df.columns for col in metric_cols)
         | 
| 754 | 
            +
                bands_possible = rmse_multiplier_for_bands is not None and "split_RMSE" in results_df.columns
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                if show_metrics and not metrics_available:
         | 
| 757 | 
            +
                    print(
         | 
| 758 | 
            +
                        f"Warning: Metrics display requested, but one or more metric columns ({metric_cols}) are missing."
         | 
| 759 | 
            +
                    )
         | 
| 760 | 
            +
                if rmse_multiplier_for_bands is not None and "split_RMSE" not in results_df.columns:
         | 
| 761 | 
            +
                    print("Warning: Error bands requested, but 'split_RMSE' column is missing.")
         | 
| 762 | 
            +
                    bands_possible = False
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                # --- Prepare Data ---
         | 
| 765 | 
            +
                train_data = results_df[results_df["split"] == "train"].copy()
         | 
| 766 | 
            +
                test_data = results_df[results_df["split"] == "test"].copy()
         | 
| 767 | 
            +
                can_plot_train = not train_data.empty
         | 
| 768 | 
            +
                can_plot_test = not test_data.empty
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                if not can_plot_train and not can_plot_test:
         | 
| 771 | 
            +
                    print("Warning: Both training and test data subsets are empty. Cannot generate plot.")
         | 
| 772 | 
            +
                    return
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                # --- Create Plot with Style ---
         | 
| 775 | 
            +
                plt.style.use("seaborn-v0_8-whitegrid")  # Use a cleaner base style
         | 
| 776 | 
            +
                fig, ax = plt.subplots(figsize=(8, 8))  # Slightly larger figure
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                # --- Plotting Logic ---
         | 
| 779 | 
            +
                # Scatter plots with enhanced style
         | 
| 780 | 
            +
                common_scatter_kws = {"s": 45, "alpha": 0.75, "edgecolor": "black", "linewidth": 0.5}
         | 
| 781 | 
            +
                if can_plot_train:
         | 
| 782 | 
            +
                    ax.scatter(
         | 
| 783 | 
            +
                        train_data["actual"],
         | 
| 784 | 
            +
                        train_data["predicted"],
         | 
| 785 | 
            +
                        label="Training set",
         | 
| 786 | 
            +
                        marker="o",
         | 
| 787 | 
            +
                        color=COLOR_TRAIN,
         | 
| 788 | 
            +
                        **common_scatter_kws,
         | 
| 789 | 
            +
                    )  # Blue circles
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                if can_plot_test:
         | 
| 792 | 
            +
                    ax.scatter(
         | 
| 793 | 
            +
                        test_data["actual"],
         | 
| 794 | 
            +
                        test_data["predicted"],
         | 
| 795 | 
            +
                        label="Test set",
         | 
| 796 | 
            +
                        marker="o",
         | 
| 797 | 
            +
                        color=COLOR_TEST,
         | 
| 798 | 
            +
                        **common_scatter_kws,
         | 
| 799 | 
            +
                    )  # Orange circles
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                # Determine plot limits
         | 
| 802 | 
            +
                # (Using the same logic as before to calculate min_val, max_val)
         | 
| 803 | 
            +
                all_actual = results_df["actual"].dropna()
         | 
| 804 | 
            +
                all_pred = results_df["predicted"].dropna()
         | 
| 805 | 
            +
                all_values = pd.concat([all_actual, all_pred]).dropna()
         | 
| 806 | 
            +
                if all_values.empty:
         | 
| 807 | 
            +
                    min_val, max_val = 0, 1
         | 
| 808 | 
            +
                else:
         | 
| 809 | 
            +
                    min_val, max_val = all_values.min(), all_values.max()
         | 
| 810 | 
            +
                    if min_val == max_val:
         | 
| 811 | 
            +
                        min_val -= 0.5
         | 
| 812 | 
            +
                        max_val += 0.5
         | 
| 813 | 
            +
                    data_range = max_val - min_val
         | 
| 814 | 
            +
                    if data_range == 0:
         | 
| 815 | 
            +
                        data_range = 1.0
         | 
| 816 | 
            +
                    padding = data_range * 0.10
         | 
| 817 | 
            +
                    min_val -= padding
         | 
| 818 | 
            +
                    max_val += padding
         | 
| 819 | 
            +
             | 
| 820 | 
            +
                # Plot y=x line (Solid Black, slightly thicker)
         | 
| 821 | 
            +
                ax.plot(
         | 
| 822 | 
            +
                    [min_val, max_val],
         | 
| 823 | 
            +
                    [min_val, max_val],
         | 
| 824 | 
            +
                    color=COLOR_PERFECT,
         | 
| 825 | 
            +
                    linestyle="-",
         | 
| 826 | 
            +
                    linewidth=1.5,
         | 
| 827 | 
            +
                    alpha=0.9,
         | 
| 828 | 
            +
                    label="_nolegend_",
         | 
| 829 | 
            +
                )
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                # Plot Error Bands based on Test RMSE (subtler style)
         | 
| 832 | 
            +
                rmse_test = np.nan
         | 
| 833 | 
            +
                if bands_possible and can_plot_test:
         | 
| 834 | 
            +
                    try:
         | 
| 835 | 
            +
                        rmse_test = test_data["split_RMSE"].dropna().iloc[0]
         | 
| 836 | 
            +
                        if pd.notna(rmse_test) and rmse_test >= 0:
         | 
| 837 | 
            +
                            margin = rmse_multiplier_for_bands * rmse_test
         | 
| 838 | 
            +
                            band_label = (
         | 
| 839 | 
            +
                                f"$\pm {rmse_multiplier_for_bands}\,$RMSE"
         | 
| 840 | 
            +
                                if rmse_multiplier_for_bands == 1
         | 
| 841 | 
            +
                                else f"$\pm {rmse_multiplier_for_bands}\,$RMSE"
         | 
| 842 | 
            +
                            )
         | 
| 843 | 
            +
                            ax.plot(
         | 
| 844 | 
            +
                                [min_val, max_val],
         | 
| 845 | 
            +
                                [min_val + margin, max_val + margin],
         | 
| 846 | 
            +
                                color=COLOR_BANDS,
         | 
| 847 | 
            +
                                linestyle="--",
         | 
| 848 | 
            +
                                linewidth=1.0,
         | 
| 849 | 
            +
                                alpha=0.7,
         | 
| 850 | 
            +
                                label=band_label,
         | 
| 851 | 
            +
                            )  # Grey dashed
         | 
| 852 | 
            +
                            ax.plot(
         | 
| 853 | 
            +
                                [min_val, max_val],
         | 
| 854 | 
            +
                                [min_val - margin, max_val - margin],
         | 
| 855 | 
            +
                                color=COLOR_BANDS,
         | 
| 856 | 
            +
                                linestyle="--",
         | 
| 857 | 
            +
                                linewidth=1.0,
         | 
| 858 | 
            +
                                alpha=0.7,
         | 
| 859 | 
            +
                                label="_nolegend_",
         | 
| 860 | 
            +
                            )  # Grey dashed
         | 
| 861 | 
            +
                        # else: print("Warning: Could not plot error bands (Invalid Test RMSE).") # Optionally silent
         | 
| 862 | 
            +
                    except Exception as e:
         | 
| 863 | 
            +
                        print(f"Warning: Could not plot error bands: {e}")
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                # Set limits and aspect ratio
         | 
| 866 | 
            +
                ax.set_xlim(min_val, max_val)
         | 
| 867 | 
            +
                ax.set_ylim(min_val, max_val)
         | 
| 868 | 
            +
                ax.set_aspect("equal", adjustable="box")
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                # ADD BACK Grid (Subtle Style)
         | 
| 871 | 
            +
                ax.grid(True, which="both", linestyle=":", linewidth=0.7, color=COLOR_GRID, alpha=0.7)
         | 
| 872 | 
            +
                # Ensure grid is behind data points
         | 
| 873 | 
            +
                ax.set_axisbelow(True)
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                # Set Labels and Title (using specified arguments)
         | 
| 876 | 
            +
                ax.set_xlabel(xlabel, fontsize=12)
         | 
| 877 | 
            +
                ax.set_ylabel(ylabel, fontsize=12)
         | 
| 878 | 
            +
                ax.set_title(title, fontsize=15, pad=15, weight="semibold")  # Slightly larger title
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                # Enhance Legend
         | 
| 881 | 
            +
                ax.legend(loc="best", frameon=True, framealpha=0.85, fontsize=10, shadow=False)
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                # --- Display Metrics Text (Optional) ---
         | 
| 884 | 
            +
                if show_metrics and metrics_available:
         | 
| 885 | 
            +
                    # (Logic for extracting and formatting metrics text remains the same)
         | 
| 886 | 
            +
                    metrics_text = ""
         | 
| 887 | 
            +
                    try:
         | 
| 888 | 
            +
                        if can_plot_train:
         | 
| 889 | 
            +
                            train_metrics = train_data[metric_cols].dropna().iloc[0]  # Ensure using valid row
         | 
| 890 | 
            +
                            r2_tr = f"{train_metrics['split_R2']:.3f}"
         | 
| 891 | 
            +
                            mae_tr = f"{train_metrics['split_MAE']:.3f}"
         | 
| 892 | 
            +
                            rmse_tr = f"{train_metrics['split_RMSE']:.3f}"
         | 
| 893 | 
            +
                            metrics_text += f"Train: $R^2$={r2_tr}, MAE={mae_tr}, RMSE={rmse_tr}\n"
         | 
| 894 | 
            +
                        else:
         | 
| 895 | 
            +
                            metrics_text += "Train: N/A\n"
         | 
| 896 | 
            +
                        if can_plot_test:
         | 
| 897 | 
            +
                            test_metrics = test_data[metric_cols].dropna().iloc[0]  # Ensure using valid row
         | 
| 898 | 
            +
                            r2_te = f"{test_metrics['split_R2']:.3f}"
         | 
| 899 | 
            +
                            mae_te = f"{test_metrics['split_MAE']:.3f}"
         | 
| 900 | 
            +
                            rmse_te = f"{test_metrics['split_RMSE']:.3f}"
         | 
| 901 | 
            +
                            metrics_text += f"Test:  $R^2$={r2_te}, MAE={mae_te}, RMSE={rmse_te}"
         | 
| 902 | 
            +
                        else:
         | 
| 903 | 
            +
                            metrics_text += "Test:  N/A"
         | 
| 904 | 
            +
                        if metrics_text:
         | 
| 905 | 
            +
                            ax.text(
         | 
| 906 | 
            +
                                0.05,
         | 
| 907 | 
            +
                                0.95,
         | 
| 908 | 
            +
                                metrics_text.strip(),
         | 
| 909 | 
            +
                                transform=ax.transAxes,
         | 
| 910 | 
            +
                                fontsize=9,
         | 
| 911 | 
            +
                                verticalalignment="top",
         | 
| 912 | 
            +
                                bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
         | 
| 913 | 
            +
                            )  # Adjusted box slightly
         | 
| 914 | 
            +
                    except Exception as e:
         | 
| 915 | 
            +
                        print(f"An error occurred during metrics display: {e}")
         | 
    	
        examples/Cheminformatics/qsar_example.lynxkite.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        examples/draw_molecules.py
    DELETED
    
    | @@ -1,29 +0,0 @@ | |
| 1 | 
            -
            from lynxkite.core.ops import op
         | 
| 2 | 
            -
            import pandas as pd
         | 
| 3 | 
            -
            import base64
         | 
| 4 | 
            -
            import io
         | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
            def pil_to_data(image):
         | 
| 8 | 
            -
                buffer = io.BytesIO()
         | 
| 9 | 
            -
                image.save(buffer, format="png")
         | 
| 10 | 
            -
                b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
         | 
| 11 | 
            -
                return "data:image/png;base64," + b64
         | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
            def smiles_to_data(smiles):
         | 
| 15 | 
            -
                import rdkit
         | 
| 16 | 
            -
             | 
| 17 | 
            -
                m = rdkit.Chem.MolFromSmiles(smiles)
         | 
| 18 | 
            -
                if m is None:
         | 
| 19 | 
            -
                    return None
         | 
| 20 | 
            -
                img = rdkit.Chem.Draw.MolToImage(m)
         | 
| 21 | 
            -
                data = pil_to_data(img)
         | 
| 22 | 
            -
                return data
         | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
            @op("LynxKite Graph Analytics", "Draw molecules")
         | 
| 26 | 
            -
            def draw_molecules(df: pd.DataFrame, *, smiles_column: str, image_column: str = "image"):
         | 
| 27 | 
            -
                df = df.copy()
         | 
| 28 | 
            -
                df[image_column] = df[smiles_column].apply(smiles_to_data)
         | 
| 29 | 
            -
                return df
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        examples/requirements.txt
    CHANGED
    
    | @@ -1,3 +1,6 @@ | |
| 1 | 
             
            # Example of a requirements.txt file. LynxKite will automatically install anything you put here.
         | 
| 2 | 
             
            faker
         | 
| 3 | 
             
            matplotlib
         | 
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            # Example of a requirements.txt file. LynxKite will automatically install anything you put here.
         | 
| 2 | 
             
            faker
         | 
| 3 | 
             
            matplotlib
         | 
| 4 | 
            +
            chembl_webresource_client
         | 
| 5 | 
            +
            rcsb-api
         | 
| 6 | 
            +
            itertools
         |