Daemontatox commited on
Commit
6d7af6b
·
verified ·
1 Parent(s): 7e59c62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -76
app.py CHANGED
@@ -1,39 +1,138 @@
1
-
2
- import os, copy
3
- os.environ["RWKV_V7_ON"] = '1'
4
- os.environ["RWKV_JIT_ON"] = '1'
5
- os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
6
-
7
- from rwkv.model import RWKV
8
-
9
- import gc, re
10
- import gradio as gr
11
- import base64
12
- from io import BytesIO
13
  import torch
14
  import torch.nn.functional as F
15
  from datetime import datetime
 
16
  from huggingface_hub import hf_hub_download
17
- from pynvml import *
18
- nvmlInit()
19
- gpu_h = nvmlDeviceGetHandleByIndex(0)
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- ctx_limit = 4000
23
- gen_limit = 32000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- ########################## text rwkv ################################################################
26
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
 
 
 
 
 
 
 
 
 
27
 
 
 
 
 
 
 
28
  title_v6 = "rwkv7-g1-0.1b-20250307-ctx4096"
29
- model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title_v6}.pth")
30
- model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy='cuda fp16')
31
- pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
32
 
33
- args = model_v6.args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- penalty_decay = 0.996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
37
  def generate_prompt(instruction, input=""):
38
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
39
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -55,57 +154,74 @@ def evaluate(
55
  presencePenalty = 0.1,
56
  countPenalty = 0.1,
57
  ):
