from typing import List
import av
import asyncio
from collections import deque
import threading

import numpy as np
import ray
from webrtc_av_queue_actor import WebRtcAVQueueActor
import pydub
import torch

class StreamlitAVQueue:
    def __init__(self, audio_bit_rate=16000):
        self._output_channels = 2
        self._audio_bit_rate = audio_bit_rate
        self._listening = True
        self._lock = threading.Lock()
        self.queue_actor = WebRtcAVQueueActor.options(
            name="WebRtcAVQueueActor", 
            get_if_exists=True,
            ).remote() 
        
    def set_listening(self, listening: bool):
        with self._lock:
            self._listening = listening
            
    async def queued_video_frames_callback(
                self,
                frames: List[av.VideoFrame],
            ) -> av.VideoFrame:
        try:
            for frame in frames:
                shared_tensor = frame.to_ndarray(format="rgb24")
                shared_tensor_ref = ray.put(shared_tensor)
                await self.queue_actor.enqueue_in_video_frame.remote(shared_tensor_ref)
                # print (f"tesnor len: {len(shared_tensor)}, tensor shape: {shared_tensor.shape}, tensor type:{shared_tensor.dtype} tensor ref: {shared_tensor_ref}")
        except Exception as e:
            print (e)
        return frames

    async def queued_audio_frames_callback(
                self,
                frames: List[av.AudioFrame],
            ) -> av.AudioFrame:
        try:
            with self._lock:
                should_listed = self._listening
            sound_chunk = pydub.AudioSegment.empty()
            if len(frames) > 0 and should_listed:
                for frame in frames:
                    sound = pydub.AudioSegment(
                        data=frame.to_ndarray().tobytes(),
                        sample_width=frame.format.bytes,
                        frame_rate=frame.sample_rate,
                        channels=len(frame.layout.channels),
                    )
                    sound = sound.set_channels(1)
                    sound = sound.set_frame_rate(self._audio_bit_rate)
                    sound_chunk += sound
                shared_buffer = np.array(sound_chunk.get_array_of_samples())
                shared_buffer_ref = ray.put(shared_buffer)
                await self.queue_actor.enqueue_in_audio_frame.remote(shared_buffer_ref)
        except Exception as e:
            print (e)
            
        # return empty frames to avoid echo
        new_frames = []
        try:
            for frame in frames:
                required_samples = frame.samples
                # print (f"frame: {frame.format.name}, {frame.layout.name}, {frame.sample_rate}, {frame.samples}")
                assert frame.format.bytes == 2
                assert frame.format.name == 's16'
                frame_as_bytes = await self.queue_actor.get_out_audio_frame.remote()
                if frame_as_bytes:
                    # print(f"frame_as_bytes: {len(frame_as_bytes)}")
                    assert len(frame_as_bytes) == frame.samples * frame.format.bytes
                    samples = np.frombuffer(frame_as_bytes, dtype=np.int16)
                else:
                    samples = np.zeros((required_samples * 2 * 1), dtype=np.int16)
                if self._output_channels == 2:
                    samples = np.vstack((samples, samples)).reshape((-1,), order='F')
                samples = samples.reshape(1, -1)
                layout = 'stereo' if self._output_channels == 2 else 'mono'
                new_frame = av.AudioFrame.from_ndarray(samples, format='s16', layout=layout)
                new_frame.sample_rate = frame.sample_rate
                new_frames.append(new_frame)
        except Exception as e:
            print (e)
        return new_frames

    async def get_in_audio_frames_async(self) -> List[av.AudioFrame]:
        shared_buffers = await self.queue_actor.get_in_audio_frames.remote()
        return shared_buffers

    async def get_video_frames_async(self) -> List[av.AudioFrame]:
        shared_tensors = await self.queue_actor.get_in_video_frames.remote()
        return shared_tensors
    
    def get_out_audio_queue(self):
        return self.queue_actor.get_out_audio_queue.remote()
    
    # def get_out_audio_frame(self):
    #     return self.queue_actor.get_out_audio_frame.remote()