RWKV-7 / app.py
Tonic's picture
Update app.py
711a57d verified
raw
history blame
6.45 kB
import os
import gradio as gr
import torch
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from copy import deepcopy
import requests
import os.path
from tqdm import tqdm
# Set environment variables
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'
os.environ["RWKV_V7_ON"] = '1'
# Model options
MODELS = {
"0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth",
"0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
}
def download_model(model_name):
"""Download model if not present"""
if not os.path.exists(model_name):
print(f"Downloading {model_name}...")
url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(model_name, 'wb') as file, tqdm(
desc=model_name,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
pbar.update(size)
class ModelManager:
def __init__(self):
self.current_model = None
self.current_model_name = None
self.pipeline = None
def load_model(self, model_name):
if model_name != self.current_model_name:
download_model(MODELS[model_name])
self.current_model = RWKV(model=MODELS[model_name], strategy='cpu fp32')
self.pipeline = PIPELINE(self.current_model, "rwkv_vocab_v20230424")
self.current_model_name = model_name
return self.pipeline
model_manager = ModelManager()
def generate_response(
model_choice,
user_prompt,
system_prompt,
temperature,
top_p,
top_k,
alpha_frequency,
alpha_presence,
alpha_decay,
max_tokens
):
try:
# Get or load the model
pipeline = model_manager.load_model(model_choice)
# Prepare the context
if system_prompt.strip():
ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:"
else:
ctx = f"User: {user_prompt.strip()}\n\nA:"
# Prepare generation arguments
args = PIPELINE_ARGS(
temperature=temperature,
top_p=top_p,
top_k=top_k,
alpha_frequency=alpha_frequency,
alpha_presence=alpha_presence,
alpha_decay=alpha_decay,
token_ban=[],
token_stop=[],
chunk_len=256
)
# Generate response
response = ""
def callback(text):
nonlocal response
response += text
return response
pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
return response
except Exception as e:
return f"Error: {str(e)}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# RWKV-7 Language Model Demo")
with gr.Row():
with gr.Column():
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model Selection"
)
system_prompt = gr.Textbox(
label="System Prompt",
placeholder="Optional system prompt to set the context",
lines=3,
value="You are a helpful AI assistant. You provide detailed and accurate responses."
)
user_prompt = gr.Textbox(
label="User Prompt",
placeholder="Enter your prompt here",
lines=3
)
max_tokens = gr.Slider(
minimum=1,
maximum=1000,
value=200,
step=1,
label="Max Tokens"
)
with gr.Column():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.05,
label="Top P"
)
top_k = gr.Slider(
minimum=0,
maximum=200,
value=100,
step=1,
label="Top K"
)
alpha_frequency = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Alpha Frequency"
)
alpha_presence = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Alpha Presence"
)
alpha_decay = gr.Slider(
minimum=0.9,
maximum=1.0,
value=0.996,
step=0.001,
label="Alpha Decay"
)
generate_button = gr.Button("Generate")
output = gr.Textbox(label="Generated Response", lines=10)
generate_button.click(
fn=generate_response,
inputs=[
model_choice,
user_prompt,
system_prompt,
temperature,
top_p,
top_k,
alpha_frequency,
alpha_presence,
alpha_decay,
max_tokens
],
outputs=output
)
gr.Markdown("""
## Model Information
- **0.1B Model**: Smaller model, faster but less capable
- **0.4B Model**: Larger model, slower but more capable
## Parameter Descriptions
- **Temperature**: Controls randomness in the output (higher = more random)
- **Top P**: Nucleus sampling threshold (lower = more focused)
- **Top K**: Limits the number of tokens considered for each step
- **Alpha Frequency**: Penalizes frequent tokens
- **Alpha Presence**: Penalizes tokens that have appeared before
- **Alpha Decay**: Rate at which penalties decay
- **Max Tokens**: Maximum length of generated response
""")
# Launch the demo
if __name__ == "__main__":
demo.launch(ssr_mode=False)