Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Merge branch 'longgen' into our_hf2
Browse files- CHANGELOG.md +2 -0
- app.py +1 -1
- app_batched.py +158 -68
- audiocraft/models/musicgen.py +7 -2
- audiocraft/modules/transformer.py +11 -8
    	
        CHANGELOG.md
    CHANGED
    
    | @@ -13,6 +13,8 @@ Now repeating the conditioning periodically if it is too short. | |
| 13 |  | 
| 14 | 
             
            More options when launching Gradio app locally (thanks @ashleykleynhans).
         | 
| 15 |  | 
|  | |
|  | |
| 16 | 
             
            ## [0.0.1] - 2023-06-09
         | 
| 17 |  | 
| 18 | 
             
            Initial release, with model evaluation only.
         | 
|  | |
| 13 |  | 
| 14 | 
             
            More options when launching Gradio app locally (thanks @ashleykleynhans).
         | 
| 15 |  | 
| 16 | 
            +
            Testing out PyTorch 2.0 memory efficient attention.
         | 
| 17 | 
            +
             | 
| 18 | 
             
            ## [0.0.1] - 2023-06-09
         | 
| 19 |  | 
| 20 | 
             
            Initial release, with model evaluation only.
         | 
    	
        app.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ from audiocraft.models import MusicGen | |
| 15 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 16 |  | 
| 17 | 
             
            MODEL = None
         | 
| 18 | 
            -
            IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            def load_model(version):
         | 
|  | |
| 15 | 
             
            from audiocraft.data.audio import audio_write
         | 
| 16 |  | 
| 17 | 
             
            MODEL = None
         | 
| 18 | 
            +
            IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            def load_model(version):
         | 
    	
        app_batched.py
    CHANGED
    
    | @@ -6,7 +6,12 @@ This source code is licensed under the license found in the | |
| 6 | 
             
            LICENSE file in the root directory of this source tree.
         | 
| 7 | 
             
            """
         | 
| 8 |  | 
|  | |
|  | |
|  | |
| 9 | 
             
            from tempfile import NamedTemporaryFile
         | 
|  | |
|  | |
| 10 | 
             
            import torch
         | 
| 11 | 
             
            import gradio as gr
         | 
| 12 | 
             
            from audiocraft.data.audio_utils import convert_audio
         | 
| @@ -16,6 +21,29 @@ from audiocraft.models import MusicGen | |
| 16 |  | 
| 17 | 
             
            MODEL = None
         | 
| 18 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 19 |  | 
| 20 | 
             
            def load_model():
         | 
| 21 | 
             
                print("Loading model")
         | 
| @@ -28,11 +56,13 @@ def predict(texts, melodies): | |
| 28 | 
             
                    MODEL = load_model()
         | 
| 29 |  | 
| 30 | 
             
                duration = 12
         | 
|  | |
|  | |
| 31 | 
             
                MODEL.set_generation_params(duration=duration)
         | 
| 32 |  | 
| 33 | 
            -
                print(texts, melodies)
         | 
|  | |
| 34 | 
             
                processed_melodies = []
         | 
| 35 | 
            -
             | 
| 36 | 
             
                target_sr = 32000
         | 
| 37 | 
             
                target_ac = 1
         | 
| 38 | 
             
                for melody in melodies:
         | 
| @@ -60,73 +90,133 @@ def predict(texts, melodies): | |
| 60 | 
             
                        audio_write(
         | 
| 61 | 
             
                            file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 62 | 
             
                            loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 63 | 
            -
                         | 
| 64 | 
            -
             | 
| 65 | 
            -
                 | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
                     | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                     | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
                             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
                     | 
| 93 | 
            -
                     | 
| 94 | 
            -
                         | 
| 95 | 
            -
             | 
| 96 | 
            -
                             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
                             | 
| 100 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 101 | 
             
                        ],
         | 
| 102 | 
            -
                        [
         | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
                     | 
| 116 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 117 | 
             
                )
         | 
| 118 | 
            -
                gr.Markdown("""
         | 
| 119 | 
            -
                ### More details
         | 
| 120 | 
            -
             | 
| 121 | 
            -
                The model will generate 12 seconds of audio based on the description you provided.
         | 
| 122 | 
            -
                You can optionaly provide a reference audio from which a broad melody will be extracted.
         | 
| 123 | 
            -
                The model will then try to follow both the description and melody provided.
         | 
| 124 | 
            -
                All samples are generated with the `melody` model.
         | 
