File size: 3,436 Bytes
f3ea76e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d2cb3
 
 
 
 
 
1ee0fa8
00d2cb3
1ee0fa8
 
 
00d2cb3
1ee0fa8
f3ea76e
 
 
00d2cb3
f3ea76e
 
 
 
 
1ee0fa8
 
00d2cb3
f3ea76e
00d2cb3
f3ea76e
 
 
 
00d2cb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
            
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis_r", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")
    plt.show()

    # Save the plot to a temporary file and return the file path
    temp_file = "temp_heatmap.png"
    plt.savefig(temp_file)
    plt.close()
    return temp_file

def heatmap_interface(sequence, start, end=None):
    # Convert start and end to integers
    start = int(start)
    if end is not None:
        end = int(end)

    # If end is None or greater than sequence length, set it to sequence length
    if end is None or end > len(sequence) or end <= 0:
        end = len(sequence)

    # Ensure start is within bounds
    if start < 1 or start > len(sequence):
        return "Start position is out of bounds."

    # Generate heatmap
    heatmap_path = generate_heatmap(sequence, start, end)
    return heatmap_path

# Define the Gradio interface
iface = gr.Interface(
    fn=heatmap_interface,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."),
        gr.Number(label="Start Position", value=1),
        gr.Number(label="End Position")  # No default value needed
    ],
    outputs="image",
    live=True
)

# Run the Gradio app
iface.launch()