Daemontatox commited on
Commit
1bc3aec
Β·
verified Β·
1 Parent(s): 6d7af6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -90
app.py CHANGED
@@ -7,80 +7,15 @@ from datetime import datetime
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
 
10
- # Check for CUDA availability
11
- cuda_available = torch.cuda.is_available()
12
- device = torch.device("cuda" if cuda_available else "cpu")
13
- print(f"Using device: {device}")
14
 
15
- # Function to set up CUDA environment if available
16
- def setup_cuda_environment():
17
- if not cuda_available:
18
- print("CUDA not available, falling back to CPU")
19
- os.environ["RWKV_V7_ON"] = '1'
20
- os.environ["RWKV_JIT_ON"] = '1'
21
- os.environ["RWKV_CUDA_ON"] = '0'
22
- return False
23
-
24
- print("CUDA is available, setting up environment")
25
-
26
- # Try to detect CUDA location automatically
27
- possible_cuda_paths = [
28
- "/usr/local/cuda",
29
- "/opt/cuda",
30
- "/usr/lib/cuda",
31
- "/usr/cuda",
32
- "/usr/local/nvidia/cuda",
33
- "/usr/lib/nvidia-cuda-toolkit",
34
- "/usr/lib/x86_64-linux-gnu/cuda"
35
- ]
36
-
37
- cuda_found = False
38
- for path in possible_cuda_paths:
39
- if os.path.exists(path):
40
- os.environ["CUDA_HOME"] = path
41
- print(f"Found CUDA at: {path}")
42
- cuda_found = True
43
- break
44
-
45
- if not cuda_found:
46
- # If we can't find the CUDA path but CUDA is available,
47
- # try looking for common libraries
48
- try:
49
- import ctypes
50
- cuda_runtime = ctypes.cdll.LoadLibrary("libcudart.so")
51
- print("Found CUDA runtime library, proceeding without explicit CUDA_HOME")
52
- cuda_found = True
53
- except:
54
- print("Could not locate CUDA runtime library")
55
-
56
- # Set RWKV environment variables
57
- if cuda_found:
58
- os.environ["RWKV_V7_ON"] = '1'
59
- os.environ["RWKV_JIT_ON"] = '1'
60
- os.environ["RWKV_CUDA_ON"] = '1'
61
- else:
62
- print("CUDA is available but environment couldn't be set up correctly, falling back to CPU")
63
- os.environ["RWKV_V7_ON"] = '1'
64
- os.environ["RWKV_JIT_ON"] = '1'
65
- os.environ["RWKV_CUDA_ON"] = '0'
66
- return False
67
-
68
- return cuda_found
69
-
70
- # Initialize NVML for GPU monitoring if available
71
- has_nvml = False
72
- if cuda_available:
73
- try:
74
- from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
75
- nvmlInit()
76
- gpu_h = nvmlDeviceGetHandleByIndex(0)
77
- has_nvml = True
78
- print("NVML initialized for GPU monitoring")
79
- except:
80
- print("NVML not available, GPU monitoring disabled")
81
-
82
- # Set up CUDA environment
83
- use_cuda = setup_cuda_environment()
84
 
85
  # Model parameters
86
  ctx_limit = 4000
@@ -114,11 +49,11 @@ try:
114
  model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title_v6}.pth")
115
  print(f"Model downloaded to {model_path_v6}")
116
 
117
- # Select strategy based on available hardware
118
- strategy = 'cuda fp16' if use_cuda else 'cpu fp32'
119
  print(f"Using strategy: {strategy}")
120
 
121
- # Initialize model with appropriate strategy
122
  model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy=strategy)
123
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
124
  args = model_v6.args
@@ -132,7 +67,7 @@ except Exception as e:
132
 
133
  # Text generation parameters
134
  penalty_decay = 0.996
135
- @spaces.GPU
136
  def generate_prompt(instruction, input=""):
137
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
138
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
@@ -199,21 +134,10 @@ def evaluate(
199
  yield out_str.strip()
200
  out_last = i + 1
201
 
202
- # Log GPU info if available
203
- if use_cuda and has_nvml:
204
- try:
205
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
206
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
207
- print(f'{timestamp} - vram total: {gpu_info.total/1024**2:.2f}MB, used: {gpu_info.used/1024**2:.2f}MB, free: {gpu_info.free/1024**2:.2f}MB')
208
- except:
209
- print("Error getting GPU info")
210
-
211
  # Clean up to free memory
212
  del out
213
  del state
214
  gc.collect()
215
- if use_cuda:
216
- torch.cuda.empty_cache()
217
 
218
  yield out_str.strip()
219
  except Exception as e:
@@ -242,7 +166,7 @@ examples = [
242
  # Create Gradio UI
243
  with gr.Blocks(title=title_v6) as demo:
244
  model_status = "βœ… Model loaded successfully" if model_loaded else "❌ Model failed to load"
245
- device_status = f"Using {'CUDA' if use_cuda else 'CPU'}"
246
 
247
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n<p>{model_status} - {device_status}</p>\n</div>")
248
 
@@ -268,5 +192,6 @@ with gr.Blocks(title=title_v6) as demo:
268
 
269
  # Launch the app
270
  print("Starting Gradio app...")
271
- demo.queue(concurrency_count=1, max_size=10)
 
272
  demo.launch(share=False)
 
7
  import gradio as gr
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # Force CPU mode as requested
11
+ use_cuda = False
12
+ device = torch.device("cpu")
13
+ print(f"Using device: {device} (forced CPU mode)")
14
 
15
+ # Set RWKV environment variables for CPU
16
+ os.environ["RWKV_V7_ON"] = '1'
17
+ os.environ["RWKV_JIT_ON"] = '1'
18
+ os.environ["RWKV_CUDA_ON"] = '0'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Model parameters
21
  ctx_limit = 4000
 
49
  model_path_v6 = hf_hub_download(repo_id="BlinkDL/rwkv7-g1", filename=f"{title_v6}.pth")
50
  print(f"Model downloaded to {model_path_v6}")
51
 
52
+ # Use CPU strategy
53
+ strategy = 'cpu fp32'
54
  print(f"Using strategy: {strategy}")
55
 
56
+ # Initialize model with CPU strategy
57
  model_v6 = RWKV(model=model_path_v6.replace('.pth',''), strategy=strategy)
58
  pipeline_v6 = PIPELINE(model_v6, "rwkv_vocab_v20230424")
59
  args = model_v6.args
 
67
 
68
  # Text generation parameters
69
  penalty_decay = 0.996
70
+
71
  def generate_prompt(instruction, input=""):
72
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
73
  input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
 
134
  yield out_str.strip()
135
  out_last = i + 1
136
 
 
 
 
 
 
 
 
 
 
137
  # Clean up to free memory
138
  del out
139
  del state
140
  gc.collect()
 
 
141
 
142
  yield out_str.strip()
143
  except Exception as e:
 
166
  # Create Gradio UI
167
  with gr.Blocks(title=title_v6) as demo:
168
  model_status = "βœ… Model loaded successfully" if model_loaded else "❌ Model failed to load"
169
+ device_status = "Using CPU mode"
170
 
171
  gr.HTML(f"<div style=\"text-align: center;\">\n<h1>{title_v6}</h1>\n<p>{model_status} - {device_status}</p>\n</div>")
172
 
 
192
 
193
  # Launch the app
194
  print("Starting Gradio app...")
195
+ # Fix the queue method call by removing the incorrect parameter
196
+ demo.queue(max_size=10)
197
  demo.launch(share=False)