| 125 | 
            -
              
         | 
| 126 | 
            -
                You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 127 |  | 
| 128 | 
            -
                 | 
| 129 | 
            -
                for more details.
         | 
| 130 | 
            -
                """)
         | 
| 131 |  | 
| 132 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 6 | 
             
            LICENSE file in the root directory of this source tree.
         | 
| 7 | 
             
            """
         | 
| 8 |  | 
| 9 | 
            +
            import argparse
         | 
| 10 | 
            +
            from concurrent.futures import ProcessPoolExecutor
         | 
| 11 | 
            +
            import subprocess as sp
         | 
| 12 | 
             
            from tempfile import NamedTemporaryFile
         | 
| 13 | 
            +
            import time
         | 
| 14 | 
            +
            import warnings
         | 
| 15 | 
             
            import torch
         | 
| 16 | 
             
            import gradio as gr
         | 
| 17 | 
             
            from audiocraft.data.audio_utils import convert_audio
         | 
|  | |
| 21 |  | 
| 22 | 
             
            MODEL = None
         | 
| 23 |  | 
| 24 | 
            +
            _old_call = sp.call
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def _call_nostderr(*args, **kwargs):
         | 
| 28 | 
            +
                # Avoid ffmpeg vomitting on the logs.
         | 
| 29 | 
            +
                kwargs['stderr'] = sp.DEVNULL
         | 
| 30 | 
            +
                kwargs['stdout'] = sp.DEVNULL
         | 
| 31 | 
            +
                _old_call(*args, **kwargs)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            sp.call = _call_nostderr
         | 
| 35 | 
            +
            pool = ProcessPoolExecutor(3)
         | 
| 36 | 
            +
            pool.__enter__()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def make_waveform(*args, **kwargs):
         | 
| 40 | 
            +
                be = time.time()
         | 
| 41 | 
            +
                with warnings.catch_warnings():
         | 
| 42 | 
            +
                    warnings.simplefilter('ignore')
         | 
| 43 | 
            +
                    out = gr.make_waveform(*args, **kwargs)
         | 
| 44 | 
            +
                    print("Make a video took", time.time() - be)
         | 
| 45 | 
            +
                    return out
         | 
| 46 | 
            +
             | 
| 47 |  | 
| 48 | 
             
            def load_model():
         | 
| 49 | 
             
                print("Loading model")
         | 
|  | |
| 56 | 
             
                    MODEL = load_model()
         | 
| 57 |  | 
| 58 | 
             
                duration = 12
         | 
| 59 | 
            +
                max_text_length = 512
         | 
| 60 | 
            +
                texts = [text[:max_text_length] for text in texts]
         | 
| 61 | 
             
                MODEL.set_generation_params(duration=duration)
         | 
| 62 |  | 
| 63 | 
            +
                print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
         | 
| 64 | 
            +
                be = time.time()
         | 
| 65 | 
             
                processed_melodies = []
         | 
|  | |
| 66 | 
             
                target_sr = 32000
         | 
| 67 | 
             
                target_ac = 1
         | 
| 68 | 
             
                for melody in melodies:
         | 
