Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ from copy import deepcopy
|
|
7 |
import requests
|
8 |
import os.path
|
9 |
from tqdm import tqdm
|
10 |
-
|
11 |
|
12 |
# Set environment variables
|
13 |
os.environ['RWKV_JIT_ON'] = '1'
|
@@ -20,27 +20,19 @@ MODELS = {
|
|
20 |
"0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
|
21 |
}
|
22 |
|
23 |
-
#
|
24 |
-
TOKENIZER_FILE = "
|
25 |
-
TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/
|
26 |
-
|
27 |
-
def download_tokenizer():
|
28 |
-
if not os.path.exists(TOKENIZER_FILE):
|
29 |
-
print("Downloading tokenizer...")
|
30 |
-
response = requests.get(TOKENIZER_URL)
|
31 |
-
with open(TOKENIZER_FILE, 'wb') as f:
|
32 |
-
f.write(response.content)
|
33 |
|
34 |
-
def
|
35 |
-
"""
|
36 |
-
if not os.path.exists(
|
37 |
-
print(f"Downloading {
|
38 |
-
url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
|
39 |
response = requests.get(url, stream=True)
|
40 |
total_size = int(response.headers.get('content-length', 0))
|
41 |
|
42 |
-
with open(
|
43 |
-
desc=
|
44 |
total=total_size,
|
45 |
unit='iB',
|
46 |
unit_scale=True,
|
@@ -50,11 +42,22 @@ def download_model(model_name):
|
|
50 |
size = file.write(data)
|
51 |
pbar.update(size)
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
class ModelManager:
|
54 |
def __init__(self):
|
55 |
self.current_model = None
|
56 |
self.current_model_name = None
|
57 |
self.pipeline = None
|
|
|
58 |
|
59 |
def load_model(self, model_name):
|
60 |
if model_name != self.current_model_name:
|
@@ -67,7 +70,6 @@ class ModelManager:
|
|
67 |
self.current_model_name = model_name
|
68 |
return self.pipeline
|
69 |
|
70 |
-
|
71 |
model_manager = ModelManager()
|
72 |
|
73 |
def generate_response(
|
@@ -115,7 +117,8 @@ def generate_response(
|
|
115 |
pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
|
116 |
return response
|
117 |
except Exception as e:
|
118 |
-
|
|
|
119 |
|
120 |
# Create the Gradio interface
|
121 |
with gr.Blocks() as demo:
|
@@ -228,4 +231,4 @@ with gr.Blocks() as demo:
|
|
228 |
|
229 |
# Launch the demo
|
230 |
if __name__ == "__main__":
|
231 |
-
demo.launch(
|
|
|
7 |
import requests
|
8 |
import os.path
|
9 |
from tqdm import tqdm
|
10 |
+
import json
|
11 |
|
12 |
# Set environment variables
|
13 |
os.environ['RWKV_JIT_ON'] = '1'
|
|
|
20 |
"0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
|
21 |
}
|
22 |
|
23 |
+
# Tokenizer settings
|
24 |
+
TOKENIZER_FILE = "20B_tokenizer.json"
|
25 |
+
TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
def download_file(url, filename):
|
28 |
+
"""Generic file downloader with progress bar"""
|
29 |
+
if not os.path.exists(filename):
|
30 |
+
print(f"Downloading {filename}...")
|
|
|
31 |
response = requests.get(url, stream=True)
|
32 |
total_size = int(response.headers.get('content-length', 0))
|
33 |
|
34 |
+
with open(filename, 'wb') as file, tqdm(
|
35 |
+
desc=filename,
|
36 |
total=total_size,
|
37 |
unit='iB',
|
38 |
unit_scale=True,
|
|
|
42 |
size = file.write(data)
|
43 |
pbar.update(size)
|
44 |
|
45 |
+
def download_model(model_name):
|
46 |
+
"""Download model if not present"""
|
47 |
+
if not os.path.exists(model_name):
|
48 |
+
url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
|
49 |
+
download_file(url, model_name)
|
50 |
+
|
51 |
+
def ensure_tokenizer():
|
52 |
+
"""Ensure tokenizer is present"""
|
53 |
+
download_file(TOKENIZER_URL, TOKENIZER_FILE)
|
54 |
+
|
55 |
class ModelManager:
|
56 |
def __init__(self):
|
57 |
self.current_model = None
|
58 |
self.current_model_name = None
|
59 |
self.pipeline = None
|
60 |
+
ensure_tokenizer()
|
61 |
|
62 |
def load_model(self, model_name):
|
63 |
if model_name != self.current_model_name:
|
|
|
70 |
self.current_model_name = model_name
|
71 |
return self.pipeline
|
72 |
|
|
|
73 |
model_manager = ModelManager()
|
74 |
|
75 |
def generate_response(
|
|
|
117 |
pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
|
118 |
return response
|
119 |
except Exception as e:
|
120 |
+
import traceback
|
121 |
+
return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}"
|
122 |
|
123 |
# Create the Gradio interface
|
124 |
with gr.Blocks() as demo:
|
|
|
231 |
|
232 |
# Launch the demo
|
233 |
if __name__ == "__main__":
|
234 |
+
demo.launch()
|