Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,7 +1,100 @@ | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
|  | |
|  | |
| 2 |  | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from datetime import datetime
         | 
| 4 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 5 |  | 
| 6 | 
            +
            title = "RWKV-4 14B fp16 ctx4096"
         | 
| 7 | 
            +
            desc = '''Links:
         | 
| 8 | 
            +
            <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 1em">ChatRWKV</a>
         | 
| 9 | 
            +
            <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 1em">RWKV-LM</a>
         | 
| 10 | 
            +
            <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 1em">RWKV pip package</a>
         | 
| 11 | 
            +
            '''
         | 
| 12 |  | 
| 13 | 
            +
            os.environ["RWKV_JIT_ON"] = '1'
         | 
| 14 | 
            +
            os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from rwkv.model import RWKV
         | 
| 17 | 
            +
            model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-169m", filename="RWKV-4-Pile-169M-20220807-8023.pth")
         | 
| 18 | 
            +
            model = RWKV(model=model_path, strategy='cuda fp16')
         | 
| 19 | 
            +
            from rwkv.utils import PIPELINE, PIPELINE_ARGS
         | 
| 20 | 
            +
            pipeline = PIPELINE(model, "20B_tokenizer.json")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def infer(
         | 
| 23 | 
            +
                    ctx,
         | 
| 24 | 
            +
                    token_count=10,
         | 
| 25 | 
            +
                    temperature=1.0,
         | 
| 26 | 
            +
                    top_p=0.85,
         | 
| 27 | 
            +
                    presencePenalty = 0.1,
         | 
| 28 | 
            +
                    countPenalty = 0.1,
         | 
| 29 | 
            +
            ):
         | 
| 30 | 
            +
                args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
         | 
| 31 | 
            +
                                 alpha_frequency = countPenalty,
         | 
| 32 | 
            +
                                 alpha_presence = presencePenalty,
         | 
| 33 | 
            +
                                 token_ban = [0], # ban the generation of some tokens
         | 
| 34 | 
            +
                                 token_stop = []) # stop generation whenever you see any token here
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                ctx = ctx.strip(' ')
         | 
| 37 | 
            +
                if ctx.endswith('\n'):
         | 
| 38 | 
            +
                    ctx = f'\n{ctx.strip()}\n'
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                    ctx = f'\n{ctx.strip()}'
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                all_tokens = []
         | 
| 43 | 
            +
                out_last = 0
         | 
| 44 | 
            +
                out_str = ''
         | 
| 45 | 
            +
                occurrence = {}
         | 
| 46 | 
            +
                state = None
         | 
| 47 | 
            +
                for i in range(int(token_count)):
         | 
| 48 | 
            +
                    out, state = model.forward(pipeline.encode(ctx) if i == 0 else [token], state)
         | 
| 49 | 
            +
                    for n in args.token_ban:
         | 
| 50 | 
            +
                        out[n] = -float('inf')
         | 
| 51 | 
            +
                    for n in occurrence:
         | 
| 52 | 
            +
                        out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
         | 
| 55 | 
            +
                    if token in args.token_stop:
         | 
| 56 | 
            +
                        break
         | 
| 57 | 
            +
                    all_tokens += [token]
         | 
| 58 | 
            +
                    if token not in occurrence:
         | 
| 59 | 
            +
                        occurrence[token] = 1
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        occurrence[token] += 1
         | 
| 62 | 
            +
                    
         | 
| 63 | 
            +
                    tmp = pipeline.decode(all_tokens[out_last:])
         | 
| 64 | 
            +
                    if '\ufffd' not in tmp:
         | 
| 65 | 
            +
                        out_str += tmp
         | 
| 66 | 
            +
                        yield out_str.strip()
         | 
| 67 | 
            +
                        out_last = i + 1
         | 
| 68 | 
            +
                yield out_str.strip()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            examples = [
         | 
| 71 | 
            +
                ["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
         | 
| 72 | 
            +
                ["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
         | 
| 73 | 
            +
                ["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nCan you write a short story about an elf maiden named Julia that meets a warrior named Rallio and they go on an adventure together?\n\nFull Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
         | 
| 74 | 
            +
            ]
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            iface = gr.Interface(
         | 
| 78 | 
            +
                fn=infer,
         | 
| 79 | 
            +
                description=f'''{desc}''',
         | 
| 80 | 
            +
                allow_flagging="never",
         | 
| 81 | 
            +
                inputs=[
         | 
| 82 | 
            +
                    gr.Textbox(lines=20, label="Prompt"),  # prompt
         | 
| 83 | 
            +
                    gr.Slider(10, 200, step=10, value=100),  # token_count
         | 
| 84 | 
            +
                    gr.Slider(0.2, 2.0, step=0.1, value=1.0),  # temperature
         | 
| 85 | 
            +
                    gr.Slider(0.0, 1.0, step=0.05, value=0.85),  # top_p
         | 
| 86 | 
            +
                    gr.Slider(0.0, 1.0, step=0.1, value=0.1),  # presencePenalty
         | 
| 87 | 
            +
                    gr.Slider(0.0, 1.0, step=0.1, value=0.1),  # countPenalty
         | 
| 88 | 
            +
                ],
         | 
| 89 | 
            +
                outputs=gr.Textbox(label="Generated Output", lines=35),
         | 
| 90 | 
            +
                examples=examples,
         | 
| 91 | 
            +
                cache_examples=False,
         | 
| 92 | 
            +
            ).queue()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            demo = gr.TabbedInterface(
         | 
| 95 | 
            +
                [iface], ["Generative"],
         | 
| 96 | 
            +
                title=title,
         | 
| 97 | 
            +
            )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            demo.queue()
         | 
| 100 | 
            +
            demo.launch(share=False)
         | 
 
			