|  | |
| 90 | 
             
                        audio_write(
         | 
| 91 | 
             
                            file.name, output, MODEL.sample_rate, strategy="loudness",
         | 
| 92 | 
             
                            loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
         | 
| 93 | 
            +
                        out_files.append(pool.submit(make_waveform, file.name))
         | 
| 94 | 
            +
                res = [[out_file.result() for out_file in out_files]]
         | 
| 95 | 
            +
                print("batch finished", len(texts), time.time() - be)
         | 
| 96 | 
            +
                return res
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def ui(**kwargs):
         | 
| 100 | 
            +
                with gr.Blocks() as demo:
         | 
| 101 | 
            +
                    gr.Markdown(
         | 
| 102 | 
            +
                        """
         | 
| 103 | 
            +
                        # MusicGen
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
         | 
| 106 | 
            +
                        presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
         | 
| 107 | 
            +
                        <br/>
         | 
| 108 | 
            +
                        <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
         | 
| 109 | 
            +
                        <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
         | 
| 110 | 
            +
                        for longer sequences, more control and no queue.</p>
         | 
| 111 | 
            +
                        """
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    with gr.Row():
         | 
| 114 | 
            +
                        with gr.Column():
         | 
| 115 | 
            +
                            with gr.Row():
         | 
| 116 | 
            +
                                text = gr.Text(label="Describe your music", lines=2, interactive=True)
         | 
| 117 | 
            +
                                melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
         | 
| 118 | 
            +
                            with gr.Row():
         | 
| 119 | 
            +
                                submit = gr.Button("Generate")
         | 
| 120 | 
            +
                        with gr.Column():
         | 
| 121 | 
            +
                            output = gr.Video(label="Generated Music")
         | 
| 122 | 
            +
                    submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
         | 
| 123 | 
            +
                    gr.Examples(
         | 
| 124 | 
            +
                        fn=predict,
         | 
| 125 | 
            +
                        examples=[
         | 
| 126 | 
            +
                            [
         | 
| 127 | 
            +
                                "An 80s driving pop song with heavy drums and synth pads in the background",
         | 
| 128 | 
            +
                                "./assets/bach.mp3",
         | 
| 129 | 
            +
                            ],
         | 
| 130 | 
            +
                            [
         | 
| 131 | 
            +
                                "A cheerful country song with acoustic guitars",
         | 
| 132 | 
            +
                                "./assets/bolero_ravel.mp3",
         | 
| 133 | 
            +
                            ],
         | 
| 134 | 
            +
                            [
         | 
| 135 | 
            +
                                "90s rock song with electric guitar and heavy drums",
         | 
| 136 | 
            +
                                None,
         | 
| 137 | 
            +
                            ],
         | 
| 138 | 
            +
                            [
         | 
| 139 | 
            +
                                "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
         | 
| 140 | 
            +
                                "./assets/bach.mp3",
         | 
| 141 | 
            +
                            ],
         | 
| 142 | 
            +
                            [
         | 
| 143 | 
            +
                                "lofi slow bpm electro chill with organic samples",
         | 
| 144 | 
            +
                                None,
         | 
| 145 | 
            +
                            ],
         | 
| 146 | 
             
                        ],
         | 
| 147 | 
            +
                        inputs=[text, melody],
         | 
| 148 | 
            +
                        outputs=[output]
         | 
| 149 | 
            +
                    )
         | 
| 150 | 
            +
                    gr.Markdown("""
         | 
| 151 | 
            +
                    ### More details
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    The model will generate 12 seconds of audio based on the description you provided.
         | 
| 154 | 
            +
                    You can optionaly provide a reference audio from which a broad melody will be extracted.
         | 
| 155 | 
            +
                    The model will then try to follow both the description and melody provided.
         | 
| 156 | 
            +
                    All samples are generated with the `melody` model.
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    You can also use your own GPU or a Google Colab by following the instructions on our repo.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
         | 
| 161 | 
            +
                    for more details.
         | 
| 162 | 
            +
                    """)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # Show the interface
         | 
| 165 | 
            +
                    launch_kwargs = {}
         | 
| 166 | 
            +
                    username = kwargs.get('username')
         | 
| 167 | 
            +
                    password = kwargs.get('password')
         | 
| 168 | 
            +
                    server_port = kwargs.get('server_port', 0)
         | 
| 169 | 
            +
                    inbrowser = kwargs.get('inbrowser', False)
         | 
| 170 | 
            +
                    share = kwargs.get('share', False)
         | 
| 171 | 
            +
                    server_name = kwargs.get('listen')
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    launch_kwargs['server_name'] = server_name
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    if username and password:
         | 
| 176 | 
            +
                        launch_kwargs['auth'] = (username, password)
         | 
| 177 | 
            +
                    if server_port > 0:
         | 
| 178 | 
            +
                        launch_kwargs['server_port'] = server_port
         | 
| 179 | 
            +
                    if inbrowser:
         | 
| 180 | 
            +
                        launch_kwargs['inbrowser'] = inbrowser
         | 
| 181 | 
            +
                    if share:
         | 
| 182 | 
            +
                        launch_kwargs['share'] = share
         | 
| 183 | 
            +
                    demo.queue(max_size=8 * 4).launch(**launch_kwargs)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            if __name__ == "__main__":
         | 
| 187 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 188 | 
            +
                parser.add_argument(
         | 
| 189 | 
            +
                    '--listen',
         | 
| 190 | 
            +
                    type=str,
         | 
| 191 | 
            +
                    default='127.0.0.1',
         | 
| 192 | 
            +
                    help='IP to listen on for connections to Gradio',
         | 
| 193 | 
            +
                )
         | 
| 194 | 
            +
                parser.add_argument(
         | 
| 195 | 
            +
                    '--username', type=str, default='', help='Username for authentication'
         | 
| 196 | 
            +
                )
         | 
| 197 | 
            +
                parser.add_argument(
         | 
| 198 | 
            +
                    '--password', type=str, default='', help='Password for authentication'
         | 
| 199 | 
            +
                )
         | 
| 200 | 
            +
                parser.add_argument(
         | 
| 201 | 
            +
                    '--server_port',
         | 
| 202 | 
            +
                    type=int,
         | 
| 203 | 
            +
                    default=0,
         | 
| 204 | 
            +
                    help='Port to run the server listener on',
         | 
| 205 | 
            +
                )
         | 
| 206 | 
            +
                parser.add_argument(
         | 
| 207 | 
            +
                    '--inbrowser', action='store_true', help='Open in browser'
         | 
| 208 | 
            +
                )
         | 
| 209 | 
            +
                parser.add_argument(
         | 
| 210 | 
            +
                    '--share', action='store_true', help='Share the gradio UI'
         | 
| 211 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 212 |  | 
| 213 | 
            +
                args = parser.parse_args()
         | 
|  | |
|  | |
| 214 |  | 
| 215 | 
            +
                ui(
         | 
| 216 | 
            +
                    username=args.username,
         | 
| 217 | 
            +
                    password=args.password,
         | 
| 218 | 
            +
                    inbrowser=args.inbrowser,
         | 
| 219 | 
            +
                    server_port=args.server_port,
         | 
| 220 | 
            +
                    share=args.share,
         | 
| 221 | 
            +
                    listen=args.listen
         | 
| 222 | 
            +
                )
         | 
    	
        audiocraft/models/musicgen.py
    CHANGED
    
    | @@ -96,7 +96,7 @@ class MusicGen: | |
| 96 | 
             
                def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
         | 
| 97 | 
             
                                          top_p: float = 0.0, temperature: float = 1.0,
         | 
| 98 | 
             
                                          duration: float = 30.0, cfg_coef: float = 3.0,
         | 
| 99 | 
            -
                                          two_step_cfg: bool = False):
         | 
| 100 | 
             
                    """Set the generation parameters for MusicGen.
         | 
| 101 |  | 
| 102 | 
             
                    Args:
         | 
| @@ -109,8 +109,13 @@ class MusicGen: | |
| 109 | 
             
                        two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
         | 
| 110 | 
             
                            instead of batching together the two. This has some impact on how things
         | 
| 111 | 
             
                            are padded but seems to have little impact in practice.
         | 
|  | |
|  | |
|  | |
| 112 | 
             
                    """
         | 
| 113 | 
            -
                    assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
         | 
|  | |
|  | |
| 114 | 
             
                    self.generation_params = {
         | 
| 115 | 
             
                        'max_gen_len': int(duration * self.frame_rate),
         | 
| 116 | 
             
                        'use_sampling': use_sampling,
         | 
|  | |
| 96 | 
             
                def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
         | 
| 97 | 
             
                                          top_p: float = 0.0, temperature: float = 1.0,
         | 
| 98 | 
             
                                          duration: float = 30.0, cfg_coef: float = 3.0,
         | 
| 99 | 
            +
                                          two_step_cfg: bool = False, extend_stride: float = 15):
         | 
| 100 | 
             
                    """Set the generation parameters for MusicGen.
         | 
| 101 |  | 
| 102 | 
             
                    Args:
         | 
|  | |
| 109 | 
             
                        two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
         | 
| 110 | 
             
                            instead of batching together the two. This has some impact on how things
         | 
| 111 | 
             
                            are padded but seems to have little impact in practice.
         | 
| 112 | 
            +
                        extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
         | 
| 113 | 
            +
                            should we extend the audio each time. Larger values will mean less context is
         | 
| 114 | 
            +
                            preserved, and shorter value will require extra computations.
         | 
| 115 | 
             
                    """
         | 
| 116 | 
            +
                    # assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
         | 
| 117 | 
            +
                    assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
         | 
| 118 | 
            +
                    self.extend_stride = extend_stride
         | 
| 119 | 
             
                    self.generation_params = {
         | 
| 120 | 
             
                        'max_gen_len': int(duration * self.frame_rate),
         | 
| 121 | 
             
                        'use_sampling': use_sampling,
         | 
    	
        audiocraft/modules/transformer.py
    CHANGED
    
    | @@ -247,20 +247,20 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 247 | 
             
                    # Complete the key/value pair using the streaming state.
         | 
| 248 | 
             
                    if self._streaming_state:
         | 
| 249 | 
             
                        pk = self._streaming_state['past_keys']
         | 
| 250 | 
            -
                        nk = torch.cat([pk, k], dim= | 
| 251 | 
             
                        if v is k:
         | 
| 252 | 
             
                            nv = nk
         | 
| 253 | 
             
                        else:
         | 
| 254 | 
             
                            pv = self._streaming_state['past_values']
         | 
| 255 | 
            -
                            nv = torch.cat([pv, v], dim= | 
| 256 | 
             
                    else:
         | 
| 257 | 
             
                        nk = k
         | 
| 258 | 
             
                        nv = v
         | 
| 259 |  | 
| 260 | 
            -
                    assert nk.shape[ | 
| 261 | 
             
                    offset = 0
         | 
| 262 | 
             
                    if self.past_context is not None:
         | 
| 263 | 
            -
                        offset = max(0, nk.shape[ | 
| 264 | 
             
                    if self._is_streaming:
         | 
| 265 | 
             
                        self._streaming_state['past_keys'] = nk[:, offset:]
         | 
| 266 | 
             
                        if v is not k:
         | 
| @@ -271,6 +271,7 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 271 | 
             
                            self._streaming_state['offset'] = torch.tensor(0)
         | 
| 272 | 
             
                    return nk, nv
         | 
| 273 |  | 
|  | |
| 274 | 
             
                def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
         | 
| 275 | 
             
                    # Apply rope embeddings to query and key tensors.
         | 
| 276 | 
             
                    assert self.rope is not None
         | 
| @@ -325,7 +326,7 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 325 | 
             
                                q = self.q_layer_norm(q)
         | 
| 326 | 
             
                                k = self.k_layer_norm(k)
         | 
| 327 | 
             
                            # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
         | 
| 328 | 
            -
                            q, k, v = [rearrange(x, "b t (h d) -> b t  | 
| 329 | 
             
                        else:
         | 
| 330 | 
             
                            if not _is_profiled():
         | 
| 331 | 
             
                                # profiling breaks that propertysomehow.
         | 
| @@ -333,7 +334,7 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 333 | 
             
                                assert value is key, "specialized implementation"
         | 
| 334 | 
             
                            projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
         | 
| 335 | 
             
                            if self.kv_repeat == 1:
         | 
| 336 | 
            -
                                packed = rearrange(projected, "b t (p h d) -> b  | 
| 337 | 
             
                                q, k, v = ops.unbind(packed, dim=2)
         | 
| 338 | 
             
                            else:
         | 
| 339 | 
             
                                embed_dim = self.embed_dim
         | 
| @@ -355,6 +356,7 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 355 | 
             
                                k = self.k_layer_norm(k)
         | 
| 356 | 
             
                                q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
         | 
| 357 | 
             
                            if self.rope:
         | 
|  | |
| 358 | 
             
                                q, k = self._apply_rope(q, k)
         | 
| 359 | 
             
                            k, v = self._complete_kv(k, v)
         | 
| 360 | 
             
                            if self.kv_repeat > 1:
         | 
| @@ -364,7 +366,8 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 364 | 
             
                            q, k, v = [x.float() for x in [q, k, v]]
         | 
| 365 | 
             
                        if self.memory_efficient:
         | 
| 366 | 
             
                            p = self.dropout if self.training else 0
         | 
| 367 | 
            -
                            x =  | 
|  | |
| 368 | 
             
                        else:
         | 
| 369 | 
             
                            # We include the dot product as float32, for consistency
         | 
| 370 | 
             
                            # with the other implementations that include that step
         | 
| @@ -385,7 +388,7 @@ class StreamingMultiheadAttention(StreamingModule): | |
| 385 | 
             
                            w = F.dropout(w, self.dropout, training=self.training).to(v)
         | 
| 386 | 
             
                            x = torch.einsum("bhqk,bkhc->bqhc", w, v)
         | 
| 387 | 
             
                        x = x.to(dtype)
         | 
| 388 | 
            -
                        x = rearrange(x, "b t  | 
| 389 | 
             
                        x = self.out_proj(x)
         | 
| 390 | 
             
                    else:
         | 
| 391 | 
             
                        key, value = self._complete_kv(key, value)
         | 
|  | |
| 247 | 
             
                    # Complete the key/value pair using the streaming state.
         | 
| 248 | 
             
                    if self._streaming_state:
         | 
| 249 | 
             
                        pk = self._streaming_state['past_keys']
         | 
| 250 | 
            +
                        nk = torch.cat([pk, k], dim=2)
         | 
| 251 | 
             
                        if v is k:
         | 
| 252 | 
             
                            nv = nk
         | 
| 253 | 
             
                        else:
         | 
| 254 | 
             
                            pv = self._streaming_state['past_values']
         | 
| 255 | 
            +
                            nv = torch.cat([pv, v], dim=2)
         | 
| 256 | 
             
                    else:
         | 
| 257 | 
             
                        nk = k
         | 
| 258 | 
             
                        nv = v
         | 
| 259 |  | 
| 260 | 
            +
                    assert nk.shape[2] == nv.shape[2]
         | 
| 261 | 
             
                    offset = 0
         | 
| 262 | 
             
                    if self.past_context is not None:
         | 
| 263 | 
            +
                        offset = max(0, nk.shape[2] - self.past_context)
         | 
| 264 | 
             
                    if self._is_streaming:
         | 
| 265 | 
             
                        self._streaming_state['past_keys'] = nk[:, offset:]
         | 
| 266 | 
             
                        if v is not k:
         | 
|  | |
| 271 | 
             
                            self._streaming_state['offset'] = torch.tensor(0)
         | 
| 272 | 
             
                    return nk, nv
         | 
| 273 |  | 
| 274 | 
            +
             | 
| 275 | 
             
                def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
         | 
| 276 | 
             
                    # Apply rope embeddings to query and key tensors.
         | 
| 277 | 
             
                    assert self.rope is not None
         | 
|  | |
| 326 | 
             
                                q = self.q_layer_norm(q)
         | 
| 327 | 
             
                                k = self.k_layer_norm(k)
         | 
| 328 | 
             
                            # q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]]
         | 
| 329 | 
            +
                            q, k, v = [rearrange(x, "b t (h d) -> b h t d", h=self.num_heads) for x in [q, k, v]]
         | 
| 330 | 
             
                        else:
         | 
| 331 | 
             
                            if not _is_profiled():
         | 
| 332 | 
             
                                # profiling breaks that propertysomehow.
         | 
|  | |
| 334 | 
             
                                assert value is key, "specialized implementation"
         | 
| 335 | 
             
                            projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
         | 
| 336 | 
             
                            if self.kv_repeat == 1:
         | 
| 337 | 
            +
                                packed = rearrange(projected, "b t (p h d) -> b h p t d", p=3, h=self.num_heads)
         | 
| 338 | 
             
                                q, k, v = ops.unbind(packed, dim=2)
         | 
| 339 | 
             
                            else:
         | 
| 340 | 
             
                                embed_dim = self.embed_dim
         | 
|  | |
| 356 | 
             
                                k = self.k_layer_norm(k)
         | 
| 357 | 
             
                                q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]]
         | 
| 358 | 
             
                            if self.rope:
         | 
| 359 | 
            +
                                assert False, "Not supported for now"
         | 
| 360 | 
             
                                q, k = self._apply_rope(q, k)
         | 
| 361 | 
             
                            k, v = self._complete_kv(k, v)
         | 
| 362 | 
             
                            if self.kv_repeat > 1:
         | 
|  | |
| 366 | 
             
                            q, k, v = [x.float() for x in [q, k, v]]
         | 
| 367 | 
             
                        if self.memory_efficient:
         | 
| 368 | 
             
                            p = self.dropout if self.training else 0
         | 
| 369 | 
            +
                            x = torch.nn.functional.scaled_dot_product_attention(
         | 
| 370 | 
            +
                                q, k, v, is_causal=attn_mask is not None, dropout_p=p)
         | 
| 371 | 
             
                        else:
         | 
| 372 | 
             
                            # We include the dot product as float32, for consistency
         | 
| 373 | 
             
                            # with the other implementations that include that step
         | 
|  | |
| 388 | 
             
                            w = F.dropout(w, self.dropout, training=self.training).to(v)
         | 
| 389 | 
             
                            x = torch.einsum("bhqk,bkhc->bqhc", w, v)
         | 
| 390 | 
             
                        x = x.to(dtype)
         | 
| 391 | 
            +
                        x = rearrange(x, "b h t d -> b t (h d)", h=self.num_heads)
         | 
| 392 | 
             
                        x = self.out_proj(x)
         | 
| 393 | 
             
                    else:
         | 
| 394 | 
             
                        key, value = self._complete_kv(key, value)
         | 
 
			
