Yanisadel commited on
Commit
eafaef3
·
verified ·
1 Parent(s): c373f42

Upload inference_example.ipynb

Browse files
Files changed (1) hide show
  1. inference_example.ipynb +140 -0
inference_example.ipynb ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from typing import List, Union\n",
10
+ "\n",
11
+ "import torch\n",
12
+ "from transformers import AutoModel"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "metadata": {},
18
+ "source": [
19
+ "# Load model"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "model = AutoModel.from_pretrained(\"InstaDeepAI/segment_enformer\", trust_remote_code=True)"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "metadata": {},
34
+ "source": [
35
+ "# Define useful functions"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def encode_sequences(sequences: Union[str, List[str]]) -> torch.Tensor:\n",
45
+ " \"\"\"\n",
46
+ " One-hot encode a DNA sequence or a batch of DNA sequences.\n",
47
+ "\n",
48
+ " Args:\n",
49
+ " sequences (Union[str, List[str]]): Either a DNA sequence or a list of DNA sequences\n",
50
+ "\n",
51
+ " Returns:\n",
52
+ " torch.Tensor: One-hot encoded\n",
53
+ " - If `sequences` is just one sequence (str), output shape is (196608, 4)\n",
54
+ " - If `sequences` is a list of sequences, output shape is (num_sequences, 196608, 4)\n",
55
+ " \n",
56
+ " \"\"\"\n",
57
+ " one_hot_map = {\n",
58
+ " 'a': torch.tensor([1., 0., 0., 0.]),\n",
59
+ " 'c': torch.tensor([0., 1., 0., 0.]),\n",
60
+ " 'g': torch.tensor([0., 0., 1., 0.]),\n",
61
+ " 't': torch.tensor([0., 0., 0., 1.]),\n",
62
+ " 'n': torch.tensor([0., 0., 0., 0.]),\n",
63
+ " 'A': torch.tensor([1., 0., 0., 0.]),\n",
64
+ " 'C': torch.tensor([0., 1., 0., 0.]),\n",
65
+ " 'G': torch.tensor([0., 0., 1., 0.]),\n",
66
+ " 'T': torch.tensor([0., 0., 0., 1.]),\n",
67
+ " 'N': torch.tensor([0., 0., 0., 0.])\n",
68
+ " }\n",
69
+ "\n",
70
+ " def encode_sequence(seq_str):\n",
71
+ " one_hot_list = []\n",
72
+ " for char in seq_str:\n",
73
+ " one_hot_vector = one_hot_map.get(char, torch.tensor([0.25, 0.25, 0.25, 0.25]))\n",
74
+ " one_hot_list.append(one_hot_vector)\n",
75
+ " return torch.stack(one_hot_list)\n",
76
+ "\n",
77
+ " if isinstance(sequences, list):\n",
78
+ " return torch.stack([encode_sequence(seq) for seq in sequences])\n",
79
+ " else:\n",
80
+ " return encode_sequence(sequences)"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "markdown",
85
+ "metadata": {},
86
+ "source": [
87
+ "# Inference example"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "sequences = [\"A\"*196608, \"G\"*196608]\n",
97
+ "one_hot_encoding = encode_sequences(sequences)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "preds = model(one_hot_encoding)"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "print(preds['logits'])"
116
+ ]
117
+ }
118
+ ],
119
+ "metadata": {
120
+ "kernelspec": {
121
+ "display_name": "genomics-research-env",
122
+ "language": "python",
123
+ "name": "python3"
124
+ },
125
+ "language_info": {
126
+ "codemirror_mode": {
127
+ "name": "ipython",
128
+ "version": 3
129
+ },
130
+ "file_extension": ".py",
131
+ "mimetype": "text/x-python",
132
+ "name": "python",
133
+ "nbconvert_exporter": "python",
134
+ "pygments_lexer": "ipython3",
135
+ "version": "3.11.10"
136
+ }
137
+ },
138
+ "nbformat": 4,
139
+ "nbformat_minor": 2
140
+ }