Hjgugugjhuhjggg commited on
Commit
43f235c
·
verified ·
1 Parent(s): 2469ee5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -32
app.py CHANGED
@@ -14,6 +14,7 @@ import yaml
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
16
  from clean_community_org import garbage_collect_empty_models
 
17
 
18
  has_gpu = torch.cuda.is_available()
19
 
@@ -54,21 +55,20 @@ A quick overview of the currently supported merge methods:
54
  ## Citation
55
  This GUI is powered by [Arcee's MergeKit](https://arxiv.org/abs/2403.13257).
56
  If you use it in your research, please cite the following paper:
57
- ```
58
  @article{goddard2024arcee,
59
  title={Arcee's MergeKit: A Toolkit for Merging Large Language Models},
60
  author={Goddard, Charles and Siriwardhana, Shamane and Ehghaghi, Malikeh and Meyers, Luke and Karpukhin, Vlad and Benedict, Brian and McQuade, Mark and Solawetz, Jacob},
61
  journal={arXiv preprint arXiv:2403.13257},
62
  year={2024}
63
  }
64
- ```
65
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
66
  """
67
 
68
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
69
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
70
 
71
- def merge(yaml_config: str, hf_token: str, repo_name: str, profile_name: str) -> Iterable[List[Log]]:
 
72
  runner = LogsViewRunner()
73
 
74
  if not yaml_config:
@@ -88,9 +88,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str, profile_name: str) ->
88
  level="ERROR",
89
  )
90
  return
91
- yield runner.log(
92
- "No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization."
93
- )
94
  is_community_model = True
95
  if not COMMUNITY_HF_TOKEN:
96
  raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
@@ -141,6 +139,7 @@ def merge(yaml_config: str, hf_token: str, repo_name: str, profile_name: str) ->
141
  )
142
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
143
 
 
144
  with gr.Blocks() as demo:
145
  gr.Markdown(MARKDOWN_DESCRIPTION)
146
 
@@ -148,37 +147,17 @@ with gr.Blocks() as demo:
148
  filename = gr.Textbox(visible=False, label="filename")
149
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
150
  with gr.Column():
151
- token = gr.Textbox(
152
- lines=1,
153
- label="HF Write Token",
154
- info="https://hf.co/settings/token",
155
- type="password",
156
- placeholder="Optional. Will upload merged model to MergeKit Community if empty.",
157
- )
158
- repo_name = gr.Textbox(
159
- lines=1,
160
- label="Repo name",
161
- placeholder="Optional. Will create a random name if empty.",
162
- )
163
- profile_name = gr.Textbox(
164
- lines=1,
165
- label="Hugging Face Profile Name",
166
- placeholder="Enter your Hugging Face profile name.",
167
- )
168
  button = gr.Button("Merge", variant="primary")
169
  logs = LogsView(label="Terminal output")
170
- gr.Examples(
171
- examples,
172
- fn=lambda s: (s,),
173
- run_on_click=True,
174
- label="Examples",
175
- inputs=[filename],
176
- outputs=[config],
177
- )
178
  gr.Markdown(MARKDOWN_ARTICLE)
179
 
180
  button.click(fn=merge, inputs=[config, token, repo_name, profile_name], outputs=[logs])
181
 
 
182
  def _garbage_collect_every_hour():
183
  while True:
184
  try:
@@ -187,7 +166,15 @@ def _garbage_collect_every_hour():
187
  print("Error running garbage collection", e)
188
  time.sleep(3600)
189
 
 
190
  pool = ThreadPoolExecutor()
191
  pool.submit(_garbage_collect_every_hour)
192
 
193
- demo.queue(default_concurrency_limit=2).launch()
 
 
 
 
 
 
 
 
14
  from gradio_logsview.logsview import Log, LogsView, LogsViewRunner
15
  from mergekit.config import MergeConfiguration
16
  from clean_community_org import garbage_collect_empty_models
17
+ import spaces
18
 
19
  has_gpu = torch.cuda.is_available()
20
 
 
55
  ## Citation
56
  This GUI is powered by [Arcee's MergeKit](https://arxiv.org/abs/2403.13257).
57
  If you use it in your research, please cite the following paper:
 
58
  @article{goddard2024arcee,
59
  title={Arcee's MergeKit: A Toolkit for Merging Large Language Models},
60
  author={Goddard, Charles and Siriwardhana, Shamane and Ehghaghi, Malikeh and Meyers, Luke and Karpukhin, Vlad and Benedict, Brian and McQuade, Mark and Solawetz, Jacob},
61
  journal={arXiv preprint arXiv:2403.13257},
62
  year={2024}
63
  }
 
64
  This Space is heavily inspired by LazyMergeKit by Maxime Labonne (see [Colab](https://colab.research.google.com/drive/1obulZ1ROXHjYLn6PPZJwRR6GzgQogxxb)).
65
  """
66
 
67
  examples = [[str(f)] for f in pathlib.Path("examples").glob("*.yaml")]
68
  COMMUNITY_HF_TOKEN = os.getenv("COMMUNITY_HF_TOKEN")
69
 
70
+
71
+ def merge(yaml_config, hf_token, repo_name, profile_name):
72
  runner = LogsViewRunner()
73
 
74
  if not yaml_config:
 
88
  level="ERROR",
89
  )
90
  return
91
+ yield runner.log("No HF token provided. Your merged model will be uploaded to the https://huggingface.co/mergekit-community organization.")
 
 
92
  is_community_model = True
93
  if not COMMUNITY_HF_TOKEN:
94
  raise gr.Error("Cannot upload to community org: community token not set by Space owner.")
 
139
  )
140
  yield runner.log(f"Model successfully uploaded to HF: {repo_url.repo_id}")
141
 
142
+
143
  with gr.Blocks() as demo:
144
  gr.Markdown(MARKDOWN_DESCRIPTION)
145
 
 
147
  filename = gr.Textbox(visible=False, label="filename")
148
  config = gr.Code(language="yaml", lines=10, label="config.yaml")
149
  with gr.Column():
150
+ token = gr.Textbox(lines=1, label="HF Write Token", info="https://hf.co/settings/token", type="password", placeholder="Optional. Will upload merged model to MergeKit Community if empty.")
151
+ repo_name = gr.Textbox(lines=1, label="Repo name", placeholder="Optional. Will create a random name if empty.")
152
+ profile_name = gr.Textbox(lines=1, label="Hugging Face Profile Name", placeholder="Enter your Hugging Face profile name.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  button = gr.Button("Merge", variant="primary")
154
  logs = LogsView(label="Terminal output")
155
+ gr.Examples(examples, fn=lambda s: (s,), run_on_click=True, label="Examples", inputs=[filename], outputs=[config])
 
 
 
 
 
 
 
156
  gr.Markdown(MARKDOWN_ARTICLE)
157
 
158
  button.click(fn=merge, inputs=[config, token, repo_name, profile_name], outputs=[logs])
159
 
160
+
161
  def _garbage_collect_every_hour():
162
  while True:
163
  try:
 
166
  print("Error running garbage collection", e)
167
  time.sleep(3600)
168
 
169
+
170
  pool = ThreadPoolExecutor()
171
  pool.submit(_garbage_collect_every_hour)
172
 
173
+
174
+ @spaces.GPU
175
+ def launch():
176
+ demo.queue(default_concurrency_limit=2).launch(share=True)
177
+
178
+
179
+ if __name__ == "__main__":
180
+ launch()