Tonic commited on
Commit
2a73516
·
verified ·
1 Parent(s): bcc02ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -21
app.py CHANGED
@@ -7,7 +7,7 @@ from copy import deepcopy
7
  import requests
8
  import os.path
9
  from tqdm import tqdm
10
-
11
 
12
  # Set environment variables
13
  os.environ['RWKV_JIT_ON'] = '1'
@@ -20,27 +20,19 @@ MODELS = {
20
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
21
  }
22
 
23
- # Download tokenizer if not present
24
- TOKENIZER_FILE = "rwkv_vocab_v20230424.txt"
25
- TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/v2/rwkv_vocab_v20230424.txt"
26
-
27
- def download_tokenizer():
28
- if not os.path.exists(TOKENIZER_FILE):
29
- print("Downloading tokenizer...")
30
- response = requests.get(TOKENIZER_URL)
31
- with open(TOKENIZER_FILE, 'wb') as f:
32
- f.write(response.content)
33
 
34
- def download_model(model_name):
35
- """Download model if not present"""
36
- if not os.path.exists(model_name):
37
- print(f"Downloading {model_name}...")
38
- url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
39
  response = requests.get(url, stream=True)
40
  total_size = int(response.headers.get('content-length', 0))
41
 
42
- with open(model_name, 'wb') as file, tqdm(
43
- desc=model_name,
44
  total=total_size,
45
  unit='iB',
46
  unit_scale=True,
@@ -50,11 +42,22 @@ def download_model(model_name):
50
  size = file.write(data)
51
  pbar.update(size)
52
 
 
 
 
 
 
 
 
 
 
 
53
  class ModelManager:
54
  def __init__(self):
55
  self.current_model = None
56
  self.current_model_name = None
57
  self.pipeline = None
 
58
 
59
  def load_model(self, model_name):
60
  if model_name != self.current_model_name:
@@ -67,7 +70,6 @@ class ModelManager:
67
  self.current_model_name = model_name
68
  return self.pipeline
69
 
70
-
71
  model_manager = ModelManager()
72
 
73
  def generate_response(
@@ -115,7 +117,8 @@ def generate_response(
115
  pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
116
  return response
117
  except Exception as e:
118
- return f"Error: {str(e)}"
 
119
 
120
  # Create the Gradio interface
121
  with gr.Blocks() as demo:
@@ -228,4 +231,4 @@ with gr.Blocks() as demo:
228
 
229
  # Launch the demo
230
  if __name__ == "__main__":
231
- demo.launch(ssr_mode=False)
 
7
  import requests
8
  import os.path
9
  from tqdm import tqdm
10
+ import json
11
 
12
  # Set environment variables
13
  os.environ['RWKV_JIT_ON'] = '1'
 
20
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
21
  }
22
 
23
+ # Tokenizer settings
24
+ TOKENIZER_FILE = "20B_tokenizer.json"
25
+ TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
 
 
 
 
 
 
 
26
 
27
+ def download_file(url, filename):
28
+ """Generic file downloader with progress bar"""
29
+ if not os.path.exists(filename):
30
+ print(f"Downloading {filename}...")
 
31
  response = requests.get(url, stream=True)
32
  total_size = int(response.headers.get('content-length', 0))
33
 
34
+ with open(filename, 'wb') as file, tqdm(
35
+ desc=filename,
36
  total=total_size,
37
  unit='iB',
38
  unit_scale=True,
 
42
  size = file.write(data)
43
  pbar.update(size)
44
 
45
+ def download_model(model_name):
46
+ """Download model if not present"""
47
+ if not os.path.exists(model_name):
48
+ url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
49
+ download_file(url, model_name)
50
+
51
+ def ensure_tokenizer():
52
+ """Ensure tokenizer is present"""
53
+ download_file(TOKENIZER_URL, TOKENIZER_FILE)
54
+
55
  class ModelManager:
56
  def __init__(self):
57
  self.current_model = None
58
  self.current_model_name = None
59
  self.pipeline = None
60
+ ensure_tokenizer()
61
 
62
  def load_model(self, model_name):
63
  if model_name != self.current_model_name:
 
70
  self.current_model_name = model_name
71
  return self.pipeline
72
 
 
73
  model_manager = ModelManager()
74
 
75
  def generate_response(
 
117
  pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
118
  return response
119
  except Exception as e:
120
+ import traceback
121
+ return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}"
122
 
123
  # Create the Gradio interface
124
  with gr.Blocks() as demo:
 
231
 
232
  # Launch the demo
233
  if __name__ == "__main__":
234
+ demo.launch()