File size: 5,075 Bytes
b3237f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import joblib
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
import numpy as np
import random
import tensorflow as tf
import os
from keras.layers import TFSMLayer

print(f"TensorFlow Version: {tf.__version__}")

base_dir = "."

# Set random seed
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def load_model(model_path):
    print(f"Loading model from {model_path}...")
    #print(f"Loading model from {model_path} using TFSMLayer...")
    #return TFSMLayer(model_path, call_endpoint="serving_default")
    #return tf.keras.models.load_model(model_path)
    return tf.saved_model.load(model_path)



# Load Random Forest models and configurations
print("Loading models...")
plant_models = {
    "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
    "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
    "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
}

general_models = {
    "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
    "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
    "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
}


# Function to generate embeddings
def get_embedding(sequence, esm_model_name, layer):
    print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
    tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
    model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)

    # Tokenize the sequence
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)

    # Generate embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        hidden_states = outputs.hidden_states  # Retrieve all hidden states
        embedding = hidden_states[layer].mean(dim=1).numpy()  # Average pooling

    return embedding


def predict_with_gpflow(model, X):
    # Convert input to TensorFlow tensor
    X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)

    # Get predictions
    predict_fn = model.predict_f_compiled
    mean, variance = predict_fn(X_tensor)

    # Return mean and variance as numpy arrays
    return mean.numpy().flatten(), variance.numpy().flatten()
# Function to predict based on user choice
def predict(sequence, prediction_type):
    # Select the appropriate model set
    selected_models = plant_models if prediction_type == "Plant-Specific" else general_models

    def process_target(target):
        esm_model_name = selected_models[target]["esm_model"]
        layer = selected_models[target]["layer"]
        model = selected_models[target]["model"]

        # Generate embedding
        embedding = get_embedding(sequence, esm_model_name, layer)

        if prediction_type == "Plant-Specific":
            # Random Forest prediction
            prediction = model.predict(embedding)[0]
            return target, round(prediction, 2)
        else:
            # GPflow prediction
            mean, variance = predict_with_gpflow(model, embedding)
            return target, round(mean[0], 2), round(variance[0], 2)

    # Predict for all targets in parallel
    with ThreadPoolExecutor() as executor:
        results = list(executor.map(process_target, selected_models.keys()))

    # Format results
    if prediction_type == "Plant-Specific":
        formatted_results = [
            ["Specificity", results[0][1]],
            ["kcat\u1d9c", results[1][1]],
            ["K\u1d9c", results[2][1]],
        ]
    else:
        formatted_results = [
            ["Specificity", results[0][1], results[0][2]],
            ["kcat\u1d9c", results[1][1], results[1][2]],
            ["K\u1d9c", results[2][1], results[2][2]],
        ]

    return formatted_results

# Define Gradio interface
print("Creating Gradio interface...")
interface = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="Input Protein Sequence"),  # Input: Text box for sequence
        gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"),  # Dropdown for selection
    ],
    outputs=gr.Dataframe(
        headers=["Target", "Prediction", "Uncertainty (for General)"], 
        type="array"
    ),  # Output: Table
    title="Rubisco Kinetics Prediction",
    description=(
        "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
        "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
    ),
)

if __name__ == "__main__":
    interface.launch()