RWKV-7 / app.py
Tonic's picture
Update app.py
fb15fc9 verified
raw
history blame
8.08 kB
import os
import gradio as gr
import torch
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from copy import deepcopy
import requests
import os.path
from tqdm import tqdm
import json
from dataclasses import dataclass
from typing import Optional, List
# Set environment variables
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'
os.environ["RWKV_V7_ON"] = '1'
# Model options
MODELS = {
"0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth",
"0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
}
# Model configurations
MODEL_CONFIGS = {
"RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth": {
"n_layer": 12,
"n_embd": 768,
"ctx_len": 4096
},
"RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth": {
"n_layer": 24,
"n_embd": 1024,
"ctx_len": 4096
}
}
@dataclass
class ModelArgs:
n_layer: int
n_embd: int
ctx_len: int
vocab_size: int = 65536
n_head: int = 16 # Number of attention heads
n_att: int = 1024 # Attention dimension
def download_file(url, filename):
"""Generic file downloader with progress bar"""
if not os.path.exists(filename):
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
pbar.update(size)
def download_model(model_name):
"""Download model if not present"""
if not os.path.exists(model_name):
url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
download_file(url, model_name)
class CustomPipeline(PIPELINE):
def __init__(self, model, vocab_file):
super().__init__(model, vocab_file)
self.model_args = None
def set_model_args(self, args: ModelArgs):
self.model_args = args
class ModelManager:
def __init__(self):
self.current_model = None
self.current_model_name = None
self.pipeline = None
def load_model(self, model_choice):
model_file = MODELS[model_choice]
if model_file != self.current_model_name:
download_model(model_file)
# Get model configuration
config = MODEL_CONFIGS[model_file]
model_args = ModelArgs(
n_layer=config['n_layer'],
n_embd=config['n_embd'],
ctx_len=config['ctx_len']
)
# Initialize model with args
self.current_model = RWKV(
model=model_file,
strategy='cpu fp32'
)
# Initialize custom pipeline
self.pipeline = CustomPipeline(self.current_model, "20B_tokenizer.json")
self.pipeline.set_model_args(model_args)
self.current_model_name = model_file
return self.pipeline
model_manager = ModelManager()
def generate_response(
model_choice,
user_prompt,
system_prompt,
temperature,
top_p,
top_k,
alpha_frequency,
alpha_presence,
alpha_decay,
max_tokens
):
try:
# Get or load the model
pipeline = model_manager.load_model(model_choice)
# Prepare the context
if system_prompt.strip():
ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:"
else:
ctx = f"User: {user_prompt.strip()}\n\nA:"
# Prepare generation arguments
args = PIPELINE_ARGS(
temperature=temperature,
top_p=top_p,
top_k=top_k,
alpha_frequency=alpha_frequency,
alpha_presence=alpha_presence,
alpha_decay=alpha_decay,
token_ban=[],
token_stop=[],
chunk_len=256,
model_args=pipeline.model_args # Pass model args to pipeline
)
# Generate response
response = ""
def callback(text):
nonlocal response
response += text
return response
pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback)
return response
except Exception as e:
import traceback
return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# RWKV-7 Language Model Demo")
with gr.Row():
with gr.Column():
model_choice = gr.Radio(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model Selection"
)
system_prompt = gr.Textbox(
label="System Prompt",
placeholder="Optional system prompt to set the context",
lines=3,
value="You are a helpful AI assistant. You provide detailed and accurate responses."
)
user_prompt = gr.Textbox(
label="User Prompt",
placeholder="Enter your prompt here",
lines=3
)
max_tokens = gr.Slider(
minimum=1,
maximum=1000,
value=200,
step=1,
label="Max Tokens"
)
with gr.Column():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.05,
label="Top P"
)
top_k = gr.Slider(
minimum=0,
maximum=200,
value=100,
step=1,
label="Top K"
)
alpha_frequency = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Alpha Frequency"
)
alpha_presence = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Alpha Presence"
)
alpha_decay = gr.Slider(
minimum=0.9,
maximum=1.0,
value=0.996,
step=0.001,
label="Alpha Decay"
)
generate_button = gr.Button("Generate")
output = gr.Textbox(label="Generated Response", lines=10)
generate_button.click(
fn=generate_response,
inputs=[
model_choice,
user_prompt,
system_prompt,
temperature,
top_p,
top_k,
alpha_frequency,
alpha_presence,
alpha_decay,
max_tokens
],
outputs=output
)
gr.Markdown("""
## Model Information
- **0.1B Model**: Smaller model, faster but less capable
- **0.4B Model**: Larger model, slower but more capable
## Parameter Descriptions
- **Temperature**: Controls randomness in the output (higher = more random)
- **Top P**: Nucleus sampling threshold (lower = more focused)
- **Top K**: Limits the number of tokens considered for each step
- **Alpha Frequency**: Penalizes frequent tokens
- **Alpha Presence**: Penalizes tokens that have appeared before
- **Alpha Decay**: Rate at which penalties decay
- **Max Tokens**: Maximum length of generated response
""")
# Launch the demo
if __name__ == "__main__":
demo.launch()