# jam_worker.py import threading, time, base64, io, uuid from dataclasses import dataclass, field import numpy as np import soundfile as sf # Pull in your helpers from app.py or refactor them into a shared utils module. from utils import ( match_loudness_to_reference, stitch_generated, hard_trim_seconds, apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail, resample_and_snap, wav_bytes_base64 ) from scipy.signal import resample_poly from math import gcd @dataclass class JamParams: bpm: float beats_per_bar: int bars_per_chunk: int target_sr: int loudness_mode: str = "auto" headroom_db: float = 1.0 style_vec: np.ndarray | None = None ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness combined_loop: any = None # NEW: Full combined audio for context setup guidance_weight: float = 1.1 temperature: float = 1.1 topk: int = 40 @dataclass class JamChunk: index: int audio_base64: str metadata: dict class JamWorker(threading.Thread): def __init__(self, mrt, params: JamParams): super().__init__(daemon=True) self.mrt = mrt self.params = params # Initialize fresh state self.state = mrt.init_state() # CRITICAL: Set up fresh context from the new combined audio if params.combined_loop is not None: self._setup_context_from_combined_loop() self.idx = 0 self.outbox: list[JamChunk] = [] self._stop_event = threading.Event() self.last_chunk_started_at = None self.last_chunk_completed_at = None self._lock = threading.Lock() def _setup_context_from_combined_loop(self): """Set up MRT context tokens from the combined loop audio""" try: # Import the utility functions (same as used in main generation) from utils import make_bar_aligned_context, take_bar_aligned_tail # Extract context from combined loop (same logic as generate_loop_continuation_with_mrt) codec_fps = float(self.mrt.codec.frame_rate) ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps # Take tail portion for context (matches main generation) loop_for_context = take_bar_aligned_tail( self.params.combined_loop, self.params.bpm, self.params.beats_per_bar, ctx_seconds ) # Encode to tokens tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32) tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth] # Create bar-aligned context context_tokens = make_bar_aligned_context( tokens, bpm=self.params.bpm, fps=int(self.mrt.codec.frame_rate), ctx_frames=self.mrt.config.context_length_frames, beats_per_bar=self.params.beats_per_bar ) # Set context on state - this is the key fix! self.state.context_tokens = context_tokens print(f"✅ JamWorker: Set up fresh context from combined loop") print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}") except Exception as e: print(f"❌ Failed to setup context from combined loop: {e}") # Continue without context rather than crashing def stop(self): self._stop_event.set() def update_style(self, new_style_vec: np.ndarray | None): with self._lock: if new_style_vec is not None: self.params.style_vec = new_style_vec def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None): with self._lock: if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight) if temperature is not None: self.params.temperature = float(temperature) if topk is not None: self.params.topk = int(topk) def _seconds_per_bar(self) -> float: return self.params.beats_per_bar * (60.0 / self.params.bpm) def _snap_and_encode(self, y, seconds, target_sr, bars): cur_sr = int(self.mrt.sample_rate) x = y.samples if y.samples.ndim == 2 else y.samples[:, None] x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds) b64, total_samples, channels = wav_bytes_base64(x, target_sr) meta = { "bpm": int(round(self.params.bpm)), "bars": int(bars), "beats_per_bar": int(self.params.beats_per_bar), "sample_rate": int(target_sr), "channels": channels, "total_samples": total_samples, "seconds_per_bar": self._seconds_per_bar(), "loop_duration_seconds": bars * self._seconds_per_bar(), "guidance_weight": self.params.guidance_weight, "temperature": self.params.temperature, "topk": self.params.topk, } return b64, meta def run(self): spb = self._seconds_per_bar() chunk_secs = self.params.bars_per_chunk * spb xfade = self.mrt.config.crossfade_length # Prime: set initial context on state (caller should have done this; safe to re-set here) # NOTE: We assume caller passed a style_vec computed from tail/whole/blend. while not self._stop_event.is_set(): # honor live knob updates atomically with self._lock: style_vec = self.params.style_vec # Temporarily override MRT knobs (thread-local overrides) self.mrt.guidance_weight = self.params.guidance_weight self.mrt.temperature = self.params.temperature self.mrt.topk = self.params.topk # 1) generate enough model chunks to cover chunk_secs need = chunk_secs chunks = [] self.last_chunk_started_at = time.time() while need > 0 and not self._stop_event.is_set(): wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec) chunks.append(wav) # model chunk length (seconds) at model SR need -= (wav.samples.shape[0] / float(self.mrt.sample_rate)) if self._stop_event.is_set(): break # 2) stitch and trim to exact seconds at model SR y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo() y = hard_trim_seconds(y, chunk_secs) # 3) post-process if self.idx == 0 and self.params.ref_loop is not None: y, _ = match_loudness_to_reference(self.params.ref_loop, y, method=self.params.loudness_mode, headroom_db=self.params.headroom_db) else: apply_micro_fades(y, 3) # 4) resample + snap + b64 b64, meta = self._snap_and_encode(y, seconds=chunk_secs, target_sr=self.params.target_sr, bars=self.params.bars_per_chunk) # 5) enqueue with self._lock: self.idx += 1 self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta)) self.last_chunk_completed_at = time.time() # optional: cleanup here if needed