import os import gradio as gr import torch from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from copy import deepcopy import requests import os.path from tqdm import tqdm import json from dataclasses import dataclass from typing import Optional, List # Set environment variables os.environ['RWKV_JIT_ON'] = '1' os.environ["RWKV_CUDA_ON"] = '0' os.environ["RWKV_V7_ON"] = '1' # Model options MODELS = { "0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth", "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" } # Model configurations MODEL_CONFIGS = { "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth": { "n_layer": 12, "n_embd": 768, "ctx_len": 4096 }, "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth": { "n_layer": 24, "n_embd": 1024, "ctx_len": 4096 } } @dataclass class ModelArgs: n_layer: int n_embd: int ctx_len: int vocab_size: int = 65536 n_head: int = 16 # Number of attention heads n_att: int = 1024 # Attention dimension def download_file(url, filename): """Generic file downloader with progress bar""" if not os.path.exists(filename): print(f"Downloading {filename}...") response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(filename, 'wb') as file, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(chunk_size=1024): size = file.write(data) pbar.update(size) def download_model(model_name): """Download model if not present""" if not os.path.exists(model_name): url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}" download_file(url, model_name) class CustomPipeline(PIPELINE): def __init__(self, model, vocab_file): super().__init__(model, vocab_file) self.model_args = None def set_model_args(self, args: ModelArgs): self.model_args = args class ModelManager: def __init__(self): self.current_model = None self.current_model_name = None self.pipeline = None def load_model(self, model_choice): model_file = MODELS[model_choice] if model_file != self.current_model_name: download_model(model_file) # Get model configuration config = MODEL_CONFIGS[model_file] model_args = ModelArgs( n_layer=config['n_layer'], n_embd=config['n_embd'], ctx_len=config['ctx_len'] ) # Initialize model with args self.current_model = RWKV( model=model_file, strategy='cpu fp32' ) # Initialize custom pipeline self.pipeline = CustomPipeline(self.current_model, "20B_tokenizer.json") self.pipeline.set_model_args(model_args) self.current_model_name = model_file return self.pipeline model_manager = ModelManager() def generate_response( model_choice, user_prompt, system_prompt, temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, max_tokens ): try: # Get or load the model pipeline = model_manager.load_model(model_choice) # Prepare the context if system_prompt.strip(): ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:" else: ctx = f"User: {user_prompt.strip()}\n\nA:" # Prepare generation arguments args = PIPELINE_ARGS( temperature=temperature, top_p=top_p, top_k=top_k, alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, alpha_decay=alpha_decay, token_ban=[], token_stop=[], chunk_len=256, model_args=pipeline.model_args # Pass model args to pipeline ) # Generate response response = "" def callback(text): nonlocal response response += text return response pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback) return response except Exception as e: import traceback return f"Error: {str(e)}\nStack trace: {traceback.format_exc()}" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# RWKV-7 Language Model Demo") with gr.Row(): with gr.Column(): model_choice = gr.Radio( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Model Selection" ) system_prompt = gr.Textbox( label="System Prompt", placeholder="Optional system prompt to set the context", lines=3, value="You are a helpful AI assistant. You provide detailed and accurate responses." ) user_prompt = gr.Textbox( label="User Prompt", placeholder="Enter your prompt here", lines=3 ) max_tokens = gr.Slider( minimum=1, maximum=1000, value=200, step=1, label="Max Tokens" ) with gr.Column(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Top P" ) top_k = gr.Slider( minimum=0, maximum=200, value=100, step=1, label="Top K" ) alpha_frequency = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Alpha Frequency" ) alpha_presence = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Alpha Presence" ) alpha_decay = gr.Slider( minimum=0.9, maximum=1.0, value=0.996, step=0.001, label="Alpha Decay" ) generate_button = gr.Button("Generate") output = gr.Textbox(label="Generated Response", lines=10) generate_button.click( fn=generate_response, inputs=[ model_choice, user_prompt, system_prompt, temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, max_tokens ], outputs=output ) gr.Markdown(""" ## Model Information - **0.1B Model**: Smaller model, faster but less capable - **0.4B Model**: Larger model, slower but more capable ## Parameter Descriptions - **Temperature**: Controls randomness in the output (higher = more random) - **Top P**: Nucleus sampling threshold (lower = more focused) - **Top K**: Limits the number of tokens considered for each step - **Alpha Frequency**: Penalizes frequent tokens - **Alpha Presence**: Penalizes tokens that have appeared before - **Alpha Decay**: Rate at which penalties decay - **Max Tokens**: Maximum length of generated response """) # Launch the demo if __name__ == "__main__": demo.launch()