58
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
59
- alpha_frequency = countPenalty,
60
- alpha_presence = presencePenalty,
61
- token_ban = [], # ban the generation of some tokens
62
- token_stop = [0]) # stop generation whenever you see any token here
63
- ctx = ctx.strip()
64
- all_tokens = []
65
- out_last = 0
66
- out_str = ''
67
- occurrence = {}
68
- state = None
69
- for i in range(int(token_count)):
70
-
71
- input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
72
- out, state = model_v6.forward(input_ids, state)
73
- for n in occurrence:
74
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
75
-
76
- token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
77
- if token in args.token_stop:
78
- break
79
- all_tokens += [token]
80
- for xxx in occurrence:
81
- occurrence[xxx] *= penalty_decay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- ttt = pipeline_v6.decode([token])
84
- www = 1
85
- if ttt in ' \t0123456789':
86
- www = 0
87
- #elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
88
- # www = 0.5
89
- if token not in occurrence:
90
- occurrence[token] = www
91
- else:
92
- occurrence[token] += www
93
-
94
- tmp = pipeline_v6.decode(all_tokens[out_last:])
95
- if '\ufffd' not in tmp:
96
- out_str += tmp
97
- yield out_str.strip()
98
- out_last = i + 1
99
-
100
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
101
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
102
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
103
- del out
104
- del state
105
- gc.collect()
106
- torch.cuda.empty_cache()
107
- yield out_str.strip()
108
 
 
109
  examples = [
110
  ["User: simulate SpaceX mars landing using python\n\nAssistant: <think", gen_limit, 1, 0.3, 0.5, 0.5],
111
  [generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.3, 0.5, 0.5],
@@ -119,12 +235,16 @@ examples = [
119
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
120
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
121
  ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", gen_limit, 1, 0.3, 0.5, 0.5],
122
- ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n“放心,我在,你们就在!”智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
123
  ]
124
 
125
  ##################################################################################################################
 
126
  with gr.Blocks(title=title_v6) as demo:
127
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n</div>")
 
 
 
128
 
129
  with gr.Tab("=== Base Model (Raw Generation) ==="):
130
  gr.Markdown(f'This is [RWKV7 G1](https://huggingface.co/BlinkDL/rwkv7-g1) 0.1B (!!!) L12-D768 reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.')
@@ -146,5 +266,7 @@ with gr.Blocks(title=title_v6) as demo:
146
  clear.click(lambda: None, [], [output])
147
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
148
 
 
 
149
  demo.queue(concurrency_count=1, max_size=10)
150
  demo.launch(share=False)
 
1
+ import spaces
2
+ import os, copy, gc, re, sys
3
+ import traceback
 
 
 
 
 
 
 
 
 
4
  import torch
5
  import torch.nn.functional as F
6
  from datetime import datetime
7
+ import gradio as gr
8
  from huggingface_hub import hf_hub_download
 
 
 
 
9
 
10
+ # Check for CUDA availability
11
+ cuda_available = torch.cuda.is_available()
12
+ device = torch.device("cuda" if cuda_available else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # Function to set up CUDA environment if available
16
+ def setup_cuda_environment():
17
+ if not cuda_available:
18
+ print("CUDA not available, falling back to CPU")
19
+ os.environ["RWKV_V7_ON"] = '1'
20
+ os.environ["RWKV_JIT_ON"] = '1'
21
+ os.environ["RWKV_CUDA_ON"] = '0'
22
+ return False
23
+
24
+ print("CUDA is available, setting up environment")
25
+
26
+ # Try to detect CUDA location automatically
27
+ possible_cuda_paths = [
28
+ "/usr/local/cuda",
29
+ "/opt/cuda",
30
+ "/usr/lib/cuda",
31
+ "/usr/cuda",
32
+ "/usr/local/nvidia/cuda",
33
+ "/usr/lib/nvidia-cuda-toolkit",
34
+ "/usr/lib/x86_64-linux-gnu/cuda"
35
+ ]
36
+
37
+ cuda_found = False
38
+ for path in possible_cuda_paths:
39
+ if os.path.exists(path):
40
+ os.environ["CUDA_HOME"] = path
41
+ print(f"Found CUDA at: {path}")
42
+ cuda_found = True
43
+ break
44
+
45
+ if not cuda_found:
46
+ # If we can't find the CUDA path but CUDA is available,
47
+ # try looking for common libraries
48
+ try:
49
+ import ctypes
50
+ cuda_runtime = ctypes.cdll.LoadLibrary("libcudart.so")
51
+ print("Found CUDA runtime library, proceeding without explicit CUDA_HOME")
52
+ cuda_found = True
53
+ except:
54
+ print("Could not locate CUDA runtime library")
55
+
56
+ # Set RWKV environment variables
57
+ if cuda_found:
58
+ os.environ["RWKV_V7_ON"] = '1'
59
+ os.environ["RWKV_JIT_ON"] = '1'
60
+ os.environ["RWKV_CUDA_ON"] = '1'
61
+ else:
62
+ print("CUDA is available but environment couldn't be set up correctly, falling back to CPU")
63
+ os.environ["RWKV_V7_ON"] = '1'
64
+ os.environ["RWKV_JIT_ON"] = '1'
65
+ os.environ["RWKV_CUDA_ON"] = '0'
66
+ return False
67
+
68
+ return cuda_found
69
 
70
+ # Initialize NVML for GPU monitoring if available
71
+ has_nvml = False
72
+ if cuda_available:
73
+ try:
74
+ from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
75
+ nvmlInit()
76
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
77
+ has_nvml = True
78
+ print("NVML initialized for GPU monitoring")
79
+ except:
80
+ print("NVML not available, GPU monitoring disabled")
81
 
82
+ # Set up CUDA environment
83
+ use_cuda = setup_cuda_environment()
84
+
85
+ # Model parameters
86
+ ctx_limit = 4000
87
+ gen_limit = 32000
88
  title_v6 = "rwkv7-g1-0.1b-20250307-ctx4096"
 
 
 
89
 
90
+ # Load RWKV with fallback mechanisms
91
+ try:
92
+ # First try importing normally
93
+ from rwkv.model import RWKV
94
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
95
+ print("RWKV imported successfully")
96
+ except Exception as e:
97
+ print(f"Error importing RWKV: {e}")
98
+ print("Attempting fallback import method...")
99
+
100
+ # Fallback method - reinstall the package
101
+ try:
102
+ import subprocess
103
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "rwkv"])
104
+ from rwkv.model import RWKV
105
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
106
+ print("RWKV imported after reinstall")
107
+ except Exception as e:
108
+ print(f"Failed to import RWKV after reinstall: {e}")
109
+ raise
110
 
111
+ # Download and initialize the model
112
+ try:
113
+ print(f"Downloading model {title_v6}...")
114
+ model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title_v6}.pth")
115
+ print(f"Model downloaded to {model_path_v6}")
116
+
117
+ # Select strategy based on available hardware
118
+ strategy = 'cuda fp16' if use_cuda else 'cpu fp32'
119
+ print(f"Using strategy: {strategy}")
120
+
121
+ # Initialize model with appropriate strategy
122
+ model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy=strategy)
123
+ pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
124
+ args = model_v6.args
125
+ print("Model initialized successfully")
126
+
127
+ model_loaded = True
128
+ except Exception as e:
129
+ print(f"Error loading model: {e}")
130
+ traceback.print_exc()
131
+ model_loaded = False
132
 
133
+ # Text generation parameters
134
+ penalty_decay = 0.996
135
+ @spaces.GPU
136
  def generate_prompt(instruction, input=""):
137
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
138
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
 
154
  presencePenalty = 0.1,
155
  countPenalty = 0.1,
156
  ):
157
+ if not model_loaded:
158
+ yield "Error: Model failed to load. Please check logs for details."
159
+ return
160
+
161
+ try:
162
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
163
+ alpha_frequency = countPenalty,
164
+ alpha_presence = presencePenalty,
165
+ token_ban = [], # ban the generation of some tokens
166
+ token_stop = [0]) # stop generation whenever you see any token here
167
+ ctx = ctx.strip()
168
+ all_tokens = []
169
+ out_last = 0
170
+ out_str = ''
171
+ occurrence = {}
172
+ state = None
173
+ for i in range(int(token_count)):
174
+
175
+ input_ids = pipeline_v6.encode(ctx)[-ctx_limit:] if i == 0 else [token]
176
+ out, state = model_v6.forward(input_ids, state)
177
+ for n in occurrence:
178
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
179
+
180
+ token = pipeline_v6.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
181
+ if token in args.token_stop:
182
+ break
183
+ all_tokens += [token]
184
+ for xxx in occurrence:
185
+ occurrence[xxx] *= penalty_decay
186
+
187
+ ttt = pipeline_v6.decode([token])
188
+ www = 1
189
+ if ttt in ' \t0123456789':
190
+ www = 0
191
+ if token not in occurrence:
192
+ occurrence[token] = www
193
+ else:
194
+ occurrence[token] += www
195
+
196
+ tmp = pipeline_v6.decode(all_tokens[out_last:])
197
+ if '\ufffd' not in tmp:
198
+ out_str += tmp
199
+ yield out_str.strip()
200
+ out_last = i + 1
201
+
202
+ # Log GPU info if available
203
+ if use_cuda and has_nvml:
204
+ try:
205
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
206
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
207
+ print(f'{timestamp} - vram total: {gpu_info.total/1024**2:.2f}MB, used: {gpu_info.used/1024**2:.2f}MB, free: {gpu_info.free/1024**2:.2f}MB')
208
+ except:
209
+ print("Error getting GPU info")
210
+
211
+ # Clean up to free memory
212
+ del out
213
+ del state
214
+ gc.collect()
215
+ if use_cuda:
216
+ torch.cuda.empty_cache()
217
 
218
+ yield out_str.strip()
219
+ except Exception as e:
220
+ print(f"Error during generation: {e}")
221
+ traceback.print_exc()
222
+ yield f"Error during generation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ # Example prompts
225
  examples = [
226
  ["User: simulate SpaceX mars landing using python\n\nAssistant: <think", gen_limit, 1, 0.3, 0.5, 0.5],
227
  [generate_prompt("Please give the pros and cons of hodl versus active trading."), gen_limit, 1, 0.3, 0.5, 0.5],
 
235
  ["En una pequeña aldea escondida entre las montañas de Andalucía, donde las calles aún conservaban el eco de antiguas leyendas, vivía un joven llamado Alejandro.", gen_limit, 1, 0.3, 0.5, 0.5],
236
  ["Dans le cœur battant de Paris, sous le ciel teinté d'un crépuscule d'or et de pourpre, se tenait une petite librairie oubliée par le temps.", gen_limit, 1, 0.3, 0.5, 0.5],
237
  ["في تطور مذهل وغير مسبوق، أعلنت السلطات المحلية في العاصمة عن اكتشاف أثري قد يغير مجرى التاريخ كما نعرفه.", gen_limit, 1, 0.3, 0.5, 0.5],
238
+ ['''"当然可以,大宇宙不会因为这五公斤就不坍缩了。"关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。\n但愿大宇宙能够忽略这个误差。\n程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。\n"放心,我在,你们就在!"智子对两位人类朋友说。\n聚变发动机启动了,推进器发出幽幽的蓝光,''', gen_limit, 1, 0.3, 0.5, 0.5],
239
  ]
240
 
241
  ##################################################################################################################
242
+ # Create Gradio UI
243
  with gr.Blocks(title=title_v6) as demo:
244
+ model_status = " Model loaded successfully" if model_loaded else "❌ Model failed to load"
245
+ device_status = f"Using {'CUDA' if use_cuda else 'CPU'}"
246
+
247
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n<p>{model_status} - {device_status}</p>\n</div>")
248
 
249
  with gr.Tab("=== Base Model (Raw Generation) ==="):
250
  gr.Markdown(f'This is [RWKV7 G1](https://huggingface.co/BlinkDL/rwkv7-g1) 0.1B (!!!) L12-D768 reasoning base LM - an attention-free pure RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports 100+ world languages and code. Check [400+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Can try examples (bottom of page) *** (can edit them). Demo limited to ctxlen {ctx_limit}.')
 
266
  clear.click(lambda: None, [], [output])
267
  data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
268
 
269
+ # Launch the app
270
+ print("Starting Gradio app...")
271
  demo.queue(concurrency_count=1, max_size=10)
272
  demo.launch(share=False)