import os import gradio as gr import torch from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from copy import deepcopy import requests import os.path from tqdm import tqdm # Set environment variables os.environ['RWKV_JIT_ON'] = '1' os.environ["RWKV_CUDA_ON"] = '0' os.environ["RWKV_V7_ON"] = '1' # Model options MODELS = { "0.1B (Smaller)": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth", "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" } # Download tokenizer if not present TOKENIZER_FILE = "rwkv_vocab_v20230424.txt" TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/v2/rwkv_vocab_v20230424.txt" def download_tokenizer(): if not os.path.exists(TOKENIZER_FILE): print("Downloading tokenizer...") response = requests.get(TOKENIZER_URL) with open(TOKENIZER_FILE, 'wb') as f: f.write(response.content) def download_model(model_name): """Download model if not present""" if not os.path.exists(model_name): print(f"Downloading {model_name}...") url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}" response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(model_name, 'wb') as file, tqdm( desc=model_name, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(chunk_size=1024): size = file.write(data) pbar.update(size) class ModelManager: def __init__(self): self.current_model = None self.current_model_name = None self.pipeline = None def load_model(self, model_name): if model_name != self.current_model_name: download_model(MODELS[model_name]) self.current_model = RWKV( model=MODELS[model_name], strategy='cpu fp32' ) self.pipeline = PIPELINE(self.current_model, TOKENIZER_FILE) self.current_model_name = model_name return self.pipeline model_manager = ModelManager() def generate_response( model_choice, user_prompt, system_prompt, temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, max_tokens ): try: # Get or load the model pipeline = model_manager.load_model(model_choice) # Prepare the context if system_prompt.strip(): ctx = f"{system_prompt.strip()}\n\nUser: {user_prompt.strip()}\n\nA:" else: ctx = f"User: {user_prompt.strip()}\n\nA:" # Prepare generation arguments args = PIPELINE_ARGS( temperature=temperature, top_p=top_p, top_k=top_k, alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, alpha_decay=alpha_decay, token_ban=[], token_stop=[], chunk_len=256 ) # Generate response response = "" def callback(text): nonlocal response response += text return response pipeline.generate(ctx, token_count=max_tokens, args=args, callback=callback) return response except Exception as e: return f"Error: {str(e)}" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# RWKV-7 Language Model Demo") with gr.Row(): with gr.Column(): model_choice = gr.Radio( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Model Selection" ) system_prompt = gr.Textbox( label="System Prompt", placeholder="Optional system prompt to set the context", lines=3, value="You are a helpful AI assistant. You provide detailed and accurate responses." ) user_prompt = gr.Textbox( label="User Prompt", placeholder="Enter your prompt here", lines=3 ) max_tokens = gr.Slider( minimum=1, maximum=1000, value=200, step=1, label="Max Tokens" ) with gr.Column(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Top P" ) top_k = gr.Slider( minimum=0, maximum=200, value=100, step=1, label="Top K" ) alpha_frequency = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Alpha Frequency" ) alpha_presence = gr.Slider( minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Alpha Presence" ) alpha_decay = gr.Slider( minimum=0.9, maximum=1.0, value=0.996, step=0.001, label="Alpha Decay" ) generate_button = gr.Button("Generate") output = gr.Textbox(label="Generated Response", lines=10) generate_button.click( fn=generate_response, inputs=[ model_choice, user_prompt, system_prompt, temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, max_tokens ], outputs=output ) gr.Markdown(""" ## Model Information - **0.1B Model**: Smaller model, faster but less capable - **0.4B Model**: Larger model, slower but more capable ## Parameter Descriptions - **Temperature**: Controls randomness in the output (higher = more random) - **Top P**: Nucleus sampling threshold (lower = more focused) - **Top K**: Limits the number of tokens considered for each step - **Alpha Frequency**: Penalizes frequent tokens - **Alpha Presence**: Penalizes tokens that have appeared before - **Alpha Decay**: Rate at which penalties decay - **Max Tokens**: Maximum length of generated response """) # Launch the demo if __name__ == "__main__": demo.launch(ssr_mode=False)