|
import os |
|
import gradio as gr |
|
import torch |
|
from rwkv.model import RWKV |
|
from rwkv.utils import PIPELINE, PIPELINE_ARGS |
|
from copy import deepcopy |
|
import requests |
|
import os.path |
|
from tqdm import tqdm |
|
import json |
|
from dataclasses import dataclass |
|
from typing import Optional, List |
|
|
|
|
|
os.environ['RWKV_JIT_ON'] = '1' |
|
os.environ["RWKV_CUDA_ON"] = '0' |
|
os.environ["RWKV_V7_ON"] = '1' |
|
|
|
|
|
MODELS = { |
|
"0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth", |
|
"0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" |
|
} |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth": { |
|
"n_layer": 12, |
|
"n_embd": 768, |
|
"ctx_len": 4096 |
|
}, |
|
"RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth": { |
|
"n_layer": 24, |
|
"n_embd": 1024, |
|
"ctx_len": 4096 |
|
} |
|
} |
|
|
|
@dataclass |
|
class ModelArgs: |
|
n_layer: int |
|
n_embd: int |
|
ctx_len: int |
|
vocab_size: int = 65536 |
|
n_head: int = 16 |
|
n_att: int = 1024 |
|
|
|
def download_file(url, filename): |
|
"""Generic file downloader with progress bar""" |
|
if not os.path.exists(filename): |
|
print(f"Downloading {filename}...") |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
with open(filename, 'wb') as file, tqdm( |
|
desc=filename, |
|
total=total_size, |
|
unit='iB', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as pbar: |
|
for data in response.iter_content(chunk_size=1024): |
|
size = file.write(data) |
|
pbar.update(size) |
|
|
|
def download_model(model_name): |
|
"""Download model if not present""" |
|
if not os.path.exists(model_name): |
|
url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}" |
|
download_file(url, model_name) |
|
|
|
class CustomPipeline(PIPELINE): |
|
def __init__(self, model, vocab_file): |
|
super().__init__(model, vocab_file) |
|
self.model_args = None |
|
|
|
def set_model_args(self, args: ModelArgs): |
|
self.model_args = args |
|
|
|
class ModelManager: |
|
def __init__(self): |
|
self.current_model = None |
|
self.current_model_name = None |
|
self.pipeline = None |
|
|
|
def load_model(self, model_choice): |
|
model_file = MODELS[model_choice] |
|
if model_file != self.current_model_name: |
|
download_model(model_file) |
|
|
|
|
|
config = MODEL_CONFIGS[model_file] |
|
model_args = ModelArgs( |
|
n_layer=config['n_layer'], |
|
n_embd=config['n_embd'], |
|
ctx_len=config['ctx_len'] |
|
) |
|
|
|
|
|
self.current_model = RWKV( |
|
model=model_file, |
|
strategy='cpu fp32' |
|
) |
|
|
|
|
|
self.pipeline = CustomPipeline(self.current_model, "20B_tokenizer.json") |
|
self.pipeline.set_model_args(model_args) |
|
self.current_model_name = model_file |
|
|
|
return self.pipeline |
|
|
|
model_manager = ModelManager() |
|
|
|
def generate_response( |
|
model_choice, |
|
user_prompt, |
|
system_prompt, |
|
temperature, |
|
top_p, |
|
top_k, |
|
alpha_frequency, |
|
alpha_presence, |
|
alpha_decay, |
|
max_tokens |
|
): |
|
try: |
|
|
|
pipeline = model_manager.load_model(model_choice) |
|
|
|
|
|
if system_prompt.strip(): |
|
ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:" |
|
else: |
|
ctx = f"User: {user_prompt.strip()}\n\nA:" |
|
|
|
|
|
args = PIPELINE_ARGS( |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
alpha_frequency=alpha_frequency, |
|
alpha_presence=alpha_presence, |
|
alpha_decay=alpha_decay, |
|
token_ban=[], |
|
token_stop=[], |
|
chunk_len=256, |
|
model_args=pipeline.model_args |
|
) |
|
|
|
|
|
response = "" |
|
def callback(text): |
|
nonlocal response |
|
response += text |
|
return response |
|
|
|
pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback) |
|
return response |
|
except Exception as e: |
|
import traceback |
|
return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# RWKV-7 Language Model Demo") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_choice = gr.Radio( |
|
choices=list(MODELS.keys()), |
|
value=list(MODELS.keys())[0], |
|
label="Model Selection" |
|
) |
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Optional system prompt to set the context", |
|
lines=3, |
|
value="You are a helpful AI assistant. You provide detailed and accurate responses." |
|
) |
|
user_prompt = gr.Textbox( |
|
label="User Prompt", |
|
placeholder="Enter your prompt here", |
|
lines=3 |
|
) |
|
max_tokens = gr.Slider( |
|
minimum=1, |
|
maximum=1000, |
|
value=200, |
|
step=1, |
|
label="Max Tokens" |
|
) |
|
|
|
with gr.Column(): |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.0, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.7, |
|
step=0.05, |
|
label="Top P" |
|
) |
|
top_k = gr.Slider( |
|
minimum=0, |
|
maximum=200, |
|
value=100, |
|
step=1, |
|
label="Top K" |
|
) |
|
alpha_frequency = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.25, |
|
step=0.05, |
|
label="Alpha Frequency" |
|
) |
|
alpha_presence = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.25, |
|
step=0.05, |
|
label="Alpha Presence" |
|
) |
|
alpha_decay = gr.Slider( |
|
minimum=0.9, |
|
maximum=1.0, |
|
value=0.996, |
|
step=0.001, |
|
label="Alpha Decay" |
|
) |
|
|
|
generate_button = gr.Button("Generate") |
|
output = gr.Textbox(label="Generated Response", lines=10) |
|
|
|
generate_button.click( |
|
fn=generate_response, |
|
inputs=[ |
|
model_choice, |
|
user_prompt, |
|
system_prompt, |
|
temperature, |
|
top_p, |
|
top_k, |
|
alpha_frequency, |
|
alpha_presence, |
|
alpha_decay, |
|
max_tokens |
|
], |
|
outputs=output |
|
) |
|
|
|
gr.Markdown(""" |
|
## Model Information |
|
- **0.1B Model**: Smaller model, faster but less capable |
|
- **0.4B Model**: Larger model, slower but more capable |
|
|
|
## Parameter Descriptions |
|
- **Temperature**: Controls randomness in the output (higher = more random) |
|
- **Top P**: Nucleus sampling threshold (lower = more focused) |
|
- **Top K**: Limits the number of tokens considered for each step |
|
- **Alpha Frequency**: Penalizes frequent tokens |
|
- **Alpha Presence**: Penalizes tokens that have appeared before |
|
- **Alpha Decay**: Rate at which penalties decay |
|
- **Max Tokens**: Maximum length of generated response |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |