nschenone commited on
Commit
d95efe1
·
1 Parent(s): 3474b25

Added model config and dynamic pipeline loading

Browse files
Files changed (6) hide show
  1. .gitignore +166 -0
  2. Untitled.ipynb +162 -0
  3. app.py +7 -14
  4. model_config.yaml +18 -0
  5. src/__init__.py +0 -0
  6. src/utils.py +16 -0
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ # End of https://www.toptal.com/developers/gitignore/api/python
Untitled.ipynb ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 24,
6
+ "id": "9654bdb9-79ea-49b0-ab6c-da9ddbdb7ee2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from src.utils import load_pipelines_from_config"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 25,
16
+ "id": "b3b42ff6-dea6-4700-ab96-6bf1706221b7",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "pipelines = load_pipelines_from_config(\"model_config.yaml\")"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 26,
26
+ "id": "667ae085-587a-427a-87dd-bfc7c08553a5",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "{'Rap': <transformers.pipelines.text_generation.TextGenerationPipeline at 0x7f5650745420>,\n",
33
+ " 'Metal': <transformers.pipelines.text_generation.TextGenerationPipeline at 0x7f5595d74700>}"
34
+ ]
35
+ },
36
+ "execution_count": 26,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "pipelines"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 5,
48
+ "id": "0a90af2f-2f8f-4c0a-a2ef-ce8d7bf5e4be",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "with open(\"model_config.yaml\", \"r\") as f:\n",
53
+ " model_config = yaml.safe_load(f.read())"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 6,
59
+ "id": "e743aa51-6d6b-43c6-b138-d9b7e9b70a36",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "{'Rap': {'model_name': 'nschenone/rap-distil',\n",
66
+ " 'artist_names': ['Eminem', 'Hopsin', 'Kentrick Lamar'],\n",
67
+ " 'mlflow_run_id': '16c4ff05d92a45d79d89572a58b6424b',\n",
68
+ " 'hf_commit_hash': 'ca066f322213fbeac8d036fafd32112e23837722',\n",
69
+ " 'task': 'text-generation'},\n",
70
+ " 'Metal': {'model_name': 'nschenone/metal-distil',\n",
71
+ " 'artist_names': ['Slipknot', 'Parkway Drive', 'Periphery'],\n",
72
+ " 'mlflow_run_id': 'f30f57e3d8c440a09e1738f07db0b211',\n",
73
+ " 'hf_commit_hash': 'ed0657933ac3eb11a554dbe153363ff3e457f5ab',\n",
74
+ " 'task': 'text-generation'}}"
75
+ ]
76
+ },
77
+ "execution_count": 6,
78
+ "metadata": {},
79
+ "output_type": "execute_result"
80
+ }
81
+ ],
82
+ "source": [
83
+ "model_config"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 13,
89
+ "id": "8608031c-c0a2-44ac-95fb-3acfde35c125",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": []
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 20,
97
+ "id": "a25764a5-e0db-45a5-96ed-34efde27046a",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "a = pipeline(\n",
102
+ " task=\"text-generation\",\n",
103
+ " model=\"nschenone/rap-distil\",\n",
104
+ " revision=\"753f2768e0a9d5b21b5009bec4855ed2c2ddef16\"\n",
105
+ ")"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 23,
111
+ "id": "632a4847-75e4-459b-a4aa-248ff741f879",
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "ename": "ValueError",
116
+ "evalue": "You need to specify a `repo_path_or_name` or a `repo_url`.",
117
+ "output_type": "error",
118
+ "traceback": [
119
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
120
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
121
+ "Input \u001b[0;32mIn [23]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43ma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
122
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:2620\u001b[0m, in \u001b[0;36mPreTrainedModel.push_to_hub\u001b[0;34m(self, repo_path_or_name, repo_url, use_temp_dir, commit_message, organization, private, use_auth_token, max_shard_size, **model_card_kwargs)\u001b[0m\n\u001b[1;32m 2617\u001b[0m repo_path_or_name \u001b[38;5;241m=\u001b[39m tempfile\u001b[38;5;241m.\u001b[39mmkdtemp()\n\u001b[1;32m 2619\u001b[0m \u001b[38;5;66;03m# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.\u001b[39;00m\n\u001b[0;32m-> 2620\u001b[0m repo \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_or_get_repo\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2621\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_path_or_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_path_or_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2622\u001b[0m \u001b[43m \u001b[49m\u001b[43mrepo_url\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_url\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2623\u001b[0m \u001b[43m \u001b[49m\u001b[43morganization\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morganization\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2624\u001b[0m \u001b[43m \u001b[49m\u001b[43mprivate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprivate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2625\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_auth_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_auth_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2626\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2627\u001b[0m \u001b[38;5;66;03m# Save the files in the cloned repo\u001b[39;00m\n\u001b[1;32m 2628\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_pretrained(repo_path_or_name, max_shard_size\u001b[38;5;241m=\u001b[39mmax_shard_size)\n",
123
+ "File \u001b[0;32m/opt/conda/lib/python3.10/site-packages/transformers/utils/hub.py:1017\u001b[0m, in \u001b[0;36mPushToHubMixin._create_or_get_repo\u001b[0;34m(cls, repo_path_or_name, repo_url, organization, private, use_auth_token)\u001b[0m\n\u001b[1;32m 1007\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 1008\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_create_or_get_repo\u001b[39m(\n\u001b[1;32m 1009\u001b[0m \u001b[38;5;28mcls\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1014\u001b[0m use_auth_token: Optional[Union[\u001b[38;5;28mbool\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1015\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Repository:\n\u001b[1;32m 1016\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m repo_path_or_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m repo_url \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou need to specify a `repo_path_or_name` or a `repo_url`.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_auth_token \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m repo_url \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1020\u001b[0m use_auth_token \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
124
+ "\u001b[0;31mValueError\u001b[0m: You need to specify a `repo_path_or_name` or a `repo_url`."
125
+ ]
126
+ }
127
+ ],
128
+ "source": [
129
+ "a.model.push_to_hub()"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "id": "2a0c4d95-8ea7-4b40-98f1-3ad16c31eeba",
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": []
139
+ }
140
+ ],
141
+ "metadata": {
142
+ "kernelspec": {
143
+ "display_name": "Python 3 (ipykernel)",
144
+ "language": "python",
145
+ "name": "python3"
146
+ },
147
+ "language_info": {
148
+ "codemirror_mode": {
149
+ "name": "ipython",
150
+ "version": 3
151
+ },
152
+ "file_extension": ".py",
153
+ "mimetype": "text/x-python",
154
+ "name": "python",
155
+ "nbconvert_exporter": "python",
156
+ "pygments_lexer": "ipython3",
157
+ "version": "3.10.6"
158
+ }
159
+ },
160
+ "nbformat": 4,
161
+ "nbformat_minor": 5
162
+ }
app.py CHANGED
@@ -1,16 +1,9 @@
1
  import gradio as gr
2
- from transformers import pipeline, set_seed
3
 
4
- models = {
5
- "Rap" : pipeline(
6
- task="text-generation",
7
- model="nschenone/rap-distil"
8
- ),
9
- "Metal" : pipeline(
10
- task="text-generation",
11
- model="nschenone/metal-distil"
12
- )
13
- }
14
 
15
  def generate(
16
  text: str,
@@ -27,7 +20,7 @@ def generate(
27
 
28
  set_seed(seed)
29
 
30
- generated = models[model](
31
  text_inputs=text,
32
  max_length=max_length,
33
  num_return_sequences=num_return_sequences,
@@ -50,8 +43,8 @@ iface = gr.Interface(
50
  label="Input Text"
51
  ),
52
  gr.Dropdown(
53
- choices=list(models.keys()),
54
- value=list(models.keys())[0],
55
  label="Model"
56
  ),
57
  gr.Slider(
 
1
  import gradio as gr
2
+ from transformers import set_seed
3
 
4
+ from src.utils import load_pipelines_from_config
5
+
6
+ pipelines = load_pipelines_from_config(config_path="model_config.yaml")
 
 
 
 
 
 
 
7
 
8
  def generate(
9
  text: str,
 
20
 
21
  set_seed(seed)
22
 
23
+ generated = pipelines[model](
24
  text_inputs=text,
25
  max_length=max_length,
26
  num_return_sequences=num_return_sequences,
 
43
  label="Input Text"
44
  ),
45
  gr.Dropdown(
46
+ choices=list(pipelines.keys()),
47
+ value=list(pipelines.keys())[0],
48
  label="Model"
49
  ),
50
  gr.Slider(
model_config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Rap:
2
+ model_name: nschenone/rap-distil
3
+ artist_names:
4
+ - Eminem
5
+ - Hopsin
6
+ - Kentrick Lamar
7
+ mlflow_run_id: 16c4ff05d92a45d79d89572a58b6424b
8
+ hf_commit_hash: ca066f322213fbeac8d036fafd32112e23837722
9
+ task: text-generation
10
+ Metal:
11
+ model_name: nschenone/metal-distil
12
+ artist_names:
13
+ - Slipknot
14
+ - Parkway Drive
15
+ - Periphery
16
+ mlflow_run_id: f30f57e3d8c440a09e1738f07db0b211
17
+ hf_commit_hash: ed0657933ac3eb11a554dbe153363ff3e457f5ab
18
+ task: text-generation
src/__init__.py ADDED
File without changes
src/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from transformers import pipeline
3
+
4
+ def load_pipelines_from_config(config_path: str):
5
+ with open(config_path, "r") as f:
6
+ model_config = yaml.safe_load(f.read())
7
+
8
+ models = {}
9
+ for model, config in model_config.items():
10
+ models[model] = pipeline(
11
+ task=config["task"],
12
+ model=config["model_name"],
13
+ revision=config["hf_commit_hash"]
14
+ )
15
+
16
+ return models