PiKaHa commited on
Commit
e372ce4
·
1 Parent(s): b3237f2

Update app.py with transformer embeddings and prediction pipeline

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -29,14 +29,9 @@ torch.backends.cudnn.benchmark = False
29
 
30
  def load_model(model_path):
31
  print(f"Loading model from {model_path}...")
32
- #print(f"Loading model from {model_path} using TFSMLayer...")
33
- #return TFSMLayer(model_path, call_endpoint="serving_default")
34
- #return tf.keras.models.load_model(model_path)
35
  return tf.saved_model.load(model_path)
36
 
37
 
38
-
39
- # Load Random Forest models and configurations
40
  print("Loading models...")
41
  plant_models = {
42
  "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
@@ -66,7 +61,11 @@ def get_embedding(sequence, esm_model_name, layer):
66
  hidden_states = outputs.hidden_states # Retrieve all hidden states
67
  embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
68
 
69
- return embedding
 
 
 
 
70
 
71
 
72
  def predict_with_gpflow(model, X):
@@ -79,27 +78,35 @@ def predict_with_gpflow(model, X):
79
 
80
  # Return mean and variance as numpy arrays
81
  return mean.numpy().flatten(), variance.numpy().flatten()
82
- # Function to predict based on user choice
83
- def predict(sequence, prediction_type):
84
- # Select the appropriate model set
85
- selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
86
 
87
- def process_target(target):
88
- esm_model_name = selected_models[target]["esm_model"]
89
- layer = selected_models[target]["layer"]
90
- model = selected_models[target]["model"]
91
 
92
- # Generate embedding
93
- embedding = get_embedding(sequence, esm_model_name, layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- if prediction_type == "Plant-Specific":
96
- # Random Forest prediction
97
- prediction = model.predict(embedding)[0]
98
- return target, round(prediction, 2)
99
- else:
100
- # GPflow prediction
101
- mean, variance = predict_with_gpflow(model, embedding)
102
- return target, round(mean[0], 2), round(variance[0], 2)
103
 
104
  # Predict for all targets in parallel
105
  with ThreadPoolExecutor() as executor:
@@ -121,6 +128,7 @@ def predict(sequence, prediction_type):
121
 
122
  return formatted_results
123
 
 
124
  # Define Gradio interface
125
  print("Creating Gradio interface...")
126
  interface = gr.Interface(
 
29
 
30
  def load_model(model_path):
31
  print(f"Loading model from {model_path}...")
 
 
 
32
  return tf.saved_model.load(model_path)
33
 
34
 
 
 
35
  print("Loading models...")
36
  plant_models = {
37
  "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6},
 
61
  hidden_states = outputs.hidden_states # Retrieve all hidden states
62
  embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling
63
 
64
+ # Convert to DataFrame with named columns
65
+ feature_columns = {f"D{i+1}": embedding[0, i] for i in range(embedding.shape[1])}
66
+ embedding_df = pd.DataFrame([feature_columns])
67
+
68
+ return embedding_df.values, embedding_df
69
 
70
 
71
  def predict_with_gpflow(model, X):
 
78
 
79
  # Return mean and variance as numpy arrays
80
  return mean.numpy().flatten(), variance.numpy().flatten()
 
 
 
 
81
 
 
 
 
 
82
 
83
+ def process_target(target):
84
+ """
85
+ Process a single target for prediction using transformer embeddings and the specified model.
86
+ """
87
+ # Get model and embedding details
88
+ esm_model_name = selected_models[target]["esm_model"]
89
+ layer = selected_models[target]["layer"]
90
+ model = selected_models[target]["model"]
91
+
92
+ # Generate embeddings in the required format
93
+ embedding, _ = get_embedding(sequence, esm_model_name, layer)
94
+
95
+ if prediction_type == "Plant-Specific":
96
+ # Random Forest prediction
97
+ y_pred = model.predict(embedding)[0]
98
+ return target, round(y_pred, 2)
99
+ else:
100
+ # GPflow prediction
101
+ y_pred, y_uncertainty = predict_with_gpflow(model, embedding)
102
+ return target, round(y_pred[0], 2), round(y_uncertainty[0], 2)
103
 
104
+ def predict(sequence, prediction_type):
105
+ """
106
+ Predicts Specificity, kcatC, and KC for the given sequence and prediction type.
107
+ """
108
+ # Select the appropriate model set
109
+ selected_models = plant_models if prediction_type == "Plant-Specific" else general_models
 
 
110
 
111
  # Predict for all targets in parallel
112
  with ThreadPoolExecutor() as executor:
 
128
 
129
  return formatted_results
130
 
131
+
132
  # Define Gradio interface
133
  print("Creating Gradio interface...")
134
  interface = gr.Interface(