Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Chatbot with TinyLlama
Browse files- app.py +120 -0
- imgs/TinyLlama_logo.png +0 -0
- imgs/user_logo.png +0 -0
- requirements.txt +86 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,120 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from ctransformers import AutoModelForCausalLM, AutoConfig, Config #import for GGUF/GGML models
         | 
| 3 | 
            +
            import datetime
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            modelfile="TinyLlama/TinyLlama-1.1B-Chat-v0.6"
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            i_temperature = 0.30 
         | 
| 8 | 
            +
            i_max_new_tokens=1100
         | 
| 9 | 
            +
            i_repetitionpenalty = 1.2
         | 
| 10 | 
            +
            i_contextlength=12048
         | 
| 11 | 
            +
            logfile = 'TinyLlama.1B.txt'
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            print("loading model...")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            stt = datetime.datetime.now()
         | 
| 16 | 
            +
            conf = AutoConfig(Config(temperature=i_temperature, 
         | 
| 17 | 
            +
                                     repetition_penalty=i_repetitionpenalty, 
         | 
| 18 | 
            +
                                     batch_size=64,
         | 
| 19 | 
            +
                                     max_new_tokens=i_max_new_tokens, 
         | 
| 20 | 
            +
                                     context_length=i_contextlength))
         | 
| 21 | 
            +
            llm = AutoModelForCausalLM.from_pretrained(modelfile,
         | 
| 22 | 
            +
                                                       model_type="llama",
         | 
| 23 | 
            +
                                                       config=conf)
         | 
| 24 | 
            +
            dt = datetime.datetime.now() - stt
         | 
| 25 | 
            +
            print(f"Model loaded in {dt}")
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def writehistory(text):
         | 
| 28 | 
            +
                with open(logfile, 'a', encoding='utf-8') as f:
         | 
| 29 | 
            +
                    f.write(text)
         | 
| 30 | 
            +
                    f.write('\n')
         | 
| 31 | 
            +
                f.close()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            with gr.Blocks(theme='ParityError/Interstellar') as demo: 
         | 
| 34 | 
            +
                # TITLE SECTION
         | 
| 35 | 
            +
                with gr.Row():
         | 
| 36 | 
            +
                    with gr.Column(scale=12):
         | 
| 37 | 
            +
                        gr.HTML("<center>"
         | 
| 38 | 
            +
                        + "<h1>🦙 TinyLlama 1.1B 🐋 4K context window</h2></center>")  
         | 
| 39 | 
            +
                        gr.Markdown("""
         | 
| 40 | 
            +
                        **Currently Running**: [TinyLlama/TinyLlama-1.1B-Chat-v0.6](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.6)         **Chat History Log File**: *TinyLlama.1B.txt*
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        - **Base Model**: TinyLlama/TinyLlama-1.1B-Chat-v0.6, Fine tuned on OpenOrca GPT4 subset for 1 epoch, Using CHATML format. 
         | 
| 43 | 
            +
                        - **License**: Apache 2.0, following the TinyLlama base model. 
         | 
| 44 | 
            +
                                    The model output is not censored and the authors do not endorse the opinions in the generated content. Use at your own risk.
         | 
| 45 | 
            +
                        """)         
         | 
| 46 | 
            +
                    gr.Image(value='imgs/TinyLlama_logo.png', width=70)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
               # chat and parameters settings
         | 
| 49 | 
            +
                with gr.Row():
         | 
| 50 | 
            +
                    with gr.Column(scale=4):
         | 
| 51 | 
            +
                        chatbot = gr.Chatbot(height = 350, show_copy_button=True, avatar_images = ["imgs/user_logo.png","imgs/TinyLlama_logo.png"])
         | 
| 52 | 
            +
                        with gr.Row():
         | 
| 53 | 
            +
                            with gr.Column(scale=14):
         | 
| 54 | 
            +
                                msg = gr.Textbox(show_label=False, placeholder="Enter text", lines=2)
         | 
| 55 | 
            +
                            submitBtn = gr.Button("\n💬 Send\n", size="lg", variant="primary", min_width=140)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    with gr.Column(min_width=50, scale=1):
         | 
| 58 | 
            +
                            with gr.Tab(label="Parameter Setting"):
         | 
| 59 | 
            +
                                gr.Markdown("# Parameters")
         | 
| 60 | 
            +
                                top_p = gr.Slider(minimum=-0, 
         | 
| 61 | 
            +
                                                  maximum=1.0,
         | 
| 62 | 
            +
                                                  value=0.95,
         | 
| 63 | 
            +
                                                  step=0.05,
         | 
| 64 | 
            +
                                                  interactive=True,
         | 
| 65 | 
            +
                                                  label="Top-p")
         | 
