Tonic commited on
Commit
fb15fc9
·
verified ·
1 Parent(s): 2a73516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -14
app.py CHANGED
@@ -8,6 +8,8 @@ import requests
8
  import os.path
9
  from tqdm import tqdm
10
  import json
 
 
11
 
12
  # Set environment variables
13
  os.environ['RWKV_JIT_ON'] = '1'
@@ -20,9 +22,28 @@ MODELS = {
20
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
21
  }
22
 
23
- # Tokenizer settings
24
- TOKENIZER_FILE = "20B_tokenizer.json"
25
- TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def download_file(url, filename):
28
  """Generic file downloader with progress bar"""
@@ -48,26 +69,44 @@ def download_model(model_name):
48
  url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
49
  download_file(url, model_name)
50
 
51
- def ensure_tokenizer():
52
- """Ensure tokenizer is present"""
53
- download_file(TOKENIZER_URL, TOKENIZER_FILE)
 
 
 
 
54
 
55
  class ModelManager:
56
  def __init__(self):
57
  self.current_model = None
58
  self.current_model_name = None
59
  self.pipeline = None
60
- ensure_tokenizer()
61
 
62
- def load_model(self, model_name):
63
- if model_name != self.current_model_name:
64
- download_model(MODELS[model_name])
 
 
 
 
 
 
 
 
 
 
 
65
  self.current_model = RWKV(
66
- model=MODELS[model_name],
67
  strategy='cpu fp32'
68
  )
69
- self.pipeline = PIPELINE(self.current_model, TOKENIZER_FILE)
70
- self.current_model_name = model_name
 
 
 
 
71
  return self.pipeline
72
 
73
  model_manager = ModelManager()
@@ -104,7 +143,8 @@ def generate_response(
104
  alpha_decay=alpha_decay,
105
  token_ban=[],
106
  token_stop=[],
107
- chunk_len=256
 
108
  )
109
 
110
  # Generate response
 
8
  import os.path
9
  from tqdm import tqdm
10
  import json
11
+ from dataclasses import dataclass
12
+ from typing import Optional, List
13
 
14
  # Set environment variables
15
  os.environ['RWKV_JIT_ON'] = '1'
 
22
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
23
  }
24
 
25
+ # Model configurations
26
+ MODEL_CONFIGS = {
27
+ "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth": {
28
+ "n_layer": 12,
29
+ "n_embd": 768,
30
+ "ctx_len": 4096
31
+ },
32
+ "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth": {
33
+ "n_layer": 24,
34
+ "n_embd": 1024,
35
+ "ctx_len": 4096
36
+ }
37
+ }
38
+
39
+ @dataclass
40
+ class ModelArgs:
41
+ n_layer: int
42
+ n_embd: int
43
+ ctx_len: int
44
+ vocab_size: int = 65536
45
+ n_head: int = 16 # Number of attention heads
46
+ n_att: int = 1024 # Attention dimension
47
 
48
  def download_file(url, filename):
49
  """Generic file downloader with progress bar"""
 
69
  url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
70
  download_file(url, model_name)
71
 
72
+ class CustomPipeline(PIPELINE):
73
+ def __init__(self, model, vocab_file):
74
+ super().__init__(model, vocab_file)
75
+ self.model_args = None
76
+
77
+ def set_model_args(self, args: ModelArgs):
78
+ self.model_args = args
79
 
80
  class ModelManager:
81
  def __init__(self):
82
  self.current_model = None
83
  self.current_model_name = None
84
  self.pipeline = None
 
85
 
86
+ def load_model(self, model_choice):
87
+ model_file = MODELS[model_choice]
88
+ if model_file != self.current_model_name:
89
+ download_model(model_file)
90
+
91
+ # Get model configuration
92
+ config = MODEL_CONFIGS[model_file]
93
+ model_args = ModelArgs(
94
+ n_layer=config['n_layer'],
95
+ n_embd=config['n_embd'],
96
+ ctx_len=config['ctx_len']
97
+ )
98
+
99
+ # Initialize model with args
100
  self.current_model = RWKV(
101
+ model=model_file,
102
  strategy='cpu fp32'
103
  )
104
+
105
+ # Initialize custom pipeline
106
+ self.pipeline = CustomPipeline(self.current_model, "20B_tokenizer.json")
107
+ self.pipeline.set_model_args(model_args)
108
+ self.current_model_name = model_file
109
+
110
  return self.pipeline
111
 
112
  model_manager = ModelManager()
 
143
  alpha_decay=alpha_decay,
144
  token_ban=[],
145
  token_stop=[],
146
+ chunk_len=256,
147
+ model_args=pipeline.model_args # Pass model args to pipeline
148
  )
149
 
150
  # Generate response