Blazgo commited on
Commit
bfc3ab5
·
verified ·
1 Parent(s): 2a94399

Update app.py to prevent users from running multiple jobs

Browse files

Attempts to fix/prevent #44, #40, #12, #11

Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -43,6 +43,8 @@ has_gpu = torch.cuda.is_available()
43
  # write_model_card=True,
44
  # )
45
  # )
 
 
46
 
47
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
48
  " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
@@ -102,6 +104,7 @@ If you use it in your research, please cite the following paper:
102
  }
103
  ```
104
 
 
105
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
106
  """
107
 
@@ -113,16 +116,40 @@ examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
113
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
114
 
115
 
 
 
 
 
 
 
 
116
  def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  runner = LogsViewRunner()
118
 
119
  if not yaml_config:
120
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
 
121
  return
122
  try:
123
  merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
124
  except Exception as e:
125
  yield runner.log(f"Invalid yaml {e}", level="ERROR")
 
126
  return
127
 
128
  is_community_model = False
@@ -132,6 +159,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
132
  f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
133
  level="ERROR",
134
  )
 
135
  return
136
  yield runner.log(
137
  "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
@@ -167,6 +195,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
167
  yield runner.log(f"Repo created: {repo_url}")
168
  except Exception as e:
169
  yield runner.log(f"Error creating repo {e}", level="ERROR")
 
170
  return
171
 
172
  # Set tmp HF_HOME to avoid filling up disk Space
@@ -178,6 +207,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
178
  if runner.exit_code != 0:
179
  yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
180
  api.delete_repo(repo_url.repo_id)
 
181
  return
182
 
183
  yield runner.log("Model merged successfully. Uploading to HF.")
@@ -188,11 +218,10 @@ def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]
188
  )
189
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
190
 
191
- # This is workaround. As the space always getting stuck.
192
- def _restart_space():
193
- huggingface_hub.HfApi().restart_space(repo_id="arcee-ai/mergekit-gui", token=COMMUNITY_HF_TOKEN, factory_reboot=False)
194
  # Run garbage collection every hour to keep the community org clean.
195
- # Empty models might exists if the merge fails abruptly (e.g. if user leaves the Space).
196
  def _garbage_remover():
197
  try:
198
  garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
@@ -200,17 +229,13 @@ def _garbage_remover():
200
  print("Error running garbage collection", e)
201
 
202
  scheduler = BackgroundScheduler()
203
- restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=21600)
204
  garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
205
  scheduler.start()
206
- next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
207
-
208
- NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC)"
209
 
210
  with gr.Blocks() as demo:
211
  gr.Markdown(MARKDOWN_DESCRIPTION)
212
- gr.Markdown(NEXT_RESTART)
213
-
214
  with gr.Row():
215
  filename = gr.Textbox(visible=False, label="filename")
216
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
@@ -241,6 +266,5 @@ with gr.Blocks() as demo:
241
 
242
  button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
243
 
244
-
245
-
246
  demo.queue(default_concurrency_limit=1).launch()
 
 
43
  # write_model_card=True,
44
  # )
45
  # )
46
+ # A simple in-memory dictionary to track users' ongoing jobs by user ID (username)
47
+ active_jobs = {}
48
 
49
  cli = "mergekit-yaml config.yaml merge --copy-tokenizer" + (
50
  " --cuda --low-cpu-memory --allow-crimes" if has_gpu else " --allow-crimes --out-shard-size 1B --lazy-unpickle"
 
104
  }
105
  ```
106
 
107
+
108
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
109
  """
110
 
 
116
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
117
 
118
 
119
+ def get_user_from_token(hf_token: str) -> str:
120
+ """Fetch the username associated with the Hugging Face token."""
121
+ api = huggingface_hub.HfApi(token=hf_token)
122
+ user_info = api.whoami()
123
+ return user_info["name"] # Returns the username of the token owner
124
+
125
+
126
  def merge(yaml_config: str, hf_token: str, repo_name: str) -> Iterable[List[Log]]:
127
+ # Get the user associated with the token
128
+ try:
129
+ username = get_user_from_token(hf_token)
130
+ except Exception as e:
131
+ yield Log(f"Error fetching user info: {e}", level="ERROR")
132
+ return
133
+
134
+ # Check if the user already has a job running
135
+ if username in active_jobs and active_jobs[username]:
136
+ yield Log(f"You already have a job running, {username}. Please wait until it's complete.", level="ERROR")
137
+ return
138
+
139
+ # Mark the job as active for the current user
140
+ active_jobs[username] = True
141
+
142
  runner = LogsViewRunner()
143
 
144
  if not yaml_config:
145
  yield runner.log("Empty yaml, pick an example below", level="ERROR")
146
+ active_jobs[username] = False
147
  return
148
  try:
149
  merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
150
  except Exception as e:
151
  yield runner.log(f"Invalid yaml {e}", level="ERROR")
152
+ active_jobs[username] = False
153
  return
154
 
155
  is_community_model = False
 
159
  f"Cannot upload merge model to namespace {repo_name.split('/')[0]}: you must provide a valid token.",
160
  level="ERROR",
161
  )
162
+ active_jobs[username] = False
163
  return
164
  yield runner.log(
165
  "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
 
195
  yield runner.log(f"Repo created: {repo_url}")
196
  except Exception as e:
197
  yield runner.log(f"Error creating repo {e}", level="ERROR")
198
+ active_jobs[username] = False
199
  return
200
 
201
  # Set tmp HF_HOME to avoid filling up disk Space
 
207
  if runner.exit_code != 0:
208
  yield runner.log("Merge failed. Deleting repo as no model is uploaded.", level="ERROR")
209
  api.delete_repo(repo_url.repo_id)
210
+ active_jobs[username] = False
211
  return
212
 
213
  yield runner.log("Model merged successfully. Uploading to HF.")
 
218
  )
219
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
220
 
221
+ # Mark the job as complete for the current user
222
+ active_jobs[username] = False
223
+
224
  # Run garbage collection every hour to keep the community org clean.
 
225
  def _garbage_remover():
226
  try:
227
  garbage_collect_empty_models(token=COMMUNITY_HF_TOKEN)
 
229
  print("Error running garbage collection", e)
230
 
231
  scheduler = BackgroundScheduler()
 
232
  garbage_remover_job = scheduler.add_job(_garbage_remover, "interval", seconds=3600)
233
  scheduler.start()
 
 
 
234
 
235
  with gr.Blocks() as demo:
236
  gr.Markdown(MARKDOWN_DESCRIPTION)
237
+ gr.Markdown(f"Next Restart: {restart_space_job.next_run_time.astimezone(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')} (UTC)")
238
+
239
  with gr.Row():
240
  filename = gr.Textbox(visible=False, label="filename")
241
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
 
266
 
267
  button.click(fn=merge, inputs=[config, token, repo_name], outputs=[logs])
268
 
 
 
269
  demo.queue(default_concurrency_limit=1).launch()
270
+