neelimapreeti297 commited on
Commit
f394198
·
verified ·
1 Parent(s): 96da389

Upload germanToEnglish.ipynb

Browse files
Files changed (1) hide show
  1. germanToEnglish.ipynb +1205 -0
germanToEnglish.ipynb ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "wsIPzMNfW3QH"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "from torchtext.data.utils import get_tokenizer\n",
12
+ "from torchtext.vocab import build_vocab_from_iterator\n",
13
+ "from torchtext.datasets import multi30k, Multi30k\n",
14
+ "from typing import Iterable, List\n",
15
+ "\n",
16
+ "\n",
17
+ "# We need to modify the URLs for the dataset since the links to the original dataset are broken\n",
18
+ "# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info\n",
19
+ "multi30k.URL[\"train\"] = \"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz\"\n",
20
+ "multi30k.URL[\"valid\"] = \"https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz\"\n",
21
+ "\n",
22
+ "SRC_LANGUAGE = 'de'\n",
23
+ "TGT_LANGUAGE = 'en'\n",
24
+ "\n",
25
+ "# Place-holders\n",
26
+ "token_transform = {}\n",
27
+ "vocab_transform = {}"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {
34
+ "colab": {
35
+ "base_uri": "https://localhost:8080/"
36
+ },
37
+ "id": "T8LEEOd2r-PV",
38
+ "outputId": "33e10bf6-dd1f-4760-ae2a-5fffd2996edb"
39
+ },
40
+ "outputs": [
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount(\"/gdrive\", force_remount=True).\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "from google.colab import drive\n",
51
+ "drive.mount('/gdrive')"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {
58
+ "colab": {
59
+ "base_uri": "https://localhost:8080/"
60
+ },
61
+ "id": "mRx_hiQnLGjV",
62
+ "outputId": "90fe1bbb-76b7-489b-e864-1b41ffbbeeef"
63
+ },
64
+ "outputs": [
65
+ {
66
+ "name": "stdout",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "Requirement already satisfied: torchdata in /usr/local/lib/python3.10/dist-packages (0.7.1)\n",
70
+ "Requirement already satisfied: urllib3>=1.25 in /usr/local/lib/python3.10/dist-packages (from torchdata) (2.0.7)\n",
71
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torchdata) (2.31.0)\n",
72
+ "Requirement already satisfied: torch>=2 in /usr/local/lib/python3.10/dist-packages (from torchdata) (2.2.1+cu121)\n",
73
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (3.13.3)\n",
74
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (4.10.0)\n",
75
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (1.12)\n",
76
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (3.2.1)\n",
77
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (3.1.3)\n",
78
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (2023.6.0)\n",
79
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.105)\n",
80
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.105)\n",
81
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.105)\n",
82
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (8.9.2.26)\n",
83
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.3.1)\n",
84
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (11.0.2.54)\n",
85
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (10.3.2.106)\n",
86
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (11.4.5.107)\n",
87
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.0.106)\n",
88
+ "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (2.19.3)\n",
89
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (12.1.105)\n",
90
+ "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata) (2.2.0)\n",
91
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2->torchdata) (12.4.127)\n",
92
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torchdata) (3.3.2)\n",
93
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torchdata) (3.6)\n",
94
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torchdata) (2024.2.2)\n",
95
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2->torchdata) (2.1.5)\n",
96
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2->torchdata) (1.3.0)\n",
97
+ "Requirement already satisfied: spacy in /usr/local/lib/python3.10/dist-packages (3.7.4)\n",
98
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.12)\n",
99
+ "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.5)\n",
100
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.10)\n",
101
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.8)\n",
102
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.9)\n",
103
+ "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (8.2.3)\n",
104
+ "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.1.2)\n",
105
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.4.8)\n",
106
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.10)\n",
107
+ "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.3.4)\n",
108
+ "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.9.4)\n",
109
+ "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (6.4.0)\n",
110
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (4.66.2)\n",
111
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.31.0)\n",
112
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.6.4)\n",
113
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.1.3)\n",
114
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy) (67.7.2)\n",
115
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (24.0)\n",
116
+ "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.3.0)\n",
117
+ "Requirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.25.2)\n",
118
+ "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (0.6.0)\n",
119
+ "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (2.16.3)\n",
120
+ "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy) (4.10.0)\n",
121
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.3.2)\n",
122
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.6)\n",
123
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.0.7)\n",
124
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2024.2.2)\n",
125
+ "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.7.11)\n",
126
+ "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy) (0.1.4)\n",
127
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy) (8.1.7)\n",
128
+ "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy) (0.16.0)\n",
129
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy) (2.1.5)\n"
130
+ ]
131
+ }
132
+ ],
133
+ "source": [
134
+ "!pip install -U torchdata\n",
135
+ "!pip install -U spacy"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {
142
+ "colab": {
143
+ "base_uri": "https://localhost:8080/"
144
+ },
145
+ "id": "WdqsXpFuzGrH",
146
+ "outputId": "f5402068-ed10-445e-82a6-9db4d11d310c"
147
+ },
148
+ "outputs": [
149
+ {
150
+ "name": "stdout",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Collecting en-core-web-sm==3.7.1\n",
154
+ " Using cached https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)\n",
155
+ "Requirement already satisfied: spacy<3.8.0,>=3.7.2 in /usr/local/lib/python3.10/dist-packages (from en-core-web-sm==3.7.1) (3.7.4)\n",
156
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.12)\n",
157
+ "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.5)\n",
158
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.0.10)\n",
159
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.8)\n",
160
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.0.9)\n",
161
+ "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.2.3)\n",
162
+ "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.1.2)\n",
163
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.4.8)\n",
164
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.10)\n",
165
+ "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.3.4)\n",
166
+ "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.9.4)\n",
167
+ "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (6.4.0)\n",
168
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.66.2)\n",
169
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.31.0)\n",
170
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.6.4)\n",
171
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.1.3)\n",
172
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (67.7.2)\n",
173
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (24.0)\n",
174
+ "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.0)\n",
175
+ "Requirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (1.25.2)\n",
176
+ "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.6.0)\n",
177
+ "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.16.3)\n",
178
+ "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (4.10.0)\n",
179
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.3.2)\n",
180
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (3.6)\n",
181
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.0.7)\n",
182
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2024.2.2)\n",
183
+ "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.7.11)\n",
184
+ "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.1.4)\n",
185
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (8.1.7)\n",
186
+ "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (0.16.0)\n",
187
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy<3.8.0,>=3.7.2->en-core-web-sm==3.7.1) (2.1.5)\n",
188
+ "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
189
+ "You can now load the package via spacy.load('en_core_web_sm')\n",
190
+ "\u001b[38;5;3m⚠ Restart to reload dependencies\u001b[0m\n",
191
+ "If you are in a Jupyter or Colab notebook, you may need to restart Python in\n",
192
+ "order to load all the package's dependencies. You can do this by selecting the\n",
193
+ "'Restart kernel' or 'Restart runtime' option.\n",
194
+ "Collecting de-core-news-sm==3.7.0\n",
195
+ " Using cached https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.7.0/de_core_news_sm-3.7.0-py3-none-any.whl (14.6 MB)\n",
196
+ "Requirement already satisfied: spacy<3.8.0,>=3.7.0 in /usr/local/lib/python3.10/dist-packages (from de-core-news-sm==3.7.0) (3.7.4)\n",
197
+ "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.0.12)\n",
198
+ "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (1.0.5)\n",
199
+ "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (1.0.10)\n",
200
+ "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.0.8)\n",
201
+ "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.0.9)\n",
202
+ "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (8.2.3)\n",
203
+ "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (1.1.2)\n",
204
+ "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.4.8)\n",
205
+ "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.0.10)\n",
206
+ "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.3.4)\n",
207
+ "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.9.4)\n",
208
+ "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (6.4.0)\n",
209
+ "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (4.66.2)\n",
210
+ "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.31.0)\n",
211
+ "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.6.4)\n",
212
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.1.3)\n",
213
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (67.7.2)\n",
214
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (24.0)\n",
215
+ "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.3.0)\n",
216
+ "Requirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.10/dist-packages (from spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (1.25.2)\n",
217
+ "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.6.0)\n",
218
+ "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.16.3)\n",
219
+ "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (4.10.0)\n",
220
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.3.2)\n",
221
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (3.6)\n",
222
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.0.7)\n",
223
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2024.2.2)\n",
224
+ "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.7.11)\n",
225
+ "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.1.4)\n",
226
+ "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (8.1.7)\n",
227
+ "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (0.16.0)\n",
228
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy<3.8.0,>=3.7.0->de-core-news-sm==3.7.0) (2.1.5)\n",
229
+ "\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
230
+ "You can now load the package via spacy.load('de_core_news_sm')\n",
231
+ "\u001b[38;5;3m⚠ Restart to reload dependencies\u001b[0m\n",
232
+ "If you are in a Jupyter or Colab notebook, you may need to restart Python in\n",
233
+ "order to load all the package's dependencies. You can do this by selecting the\n",
234
+ "'Restart kernel' or 'Restart runtime' option.\n"
235
+ ]
236
+ }
237
+ ],
238
+ "source": [
239
+ "!python -m spacy download en_core_web_sm\n",
240
+ "!python -m spacy download de_core_news_sm"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "metadata": {
247
+ "id": "Vmir-6Ppki3_"
248
+ },
249
+ "outputs": [],
250
+ "source": [
251
+ "!pip install portalocker>=2.0.0"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 92,
257
+ "metadata": {
258
+ "colab": {
259
+ "base_uri": "https://localhost:8080/"
260
+ },
261
+ "id": "nzh92t5UW9bu",
262
+ "outputId": "4db35419-1b6d-413f-89b8-791214a07826"
263
+ },
264
+ "outputs": [
265
+ {
266
+ "output_type": "stream",
267
+ "name": "stderr",
268
+ "text": [
269
+ "/usr/local/lib/python3.10/dist-packages/spacy/util.py:1740: UserWarning: [W111] Jupyter notebook detected: if using `prefer_gpu()` or `require_gpu()`, include it in the same cell right before `spacy.load()` to ensure that the model is loaded on the correct device. More information: http://spacy.io/usage/v3#jupyter-notebook-gpu\n",
270
+ " warnings.warn(Warnings.W111)\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')\n",
276
+ "token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')\n",
277
+ "\n",
278
+ "\n",
279
+ "# helper function to yield list of tokens\n",
280
+ "def yield_tokens(data_iter: Iterable, language: str) -> List[str]:\n",
281
+ " language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}\n",
282
+ "\n",
283
+ " for data_sample in data_iter:\n",
284
+ " yield token_transform[language](data_sample[language_index[language]])\n",
285
+ "\n",
286
+ "# Define special symbols and indices\n",
287
+ "UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n",
288
+ "# Make sure the tokens are in order of their indices to properly insert them in vocab\n",
289
+ "special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']\n",
290
+ "\n",
291
+ "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n",
292
+ " # Training data Iterator\n",
293
+ " train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))\n",
294
+ " # Create torchtext's Vocab object\n",
295
+ " vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),\n",
296
+ " min_freq=1,\n",
297
+ " specials=special_symbols,\n",
298
+ " special_first=True)\n",
299
+ "\n",
300
+ "# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.\n",
301
+ "# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.\n",
302
+ "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n",
303
+ " vocab_transform[ln].set_default_index(UNK_IDX)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 93,
309
+ "metadata": {
310
+ "id": "OB_yiHCaXKv8"
311
+ },
312
+ "outputs": [],
313
+ "source": [
314
+ "from torch import Tensor\n",
315
+ "import torch\n",
316
+ "import torch.nn as nn\n",
317
+ "from torch.nn import Transformer\n",
318
+ "import math\n",
319
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
320
+ "\n",
321
+ "# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.\n",
322
+ "class PositionalEncoding(nn.Module):\n",
323
+ " def __init__(self,\n",
324
+ " emb_size: int,\n",
325
+ " dropout: float,\n",
326
+ " maxlen: int = 5000):\n",
327
+ " super(PositionalEncoding, self).__init__()\n",
328
+ " den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)\n",
329
+ " pos = torch.arange(0, maxlen).reshape(maxlen, 1)\n",
330
+ " pos_embedding = torch.zeros((maxlen, emb_size))\n",
331
+ " pos_embedding[:, 0::2] = torch.sin(pos * den)\n",
332
+ " pos_embedding[:, 1::2] = torch.cos(pos * den)\n",
333
+ " pos_embedding = pos_embedding.unsqueeze(-2)\n",
334
+ "\n",
335
+ " self.dropout = nn.Dropout(dropout)\n",
336
+ " self.register_buffer('pos_embedding', pos_embedding)\n",
337
+ "\n",
338
+ " def forward(self, token_embedding: Tensor):\n",
339
+ " return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])\n",
340
+ "\n",
341
+ "# helper Module to convert tensor of input indices into corresponding tensor of token embeddings\n",
342
+ "class TokenEmbedding(nn.Module):\n",
343
+ " def __init__(self, vocab_size: int, emb_size):\n",
344
+ " super(TokenEmbedding, self).__init__()\n",
345
+ " self.embedding = nn.Embedding(vocab_size, emb_size)\n",
346
+ " self.emb_size = emb_size\n",
347
+ "\n",
348
+ " def forward(self, tokens: Tensor):\n",
349
+ " return self.embedding(tokens.long()) * math.sqrt(self.emb_size)\n",
350
+ "\n",
351
+ "# Seq2Seq Network\n",
352
+ "class Seq2SeqTransformer(nn.Module):\n",
353
+ " def __init__(self,\n",
354
+ " num_encoder_layers: int,\n",
355
+ " num_decoder_layers: int,\n",
356
+ " emb_size: int,\n",
357
+ " nhead: int,\n",
358
+ " src_vocab_size: int,\n",
359
+ " tgt_vocab_size: int,\n",
360
+ " dim_feedforward: int = 512,\n",
361
+ " dropout: float = 0.1):\n",
362
+ " super(Seq2SeqTransformer, self).__init__()\n",
363
+ " self.transformer = Transformer(d_model=emb_size,\n",
364
+ " nhead=nhead,\n",
365
+ " num_encoder_layers=num_encoder_layers,\n",
366
+ " num_decoder_layers=num_decoder_layers,\n",
367
+ " dim_feedforward=dim_feedforward,\n",
368
+ " dropout=dropout)\n",
369
+ " self.generator = nn.Linear(emb_size, tgt_vocab_size)\n",
370
+ " self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)\n",
371
+ " self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)\n",
372
+ " self.positional_encoding = PositionalEncoding(\n",
373
+ " emb_size, dropout=dropout)\n",
374
+ "\n",
375
+ " def forward(self,\n",
376
+ " src: Tensor,\n",
377
+ " trg: Tensor,\n",
378
+ " src_mask: Tensor,\n",
379
+ " tgt_mask: Tensor,\n",
380
+ " src_padding_mask: Tensor,\n",
381
+ " tgt_padding_mask: Tensor,\n",
382
+ " memory_key_padding_mask: Tensor):\n",
383
+ " src_emb = self.positional_encoding(self.src_tok_emb(src))\n",
384
+ " tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))\n",
385
+ " outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,\n",
386
+ " src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n",
387
+ " return self.generator(outs)\n",
388
+ "\n",
389
+ " def encode(self, src: Tensor, src_mask: Tensor):\n",
390
+ " return self.transformer.encoder(self.positional_encoding(\n",
391
+ " self.src_tok_emb(src)), src_mask)\n",
392
+ "\n",
393
+ " def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):\n",
394
+ " return self.transformer.decoder(self.positional_encoding(\n",
395
+ " self.tgt_tok_emb(tgt)), memory,\n",
396
+ " tgt_mask)"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": 94,
402
+ "metadata": {
403
+ "id": "ECpJWZp2r_xa"
404
+ },
405
+ "outputs": [],
406
+ "source": [
407
+ "from torch import Tensor\n",
408
+ "import torch\n",
409
+ "import torch.nn as nn\n",
410
+ "from torch.nn import Transformer\n",
411
+ "import math\n",
412
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
413
+ "\n",
414
+ "# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.\n",
415
+ "class PositionalEncoding(nn.Module):\n",
416
+ " def __init__(self,\n",
417
+ " emb_size: int,\n",
418
+ " dropout: float,\n",
419
+ " maxlen: int = 5000):\n",
420
+ " super(PositionalEncoding, self).__init__()\n",
421
+ " den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)\n",
422
+ " pos = torch.arange(0, maxlen).reshape(maxlen, 1)\n",
423
+ " pos_embedding = torch.zeros((maxlen, emb_size))\n",
424
+ " pos_embedding[:, 0::2] = torch.sin(pos * den)\n",
425
+ " pos_embedding[:, 1::2] = torch.cos(pos * den)\n",
426
+ " pos_embedding = pos_embedding.unsqueeze(-2)\n",
427
+ "\n",
428
+ " self.dropout = nn.Dropout(dropout)\n",
429
+ " self.register_buffer('pos_embedding', pos_embedding)\n",
430
+ "\n",
431
+ " def forward(self, token_embedding: Tensor):\n",
432
+ " return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])\n",
433
+ "\n",
434
+ "# helper Module to convert tensor of input indices into corresponding tensor of token embeddings\n",
435
+ "class TokenEmbedding(nn.Module):\n",
436
+ " def __init__(self, vocab_size: int, emb_size):\n",
437
+ " super(TokenEmbedding, self).__init__()\n",
438
+ " self.embedding = nn.Embedding(vocab_size, emb_size)\n",
439
+ " self.emb_size = emb_size\n",
440
+ "\n",
441
+ " def forward(self, tokens: Tensor):\n",
442
+ " return self.embedding(tokens.long()) * math.sqrt(self.emb_size)\n",
443
+ "\n",
444
+ "# Seq2Seq Network\n",
445
+ "class Seq2SeqTransformer(nn.Module):\n",
446
+ " def __init__(self,\n",
447
+ " num_encoder_layers: int,\n",
448
+ " num_decoder_layers: int,\n",
449
+ " emb_size: int,\n",
450
+ " nhead: int,\n",
451
+ " src_vocab_size: int,\n",
452
+ " tgt_vocab_size: int,\n",
453
+ " dim_feedforward: int = 512,\n",
454
+ " dropout: float = 0.1):\n",
455
+ " super(Seq2SeqTransformer, self).__init__()\n",
456
+ " self.transformer = Transformer(d_model=emb_size,\n",
457
+ " nhead=nhead,\n",
458
+ " num_encoder_layers=num_encoder_layers,\n",
459
+ " num_decoder_layers=num_decoder_layers,\n",
460
+ " dim_feedforward=dim_feedforward,\n",
461
+ " dropout=dropout)\n",
462
+ " self.generator = nn.Linear(emb_size, tgt_vocab_size)\n",
463
+ " self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)\n",
464
+ " self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)\n",
465
+ " self.positional_encoding = PositionalEncoding(\n",
466
+ " emb_size, dropout=dropout)\n",
467
+ "\n",
468
+ " def forward(self,\n",
469
+ " src: Tensor,\n",
470
+ " trg: Tensor,\n",
471
+ " src_mask: Tensor,\n",
472
+ " tgt_mask: Tensor,\n",
473
+ " src_padding_mask: Tensor,\n",
474
+ " tgt_padding_mask: Tensor,\n",
475
+ " memory_key_padding_mask: Tensor):\n",
476
+ " src_emb = self.positional_encoding(self.src_tok_emb(src))\n",
477
+ " tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))\n",
478
+ " outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,\n",
479
+ " src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n",
480
+ " return self.generator(outs)\n",
481
+ "\n",
482
+ " def encode(self, src: Tensor, src_mask: Tensor):\n",
483
+ " return self.transformer.encoder(self.positional_encoding(\n",
484
+ " self.src_tok_emb(src)), src_mask)\n",
485
+ "\n",
486
+ " def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):\n",
487
+ " return self.transformer.decoder(self.positional_encoding(\n",
488
+ " self.tgt_tok_emb(tgt)), memory,\n",
489
+ " tgt_mask)"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": 95,
495
+ "metadata": {
496
+ "id": "PUIS0MWUZCKc"
497
+ },
498
+ "outputs": [],
499
+ "source": [
500
+ "def generate_square_subsequent_mask(sz):\n",
501
+ " mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)\n",
502
+ " mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
503
+ " return mask\n",
504
+ "\n",
505
+ "\n",
506
+ "def create_mask(src, tgt):\n",
507
+ " src_seq_len = src.shape[0]\n",
508
+ " tgt_seq_len = tgt.shape[0]\n",
509
+ "\n",
510
+ " tgt_mask = generate_square_subsequent_mask(tgt_seq_len)\n",
511
+ " src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)\n",
512
+ "\n",
513
+ " src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n",
514
+ " tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n",
515
+ " return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": 96,
521
+ "metadata": {
522
+ "colab": {
523
+ "base_uri": "https://localhost:8080/"
524
+ },
525
+ "id": "DA3eAj9GZFus",
526
+ "outputId": "8132fcb6-84c1-44c9-a150-616467d36052"
527
+ },
528
+ "outputs": [
529
+ {
530
+ "output_type": "stream",
531
+ "name": "stderr",
532
+ "text": [
533
+ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
534
+ " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
535
+ ]
536
+ }
537
+ ],
538
+ "source": [
539
+ "torch.manual_seed(0)\n",
540
+ "\n",
541
+ "SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])\n",
542
+ "TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])\n",
543
+ "EMB_SIZE = 512\n",
544
+ "NHEAD = 8\n",
545
+ "FFN_HID_DIM = 512\n",
546
+ "BATCH_SIZE = 128\n",
547
+ "NUM_ENCODER_LAYERS = 3\n",
548
+ "NUM_DECODER_LAYERS = 3\n",
549
+ "\n",
550
+ "transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,\n",
551
+ " NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)\n",
552
+ "\n",
553
+ "for p in transformer.parameters():\n",
554
+ " if p.dim() > 1:\n",
555
+ " nn.init.xavier_uniform_(p)\n",
556
+ "\n",
557
+ "transformer = transformer.to(DEVICE)\n",
558
+ "\n",
559
+ "loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n",
560
+ "\n",
561
+ "optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": 97,
567
+ "metadata": {
568
+ "id": "IO9Y95SnZKys"
569
+ },
570
+ "outputs": [],
571
+ "source": [
572
+ "from torch.nn.utils.rnn import pad_sequence\n",
573
+ "\n",
574
+ "# helper function to club together sequential operations\n",
575
+ "def sequential_transforms(*transforms):\n",
576
+ " def func(txt_input):\n",
577
+ " for transform in transforms:\n",
578
+ " txt_input = transform(txt_input)\n",
579
+ " return txt_input\n",
580
+ " return func\n",
581
+ "\n",
582
+ "# function to add BOS/EOS and create tensor for input sequence indices\n",
583
+ "def tensor_transform(token_ids: List[int]):\n",
584
+ " return torch.cat((torch.tensor([BOS_IDX]),\n",
585
+ " torch.tensor(token_ids),\n",
586
+ " torch.tensor([EOS_IDX])))\n",
587
+ "\n",
588
+ "# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices\n",
589
+ "text_transform = {}\n",
590
+ "for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n",
591
+ " text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization\n",
592
+ " vocab_transform[ln], #Numericalization\n",
593
+ " tensor_transform) # Add BOS/EOS and create tensor\n",
594
+ "\n",
595
+ "\n",
596
+ "# function to collate data samples into batch tensors\n",
597
+ "def collate_fn(batch):\n",
598
+ " src_batch, tgt_batch = [], []\n",
599
+ " for src_sample, tgt_sample in batch:\n",
600
+ " src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip(\"\\n\")))\n",
601
+ " tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip(\"\\n\")))\n",
602
+ "\n",
603
+ " src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)\n",
604
+ " tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)\n",
605
+ " return src_batch, tgt_batch"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": 98,
611
+ "metadata": {
612
+ "id": "qw9lO5xvZSjb"
613
+ },
614
+ "outputs": [],
615
+ "source": [
616
+ "from torch.utils.data import DataLoader\n",
617
+ "\n",
618
+ "def train_epoch(model, optimizer):\n",
619
+ " model.train()\n",
620
+ " losses = 0\n",
621
+ " train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))\n",
622
+ " train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)\n",
623
+ "\n",
624
+ " for src, tgt in train_dataloader:\n",
625
+ " src = src.to(DEVICE)\n",
626
+ " tgt = tgt.to(DEVICE)\n",
627
+ "\n",
628
+ " tgt_input = tgt[:-1, :]\n",
629
+ "\n",
630
+ " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n",
631
+ "\n",
632
+ " logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
633
+ "\n",
634
+ " optimizer.zero_grad()\n",
635
+ "\n",
636
+ " tgt_out = tgt[1:, :]\n",
637
+ " loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n",
638
+ " loss.backward()\n",
639
+ "\n",
640
+ " optimizer.step()\n",
641
+ " losses += loss.item()\n",
642
+ "\n",
643
+ " return losses / len(list(train_dataloader))"
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": 99,
649
+ "metadata": {
650
+ "id": "frdDbhZ_ZZ9d"
651
+ },
652
+ "outputs": [],
653
+ "source": [
654
+ "def evaluate(model):\n",
655
+ " model.eval()\n",
656
+ " losses = 0\n",
657
+ "\n",
658
+ " val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))\n",
659
+ " val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)\n",
660
+ "\n",
661
+ " for src, tgt in val_dataloader:\n",
662
+ " src = src.to(DEVICE)\n",
663
+ " tgt = tgt.to(DEVICE)\n",
664
+ "\n",
665
+ " tgt_input = tgt[:-1, :]\n",
666
+ "\n",
667
+ " src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n",
668
+ "\n",
669
+ " logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)\n",
670
+ "\n",
671
+ " tgt_out = tgt[1:, :]\n",
672
+ " loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n",
673
+ " losses += loss.item()\n",
674
+ "\n",
675
+ " return losses / len(list(val_dataloader))"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "metadata": {
682
+ "colab": {
683
+ "base_uri": "https://localhost:8080/"
684
+ },
685
+ "id": "xjLl776lZfJc",
686
+ "outputId": "6f0965d6-6e53-40b7-fe19-69096b68c3f8"
687
+ },
688
+ "outputs": [
689
+ {
690
+ "metadata": {
691
+ "tags": null
692
+ },
693
+ "name": "stderr",
694
+ "output_type": "stream",
695
+ "text": [
696
+ "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py:5109: UserWarning: Support for mismatched key_padding_mask and attn_mask is deprecated. Use same type for both instead.\n",
697
+ " warnings.warn(\n",
698
+ "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/combining.py:337: UserWarning: Some child DataPipes are not exhausted when __iter__ is called. We are resetting the buffer and each child DataPipe will read from the start again.\n",
699
+ " warnings.warn(\"Some child DataPipes are not exhausted when __iter__ is called. We are resetting \"\n"
700
+ ]
701
+ },
702
+ {
703
+ "output_type": "stream",
704
+ "name": "stdout",
705
+ "text": [
706
+ "Epoch: 1, Train loss: 5.344, Val loss: 4.106, Epoch time = 43.253s\n",
707
+ "Epoch: 2, Train loss: 3.761, Val loss: 3.309, Epoch time = 43.216s\n",
708
+ "Epoch: 3, Train loss: 3.157, Val loss: 2.887, Epoch time = 43.028s\n",
709
+ "Epoch: 4, Train loss: 2.767, Val loss: 2.640, Epoch time = 43.509s\n",
710
+ "Epoch: 5, Train loss: 2.477, Val loss: 2.442, Epoch time = 44.192s\n",
711
+ "Epoch: 6, Train loss: 2.247, Val loss: 2.306, Epoch time = 44.518s\n",
712
+ "Epoch: 7, Train loss: 2.055, Val loss: 2.207, Epoch time = 43.989s\n"
713
+ ]
714
+ }
715
+ ],
716
+ "source": [
717
+ "from timeit import default_timer as timer\n",
718
+ "NUM_EPOCHS = 10\n",
719
+ "\n",
720
+ "for epoch in range(1, NUM_EPOCHS+1):\n",
721
+ " start_time = timer()\n",
722
+ " train_loss = train_epoch(transformer, optimizer)\n",
723
+ " end_time = timer()\n",
724
+ " val_loss = evaluate(transformer)\n",
725
+ " print((f\"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, \"f\"Epoch time = {(end_time - start_time):.3f}s\"))\n"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": 20,
731
+ "metadata": {
732
+ "id": "ebEhLx-3slOE"
733
+ },
734
+ "outputs": [],
735
+ "source": [
736
+ "torch.save(transformer.state_dict(), '/gdrive/My Drive/transformer_model.pth')"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": 58,
742
+ "metadata": {
743
+ "id": "OW8D2ALUtBQq"
744
+ },
745
+ "outputs": [],
746
+ "source": [
747
+ "def greedy_decode(model, src, src_mask, max_len, start_symbol):\n",
748
+ " src = src.to(DEVICE)\n",
749
+ " src_mask = src_mask.to(DEVICE)\n",
750
+ "\n",
751
+ " memory = model.encode(src, src_mask)\n",
752
+ " ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)\n",
753
+ " for i in range(max_len-1):\n",
754
+ " memory = memory.to(DEVICE)\n",
755
+ " tgt_mask = (generate_square_subsequent_mask(ys.size(0))\n",
756
+ " .type(torch.bool)).to(DEVICE)\n",
757
+ " out = model.decode(ys, memory, tgt_mask)\n",
758
+ " out = out.transpose(0, 1)\n",
759
+ " prob = model.generator(out[:, -1])\n",
760
+ " _, next_word = torch.max(prob, dim=1)\n",
761
+ " next_word = next_word.item()\n",
762
+ "\n",
763
+ " ys = torch.cat([ys,\n",
764
+ " torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)\n",
765
+ " if next_word == EOS_IDX:\n",
766
+ " break\n",
767
+ " return ys"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": 59,
773
+ "metadata": {
774
+ "id": "exM3fCaBtFk2",
775
+ "colab": {
776
+ "base_uri": "https://localhost:8080/"
777
+ },
778
+ "outputId": "726a1bab-c145-4861-f4d3-6cb5122a567c"
779
+ },
780
+ "outputs": [
781
+ {
782
+ "output_type": "stream",
783
+ "name": "stdout",
784
+ "text": [
785
+ "3\n",
786
+ "3\n",
787
+ "512\n",
788
+ "8\n",
789
+ "19214\n",
790
+ "10837\n",
791
+ "512\n"
792
+ ]
793
+ }
794
+ ],
795
+ "source": [
796
+ "# Load the saved model\n",
797
+ "loaded_model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,\n",
798
+ " NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)\n",
799
+ "print(NUM_ENCODER_LAYERS)\n",
800
+ "print(NUM_DECODER_LAYERS)\n",
801
+ "print(EMB_SIZE)\n",
802
+ "print(NHEAD)\n",
803
+ "print(SRC_VOCAB_SIZE)\n",
804
+ "print(TGT_VOCAB_SIZE)\n",
805
+ "print(FFN_HID_DIM)\n",
806
+ "loaded_model.load_state_dict(torch.load('/gdrive/My Drive/transformer_model.pth'))\n",
807
+ "loaded_model.eval() # Make sure to set the model in evaluation mode\n",
808
+ "\n",
809
+ "# Incorporate the loaded model into the remaining portion of your code\n",
810
+ "def translate(model: torch.nn.Module, src_sentence: str):\n",
811
+ " model.eval()\n",
812
+ " src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)\n",
813
+ " num_tokens = src.shape[0]\n",
814
+ " src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n",
815
+ " tgt_tokens = greedy_decode(\n",
816
+ " model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()\n",
817
+ " return \" \".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace(\"<bos>\", \"\").replace(\"<eos>\", \"\")\n"
818
+ ]
819
+ },
820
+ {
821
+ "cell_type": "code",
822
+ "execution_count": 60,
823
+ "metadata": {
824
+ "id": "85yPR0zBtOsZ",
825
+ "colab": {
826
+ "base_uri": "https://localhost:8080/"
827
+ },
828
+ "outputId": "44efc93c-5d86-4084-fc21-bb3bc5bae207"
829
+ },
830
+ "outputs": [
831
+ {
832
+ "output_type": "stream",
833
+ "name": "stdout",
834
+ "text": [
835
+ " Russia cloth spoof Russia sewing Madrid Madrid Russia silhouetted Madrid Russia Madrid Madrid Russia cloth\n"
836
+ ]
837
+ }
838
+ ],
839
+ "source": [
840
+ "print(translate(transformer, \"Eine Gruppe von Menschen steht vor einem Iglu .\"))"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": 24,
846
+ "metadata": {
847
+ "id": "HJF7lXj0tPjO",
848
+ "colab": {
849
+ "base_uri": "https://localhost:8080/"
850
+ },
851
+ "outputId": "0237b57f-29cf-4c75-a060-fb928dbd2ced"
852
+ },
853
+ "outputs": [
854
+ {
855
+ "output_type": "stream",
856
+ "name": "stdout",
857
+ "text": [
858
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.38.2)\n",
859
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.3)\n",
860
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n",
861
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n",
862
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.0)\n",
863
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
864
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n",
865
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
866
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n",
867
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n",
868
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n",
869
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2023.6.0)\n",
870
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.10.0)\n",
871
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
872
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n",
873
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
874
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n"
875
+ ]
876
+ }
877
+ ],
878
+ "source": [
879
+ "!pip install transformers"
880
+ ]
881
+ },
882
+ {
883
+ "cell_type": "code",
884
+ "execution_count": 25,
885
+ "metadata": {
886
+ "id": "TMLnV5aMtSco"
887
+ },
888
+ "outputs": [],
889
+ "source": [
890
+ "from transformers.modeling_utils import PreTrainedModel ,PretrainedConfig"
891
+ ]
892
+ },
893
+ {
894
+ "cell_type": "code",
895
+ "execution_count": 27,
896
+ "metadata": {
897
+ "id": "oP9ODxPxtWPI"
898
+ },
899
+ "outputs": [],
900
+ "source": [
901
+ "class Seq2SeqTransformer(PreTrainedModel):\n",
902
+ " def __init__(self,config):\n",
903
+ " super(Seq2SeqTransformer, self).__init__(config)\n",
904
+ " self.transformer = Transformer(d_model=config.emb_size,\n",
905
+ " nhead=config.nhead,\n",
906
+ " num_encoder_layers=config.num_encoder_layers,\n",
907
+ " num_decoder_layers=config.num_decoder_layers,\n",
908
+ " dim_feedforward=config.dim_feedforward,\n",
909
+ " dropout=config.dropout)\n",
910
+ " self.generator = nn.Linear(config.emb_size, config.tgt_vocab_size)\n",
911
+ " self.src_tok_emb = TokenEmbedding(config.src_vocab_size, config.emb_size)\n",
912
+ " self.tgt_tok_emb = TokenEmbedding(config.tgt_vocab_size, config.emb_size)\n",
913
+ " self.positional_encoding = PositionalEncoding(\n",
914
+ " config.emb_size, dropout=config.dropout)"
915
+ ]
916
+ },
917
+ {
918
+ "cell_type": "code",
919
+ "execution_count": 30,
920
+ "metadata": {
921
+ "id": "_uOmJ7oQtdVF"
922
+ },
923
+ "outputs": [],
924
+ "source": [
925
+ "config = PretrainedConfig(\n",
926
+ " # Specify your vocabulary size\n",
927
+ " dim_feedforward =512,\n",
928
+ " dropout= 0.1,\n",
929
+ " emb_size= 512,\n",
930
+ " num_decoder_layers= 3,\n",
931
+ " num_encoder_layers= 3,\n",
932
+ " nhead= 8,\n",
933
+ " src_vocab_size= 19214,\n",
934
+ " tgt_vocab_size= 10837\n",
935
+ ")"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": 33,
941
+ "metadata": {
942
+ "id": "DO15AHGZtjwA"
943
+ },
944
+ "outputs": [],
945
+ "source": [
946
+ "model = Seq2SeqTransformer(config)\n",
947
+ "model.to(DEVICE)\n",
948
+ "\n",
949
+ "\n",
950
+ "model.save_pretrained('/gdrive/My Drive')"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "source": [
956
+ "!pip install -q gradio==3.48.0"
957
+ ],
958
+ "metadata": {
959
+ "colab": {
960
+ "base_uri": "https://localhost:8080/"
961
+ },
962
+ "id": "vJicfSC62R86",
963
+ "outputId": "ddb7f709-daff-4376-e15a-d936397e8ec3"
964
+ },
965
+ "execution_count": 35,
966
+ "outputs": [
967
+ {
968
+ "output_type": "stream",
969
+ "name": "stdout",
970
+ "text": [
971
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.3/20.3 MB\u001b[0m \u001b[31m57.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
972
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m91.9/91.9 kB\u001b[0m \u001b[31m11.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
973
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
974
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m299.2/299.2 kB\u001b[0m \u001b[31m34.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
975
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
976
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m144.8/144.8 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
977
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.8/60.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
978
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
979
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
980
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.9/71.9 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
981
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
982
+ "\u001b[?25h Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
983
+ ]
984
+ }
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "code",
989
+ "source": [
990
+ "import gradio as gr\n",
991
+ "import torch\n",
992
+ "from torchtext.data.utils import get_tokenizer\n",
993
+ "from torchtext.vocab import build_vocab_from_iterator\n",
994
+ "from torchtext.datasets import Multi30k\n",
995
+ "from torch import Tensor\n",
996
+ "from typing import Iterable, List\n",
997
+ "\n",
998
+ "# Define your model, tokenizer, and other necessary components here\n",
999
+ "# Ensure you have imported all necessary libraries\n",
1000
+ "\n",
1001
+ "# Load your transformer model\n",
1002
+ "model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,\n",
1003
+ " NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)\n",
1004
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
1005
+ "model.load_state_dict(torch.load('/gdrive/My Drive/transformer_model.pth', map_location=device))\n",
1006
+ "model.eval()\n",
1007
+ "\n",
1008
+ "\n",
1009
+ "def translate(model: torch.nn.Module, src_sentence: str):\n",
1010
+ " model.eval()\n",
1011
+ " src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)\n",
1012
+ " num_tokens = src.shape[0]\n",
1013
+ " src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n",
1014
+ " tgt_tokens = greedy_decode(\n",
1015
+ " model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()\n",
1016
+ " return \" \".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace(\"<bos>\", \"\").replace(\"<eos>\", \"\")\n",
1017
+ "\n",
1018
+ "\n"
1019
+ ],
1020
+ "metadata": {
1021
+ "colab": {
1022
+ "base_uri": "https://localhost:8080/"
1023
+ },
1024
+ "id": "wgBhx0w7-EUa",
1025
+ "outputId": "170f3d83-5c56-4cc6-da52-273b8f63e885"
1026
+ },
1027
+ "execution_count": 90,
1028
+ "outputs": [
1029
+ {
1030
+ "output_type": "stream",
1031
+ "name": "stderr",
1032
+ "text": [
1033
+ "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
1034
+ " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
1035
+ ]
1036
+ }
1037
+ ]
1038
+ },
1039
+ {
1040
+ "cell_type": "code",
1041
+ "source": [
1042
+ "if __name__ == \"__main__\":\n",
1043
+ " # Create the Gradio interface\n",
1044
+ " iface = gr.Interface(\n",
1045
+ " fn=translate, # Specify the translation function as the main function\n",
1046
+ " inputs=[\n",
1047
+ " gr.inputs.Textbox(label=\"Text\"),\n",
1048
+ " gr.inputs.Textbox(label=\"Text\")\n",
1049
+ "\n",
1050
+ " ],\n",
1051
+ " outputs=[\"text\"], # Define the output type as text\n",
1052
+ " #examples=[[\"I'm ready\", \"english\", \"arabic\"]], # Provide an example input for demonstration\n",
1053
+ " cache_examples=False, # Disable caching of examples\n",
1054
+ " title=\"germanToenglish\", # Set the title of the interface\n",
1055
+ " #description=\"This is a translator app for arabic and english. Currently supports only english to arabic.\" # Add a description of the interface\n",
1056
+ " )\n",
1057
+ "\n",
1058
+ " # Launch the interface\n",
1059
+ " iface.launch(share=True)"
1060
+ ],
1061
+ "metadata": {
1062
+ "colab": {
1063
+ "base_uri": "https://localhost:8080/",
1064
+ "height": 819
1065
+ },
1066
+ "id": "y9CN022m-hGQ",
1067
+ "outputId": "34971409-3c9d-46a8-a741-e785d597d18c"
1068
+ },
1069
+ "execution_count": 91,
1070
+ "outputs": [
1071
+ {
1072
+ "output_type": "stream",
1073
+ "name": "stderr",
1074
+ "text": [
1075
+ "<ipython-input-91-b142228ac367>:6: GradioDeprecationWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
1076
+ " gr.inputs.Textbox(label=\"Text\"),\n",
1077
+ "<ipython-input-91-b142228ac367>:6: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
1078
+ " gr.inputs.Textbox(label=\"Text\"),\n",
1079
+ "<ipython-input-91-b142228ac367>:6: GradioDeprecationWarning: `numeric` parameter is deprecated, and it has no effect\n",
1080
+ " gr.inputs.Textbox(label=\"Text\"),\n",
1081
+ "<ipython-input-91-b142228ac367>:7: GradioDeprecationWarning: Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your component from gradio.components\n",
1082
+ " gr.inputs.Textbox(label=\"Text\")\n",
1083
+ "<ipython-input-91-b142228ac367>:7: GradioDeprecationWarning: `optional` parameter is deprecated, and it has no effect\n",
1084
+ " gr.inputs.Textbox(label=\"Text\")\n",
1085
+ "<ipython-input-91-b142228ac367>:7: GradioDeprecationWarning: `numeric` parameter is deprecated, and it has no effect\n",
1086
+ " gr.inputs.Textbox(label=\"Text\")\n"
1087
+ ]
1088
+ },
1089
+ {
1090
+ "output_type": "stream",
1091
+ "name": "stdout",
1092
+ "text": [
1093
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
1094
+ "Running on public URL: https://05da874e546ecf0271.gradio.live\n",
1095
+ "\n",
1096
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
1097
+ ]
1098
+ },
1099
+ {
1100
+ "output_type": "display_data",
1101
+ "data": {
1102
+ "text/plain": [
1103
+ "<IPython.core.display.HTML object>"
1104
+ ],
1105
+ "text/html": [
1106
+ "<div><iframe src=\"https://05da874e546ecf0271.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1107
+ ]
1108
+ },
1109
+ "metadata": {}
1110
+ }
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "source": [
1116
+ "if __name__ == \"__main__\":\n",
1117
+ " # Create the Gradio interface\n",
1118
+ " iface = gr.Interface(\n",
1119
+ " fn=translate, # Specify the translation function as the main function\n",
1120
+ " inputs=[\n",
1121
+ " gr.components.Textbox(label=\"Text\"), # Add a textbox input for entering text\n",
1122
+ " gr.components.Dropdown(label=\"Source Language\", choices=language), # Add a dropdown for selecting source language\n",
1123
+ " gr.components.Dropdown(label=\"Target Language\", choices=language), # Add a dropdown for selecting target language\n",
1124
+ " ],\n",
1125
+ " outputs=[\"text\"], # Define the output type as text\n",
1126
+ " #examples=[[\"I'm ready\", \"english\", \"arabic\"]], # Provide an example input for demonstration\n",
1127
+ " cache_examples=False, # Disable caching of examples\n",
1128
+ " title=\"germanToenglish\", # Set the title of the interface\n",
1129
+ " #description=\"This is a translator app for arabic and english. Currently supports only english to arabic.\" # Add a description of the interface\n",
1130
+ " )\n",
1131
+ "\n",
1132
+ " # Launch the interface\n",
1133
+ " iface.launch(share=True)"
1134
+ ],
1135
+ "metadata": {
1136
+ "colab": {
1137
+ "base_uri": "https://localhost:8080/",
1138
+ "height": 680
1139
+ },
1140
+ "id": "NRTdTJ8E72LQ",
1141
+ "outputId": "6d76e9c7-8f46-498b-e0a6-b6aa74b48fc6"
1142
+ },
1143
+ "execution_count": 45,
1144
+ "outputs": [
1145
+ {
1146
+ "output_type": "stream",
1147
+ "name": "stderr",
1148
+ "text": [
1149
+ "/usr/local/lib/python3.10/dist-packages/gradio/utils.py:812: UserWarning: Expected 2 arguments for function <function translate at 0x7d1bb879fc70>, received 3.\n",
1150
+ " warnings.warn(\n",
1151
+ "/usr/local/lib/python3.10/dist-packages/gradio/utils.py:820: UserWarning: Expected maximum 2 arguments for function <function translate at 0x7d1bb879fc70>, received 3.\n",
1152
+ " warnings.warn(\n"
1153
+ ]
1154
+ },
1155
+ {
1156
+ "output_type": "stream",
1157
+ "name": "stdout",
1158
+ "text": [
1159
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
1160
+ "Running on public URL: https://652be12920500f856f.gradio.live\n",
1161
+ "\n",
1162
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
1163
+ ]
1164
+ },
1165
+ {
1166
+ "output_type": "display_data",
1167
+ "data": {
1168
+ "text/plain": [
1169
+ "<IPython.core.display.HTML object>"
1170
+ ],
1171
+ "text/html": [
1172
+ "<div><iframe src=\"https://652be12920500f856f.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1173
+ ]
1174
+ },
1175
+ "metadata": {}
1176
+ }
1177
+ ]
1178
+ },
1179
+ {
1180
+ "cell_type": "code",
1181
+ "source": [],
1182
+ "metadata": {
1183
+ "id": "5RuYPqUT3M3M"
1184
+ },
1185
+ "execution_count": null,
1186
+ "outputs": []
1187
+ }
1188
+ ],
1189
+ "metadata": {
1190
+ "accelerator": "GPU",
1191
+ "colab": {
1192
+ "gpuType": "T4",
1193
+ "provenance": []
1194
+ },
1195
+ "kernelspec": {
1196
+ "display_name": "Python 3",
1197
+ "name": "python3"
1198
+ },
1199
+ "language_info": {
1200
+ "name": "python"
1201
+ }
1202
+ },
1203
+ "nbformat": 4,
1204
+ "nbformat_minor": 0
1205
+ }