| 66 | 
            +
                                temperature = gr.Slider(minimum=0.1,
         | 
| 67 | 
            +
                                                        maximum=1.0,
         | 
| 68 | 
            +
                                                        value=0.30,
         | 
| 69 | 
            +
                                                        step=0.01,
         | 
| 70 | 
            +
                                                        interactive=True,
         | 
| 71 | 
            +
                                                        label="Temperature")
         | 
| 72 | 
            +
                                max_length_tokens = gr.Slider(minimum=0,
         | 
| 73 | 
            +
                                                              maximum=4096,
         | 
| 74 | 
            +
                                                              value=1060,
         | 
| 75 | 
            +
                                                              step=4,
         | 
| 76 | 
            +
                                                              interactive=True,
         | 
| 77 | 
            +
                                                              label="Max Generation Tokens")
         | 
| 78 | 
            +
                                rep_pen = gr.Slider(minimum=0,
         | 
| 79 | 
            +
                                                    maximum=5,
         | 
| 80 | 
            +
                                                    value=1.2,
         | 
| 81 | 
            +
                                                    step=0.05,
         | 
| 82 | 
            +
                                                    interactive=True,
         | 
| 83 | 
            +
                                                    label="Repetition Penalty")
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                            clear = gr.Button("🗑️ Clear All Messages", variant='secondary')
         | 
| 86 | 
            +
                def user(user_message, history):
         | 
| 87 | 
            +
                    writehistory(f"USER: {user_message}")
         | 
| 88 | 
            +
                    return "", history + [[user_message, None]]
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def bot(history, t, p, m, r):
         | 
| 91 | 
            +
                    SYSTEM_PROMPT = """<|im_start|>system
         | 
| 92 | 
            +
                    You are a helpful bot. Your answers are clear and concise.
         | 
| 93 | 
            +
                    <|im_end|>
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    """    
         | 
| 96 | 
            +
                    prompt = f"<|im_start|>system<|im_end|><|im_start|>user\n{history[-1][0]}<|im_end|>\n<|im_start|>assistant\n"  
         | 
| 97 | 
            +
                    print(f"history lenght: {len(history)}")
         | 
| 98 | 
            +
                    if len(history) == 1:
         | 
| 99 | 
            +
                        print("this is the first round")
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        print("here we should pass more conversations")
         | 
| 102 | 
            +
                    history[-1][1] = ""
         | 
| 103 | 
            +
                    for character in llm(prompt,
         | 
| 104 | 
            +
                                         temperature = t,
         | 
| 105 | 
            +
                                         top_p = p, 
         | 
| 106 | 
            +
                                         repetition_penalty = r, 
         | 
| 107 | 
            +
                                         max_new_tokens=m,
         | 
| 108 | 
            +
                                         stop = ['<|im_end|>'],
         | 
| 109 | 
            +
                                         stream = True):
         | 
| 110 | 
            +
                        history[-1][1] += character
         | 
| 111 | 
            +
                        yield history
         | 
| 112 | 
            +
                    writehistory(f"temperature: {t}, top_p: {p}, maxNewTokens: {m}, repetitionPenalty: {r}\n---\nBOT: {history}\n\n")
         | 
| 113 | 
            +
                    # Log in the terminal the messages
         | 
| 114 | 
            +
                    print(f"USER: {history[-1][0]}\n---\ntemperature: {t}, top_p: {p}, maxNewTokens: {m}, repetitionPenalty: {r}\n---\nBOT: {history[-1][1]}\n\n")    
         | 
| 115 | 
            +
                # Clicking the submitBtn will call the generation with Parameters in the slides
         | 
| 116 | 
            +
                submitBtn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, [chatbot,temperature,top_p,max_length_tokens,rep_pen], chatbot)
         | 
| 117 | 
            +
                clear.click(lambda: None, None, chatbot, queue=False)
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
            demo.queue()  # required to yield the streams from the text generation
         | 
| 120 | 
            +
            demo.launch(inbrowser=True, share=True)
         | 
    	
        imgs/TinyLlama_logo.png
    ADDED
    
    |   | 
    	
        imgs/user_logo.png
    ADDED
    
    |   | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            aiofiles==23.2.1
         | 
| 2 | 
            +
            aiohttp==3.9.3
         | 
| 3 | 
            +
            aiosignal==1.3.1
         | 
| 4 | 
            +
            altair==5.2.0
         | 
| 5 | 
            +
            annotated-types==0.6.0
         | 
| 6 | 
            +
            anyio==4.3.0
         | 
| 7 | 
            +
            async-timeout==4.0.3
         | 
| 8 | 
            +
            attrs==23.2.0
         | 
| 9 | 
            +
            certifi==2024.2.2
         | 
| 10 | 
            +
            charset-normalizer==3.3.2
         | 
| 11 | 
            +
            click==8.1.7
         | 
| 12 | 
            +
            cmake==3.28.3
         | 
| 13 | 
            +
            colorama==0.4.6
         | 
