Spaces:
Sleeping
Sleeping
File size: 5,075 Bytes
b3237f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import joblib
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
import numpy as np
import random
import tensorflow as tf
import os
from keras.layers import TFSMLayer
print(f"TensorFlow Version: {tf.__version__}")
base_dir = "."
# Set random seed
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_model(model_path):
print(f"Loading model from {model_path}...")
#print(f"Loading model from {model_path} using TFSMLayer...")
#return TFSMLayer(model_path, call_endpoint="serving_default")
#return tf.keras.models.load_model(model_path)
return tf.saved_model.load(model_path)
# Load Random Forest models and configurations
print("Loading models...")
plant_models = {
"Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
"kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11},
"KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4},
}
general_models = {
"Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33},
"kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7},
"KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26},
}
# Function to generate embeddings
def get_embedding(sequence, esm_model_name, layer):
print(f"Generating embeddings using {esm_model_name}, Layer {layer}...")
tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True)
# Tokenize the sequence
inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
hidden_states = outputs.hidden_states # Retrieve all hidden states
embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
return embedding
def predict_with_gpflow(model, X):
# Convert input to TensorFlow tensor
X_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
# Get predictions
predict_fn = model.predict_f_compiled
mean, variance = predict_fn(X_tensor)
# Return mean and variance as numpy arrays
return mean.numpy().flatten(), variance.numpy().flatten()
# Function to predict based on user choice
def predict(sequence, prediction_type):
# Select the appropriate model set
selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
def process_target(target):
esm_model_name = selected_models[target]["esm_model"]
layer = selected_models[target]["layer"]
model = selected_models[target]["model"]
# Generate embedding
embedding = get_embedding(sequence, esm_model_name, layer)
if prediction_type == "Plant-Specific":
# Random Forest prediction
prediction = model.predict(embedding)[0]
return target, round(prediction, 2)
else:
# GPflow prediction
mean, variance = predict_with_gpflow(model, embedding)
return target, round(mean[0], 2), round(variance[0], 2)
# Predict for all targets in parallel
with ThreadPoolExecutor() as executor:
results = list(executor.map(process_target, selected_models.keys()))
# Format results
if prediction_type == "Plant-Specific":
formatted_results = [
["Specificity", results[0][1]],
["kcat\u1d9c", results[1][1]],
["K\u1d9c", results[2][1]],
]
else:
formatted_results = [
["Specificity", results[0][1], results[0][2]],
["kcat\u1d9c", results[1][1], results[1][2]],
["K\u1d9c", results[2][1], results[2][2]],
]
return formatted_results
# Define Gradio interface
print("Creating Gradio interface...")
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Input Protein Sequence"), # Input: Text box for sequence
gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection
],
outputs=gr.Dataframe(
headers=["Target", "Prediction", "Uncertainty (for General)"],
type="array"
), # Output: Table
title="Rubisco Kinetics Prediction",
description=(
"Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). "
"Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions."
),
)
if __name__ == "__main__":
interface.launch()
|