Tonic commited on
Commit
612202d
·
verified ·
1 Parent(s): 50a3242

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from rwkv.model import RWKV
5
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
6
+ from copy import deepcopy
7
+ import requests
8
+ import os.path
9
+
10
+ # Set environment variables
11
+ os.environ['RWKV_JIT_ON'] = '1'
12
+ os.environ["RWKV_CUDA_ON"] = '0'
13
+
14
+ # Model options
15
+ MODELS = {
16
+ "0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth",
17
+ "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
18
+ }
19
+
20
+ def download_model(model_name):
21
+ """Download model if not present"""
22
+ if not os.path.exists(model_name):
23
+ print(f"Downloading {model_name}...")
24
+ url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
25
+ response = requests.get(url, stream=True)
26
+ total_size = int(response.headers.get('content-length', 0))
27
+
28
+ with open(model_name, 'wb') as file, tqdm(
29
+ desc=model_name,
30
+ total=total_size,
31
+ unit='iB',
32
+ unit_scale=True,
33
+ unit_divisor=1024,
34
+ ) as pbar:
35
+ for data in response.iter_content(chunk_size=1024):
36
+ size = file.write(data)
37
+ pbar.update(size)
38
+
39
+ class ModelManager:
40
+ def __init__(self):
41
+ self.current_model = None
42
+ self.current_model_name = None
43
+ self.pipeline = None
44
+
45
+ def load_model(self, model_name):
46
+ if model_name != self.current_model_name:
47
+ download_model(MODELS[model_name])
48
+ self.current_model = RWKV(model=MODELS[model_name], strategy='cpu fp32')
49
+ self.pipeline = PIPELINE(self.current_model, "rwkv_vocab_v20230424")
50
+ self.current_model_name = model_name
51
+ return self.pipeline
52
+
53
+ model_manager = ModelManager()
54
+
55
+ def generate_response(
56
+ model_choice,
57
+ user_prompt,
58
+ system_prompt,
59
+ temperature,
60
+ top_p,
61
+ top_k,
62
+ alpha_frequency,
63
+ alpha_presence,
64
+ alpha_decay,
65
+ max_tokens
66
+ ):
67
+ try:
68
+ # Get or load the model
69
+ pipeline = model_manager.load_model(model_choice)
70
+
71
+ # Prepare the context
72
+ if system_prompt.strip():
73
+ ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:"
74
+ else:
75
+ ctx = f"User: {user_prompt.strip()}\n\nA:"
76
+
77
+ # Prepare generation arguments
78
+ args = PIPELINE_ARGS(
79
+ temperature=temperature,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ alpha_frequency=alpha_frequency,
83
+ alpha_presence=alpha_presence,
84
+ alpha_decay=alpha_decay,
85
+ token_ban=[],
86
+ token_stop=[],
87
+ chunk_len=256
88
+ )
89
+
90
+ # Generate response
91
+ response = ""
92
+ def callback(text):
93
+ nonlocal response
94
+ response += text
95
+ return response
96
+
97
+ pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
98
+ return response
99
+ except Exception as e:
100
+ return f"Error: {str(e)}"
101
+
102
+ # Create the Gradio interface
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("# RWKV-7 Language Model Demo")
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ model_choice = gr.Radio(
109
+ choices=list(MODELS.keys()),
110
+ value=list(MODELS.keys())[0],
111
+ label="Model Selection"
112
+ )
113
+ system_prompt = gr.Textbox(
114
+ label="System Prompt",
115
+ placeholder="Optional system prompt to set the context",
116
+ lines=3,
117
+ value="You are a helpful AI assistant. You provide detailed and accurate responses."
118
+ )
119
+ user_prompt = gr.Textbox(
120
+ label="User Prompt",
121
+ placeholder="Enter your prompt here",
122
+ lines=3
123
+ )
124
+ max_tokens = gr.Slider(
125
+ minimum=1,
126
+ maximum=1000,
127
+ value=200,
128
+ step=1,
129
+ label="Max Tokens"
130
+ )
131
+
132
+ with gr.Column():
133
+ temperature = gr.Slider(
134
+ minimum=0.1,
135
+ maximum=2.0,
136
+ value=1.0,
137
+ step=0.1,
138
+ label="Temperature"
139
+ )
140
+ top_p = gr.Slider(
141
+ minimum=0.0,
142
+ maximum=1.0,
143
+ value=0.7,
144
+ step=0.05,
145
+ label="Top P"
146
+ )
147
+ top_k = gr.Slider(
148
+ minimum=0,
149
+ maximum=200,
150
+ value=100,
151
+ step=1,
152
+ label="Top K"
153
+ )
154
+ alpha_frequency = gr.Slider(
155
+ minimum=0.0,
156
+ maximum=1.0,
157
+ value=0.25,
158
+ step=0.05,
159
+ label="Alpha Frequency"
160
+ )
161
+ alpha_presence = gr.Slider(
162
+ minimum=0.0,
163
+ maximum=1.0,
164
+ value=0.25,
165
+ step=0.05,
166
+ label="Alpha Presence"
167
+ )
168
+ alpha_decay = gr.Slider(
169
+ minimum=0.9,
170
+ maximum=1.0,
171
+ value=0.996,
172
+ step=0.001,
173
+ label="Alpha Decay"
174
+ )
175
+
176
+ generate_button = gr.Button("Generate")
177
+ output = gr.Textbox(label="Generated Response", lines=10)
178
+
179
+ generate_button.click(
180
+ fn=generate_response,
181
+ inputs=[
182
+ model_choice,
183
+ user_prompt,
184
+ system_prompt,
185
+ temperature,
186
+ top_p,
187
+ top_k,
188
+ alpha_frequency,
189
+ alpha_presence,
190
+ alpha_decay,
191
+ max_tokens
192
+ ],
193
+ outputs=output
194
+ )
195
+
196
+ gr.Markdown("""
197
+ ## Model Information
198
+ - **0.1B Model**: Smaller model, faster but less capable
199
+ - **0.4B Model**: Larger model, slower but more capable
200
+
201
+ ## Parameter Descriptions
202
+ - **Temperature**: Controls randomness in the output (higher = more random)
203
+ - **Top P**: Nucleus sampling threshold (lower = more focused)
204
+ - **Top K**: Limits the number of tokens considered for each step
205
+ - **Alpha Frequency**: Penalizes frequent tokens
206
+ - **Alpha Presence**: Penalizes tokens that have appeared before
207
+ - **Alpha Decay**: Rate at which penalties decay
208
+ - **Max Tokens**: Maximum length of generated response
209
+ """)
210
+
211
+ # Launch the demo
212
+ if __name__ == "__main__":
213
+ demo.launch(ssr_mode=False)