| 14 | 
            +
            contourpy==1.2.0
         | 
| 15 | 
            +
            cycler==0.12.1
         | 
| 16 | 
            +
            exceptiongroup==1.2.0
         | 
| 17 | 
            +
            fastapi==0.110.0
         | 
| 18 | 
            +
            ffmpy==0.3.2
         | 
| 19 | 
            +
            filelock==3.13.1
         | 
| 20 | 
            +
            fonttools==4.50.0
         | 
| 21 | 
            +
            frozenlist==1.4.1
         | 
| 22 | 
            +
            fsspec==2024.3.0
         | 
| 23 | 
            +
            gradio==4.21.0
         | 
| 24 | 
            +
            gradio_client==0.12.0
         | 
| 25 | 
            +
            h11==0.14.0
         | 
| 26 | 
            +
            httpcore==1.0.4
         | 
| 27 | 
            +
            httpx==0.27.0
         | 
| 28 | 
            +
            huggingface-hub==0.21.4
         | 
| 29 | 
            +
            idna==3.6
         | 
| 30 | 
            +
            importlib_resources==6.3.0
         | 
| 31 | 
            +
            Jinja2==3.1.3
         | 
| 32 | 
            +
            jsonschema==4.21.1
         | 
| 33 | 
            +
            jsonschema-specifications==2023.12.1
         | 
| 34 | 
            +
            kiwisolver==1.4.5
         | 
| 35 | 
            +
            linkify-it-py==2.0.3
         | 
| 36 | 
            +
            lit==18.1.1
         | 
| 37 | 
            +
            markdown-it-py==2.2.0
         | 
| 38 | 
            +
            MarkupSafe==2.1.5
         | 
| 39 | 
            +
            matplotlib==3.8.3
         | 
| 40 | 
            +
            mdit-py-plugins==0.3.3
         | 
| 41 | 
            +
            mdurl==0.1.2
         | 
| 42 | 
            +
            mpmath==1.3.0
         | 
| 43 | 
            +
            multidict==6.0.5
         | 
| 44 | 
            +
            networkx==3.2.1
         | 
| 45 | 
            +
            numpy==1.26.4
         | 
| 46 | 
            +
            orjson==3.9.15
         | 
| 47 | 
            +
            packaging==24.0
         | 
| 48 | 
            +
            pandas==2.2.1
         | 
| 49 | 
            +
            pillow==10.2.0
         | 
| 50 | 
            +
            pydantic==2.6.4
         | 
| 51 | 
            +
            pydantic_core==2.16.3
         | 
| 52 | 
            +
            pydub==0.25.1
         | 
| 53 | 
            +
            Pygments==2.17.2
         | 
| 54 | 
            +
            pyparsing==3.1.2
         | 
| 55 | 
            +
            python-dateutil==2.9.0.post0
         | 
| 56 | 
            +
            python-multipart==0.0.9
         | 
| 57 | 
            +
            pytz==2024.1
         | 
| 58 | 
            +
            PyYAML==6.0.1
         | 
| 59 | 
            +
            referencing==0.33.0
         | 
| 60 | 
            +
            regex==2023.12.25
         | 
| 61 | 
            +
            requests==2.31.0
         | 
| 62 | 
            +
            rich==13.7.1
         | 
| 63 | 
            +
            rpds-py==0.18.0
         | 
| 64 | 
            +
            ruff==0.3.3
         | 
| 65 | 
            +
            safetensors==0.4.2
         | 
| 66 | 
            +
            semantic-version==2.10.0
         | 
| 67 | 
            +
            shellingham==1.5.4
         | 
| 68 | 
            +
            six==1.16.0
         | 
| 69 | 
            +
            sniffio==1.3.1
         | 
| 70 | 
            +
            starlette==0.36.3
         | 
| 71 | 
            +
            sympy==1.12
         | 
| 72 | 
            +
            tokenizers==0.13.3
         | 
| 73 | 
            +
            tomlkit==0.12.0
         | 
| 74 | 
            +
            toolz==0.12.1
         | 
| 75 | 
            +
            torch==2.0.1
         | 
| 76 | 
            +
            tqdm==4.66.2
         | 
| 77 | 
            +
            transformers==4.31.0
         | 
| 78 | 
            +
            triton==2.0.0
         | 
| 79 | 
            +
            typer==0.9.0
         | 
| 80 | 
            +
            typing_extensions==4.10.0
         | 
| 81 | 
            +
            tzdata==2024.1
         | 
| 82 | 
            +
            uc-micro-py==1.0.3
         | 
| 83 | 
            +
            urllib3==2.2.1
         | 
| 84 | 
            +
            uvicorn==0.28.0
         | 
| 85 | 
            +
            websockets==11.0.3
         | 
| 86 | 
            +
            yarl==1.9.4
         | 
