import argparse
import uvicorn
import sys
import json
import string
import random
import base64


from fastapi import FastAPI, Response
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
from googletrans import Translator
from io import BytesIO
from gtts import gTTS
from fastapi.middleware.cors import CORSMiddleware


class ChatAPIApp:
    def __init__(self):
        self.app = FastAPI(
            docs_url="/",
            title="HuggingFace LLM API",
            swagger_ui_parameters={"defaultModelsExpandDepth": -1},
            version="1.0",
        )
        self.setup_routes()

    def get_available_models(self):
        f = open('apis/lang_name.json', "r")
        self.available_models = json.loads(f.read())
        return self.available_models

    class ChatCompletionsPostItem(BaseModel):
        from_language: str = Field(
            default="auto",
            description="(str) `Detect`",
        )
        to_language: str = Field(
            default="en",
            description="(str) `en`",
        )
        input_text: str = Field(
            default="Hello",
            description="(str) `Text for translate`",
        )
   

    def chat_completions(self, item: ChatCompletionsPostItem):
        translator = Translator()
        f = open('apis/lang_name.json', "r")
        available_langs = json.loads(f.read())
        from_lang = 'en'
        to_lang = 'en'
        for lang_item in available_langs:
          if item.to_language == lang_item['code']:
              to_lang = item.to_language
              break
              
          
        translated = translator.translate(item.input_text, dest=to_lang)
        item_response = {
            "from_language": translated.src,
            "to_language": translated.dest,
            "text": item.input_text,
            "translate": translated.text
        }
        json_compatible_item_data = jsonable_encoder(item_response)
        return JSONResponse(content=json_compatible_item_data)

        
    class DetectLanguagePostItem(BaseModel):
        input_text: str = Field(
            default="Hello",
            description="(str) `Text for detection`",
        )

    def detect_language(self, item: DetectLanguagePostItem):
        translator = Translator()
        detected = translator.detect(item.input_text)

        item_response = {
            "lang": detected.lang,
            "confidence": detected.confidence,
        }
        json_compatible_item_data = jsonable_encoder(item_response)
        return JSONResponse(content=json_compatible_item_data)
        
    class TTSPostItem(BaseModel):
        input_text: str = Field(
            default="Hello",
            description="(str) `Text for TTS`",
        )
        from_language: str = Field(
            default="en",
            description="(str) `TTS language`",
        )
        
    def text_to_speech(self, item: TTSPostItem):
        try:
            audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False)
            fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10));
            fileName = fileName + ".mp3";
            mp3_fp = BytesIO()
            #audioobj.save(fileName)
            audioobj.write_to_fp(mp3_fp)
            #buffer = bytearray(mp3_fp.read())
            #base64EncodedStr = base64.encodebytes(buffer)
            mp3_fp.read()
            return Response(content=mp3_fp.tell(), media_type="audio/mpeg")
        except gtts.tts.gTTSError as err:
               item_response = {
                 "status": 400
               }
               json_compatible_item_data = jsonable_encoder(item_response)
               return JSONResponse(content=json_compatible_item_data)
           
        
    def setup_routes(self):
        for prefix in ["", "/v1"]:
            self.app.get(
                prefix + "/models",
                summary="Get available languages",
            )(self.get_available_models)

            self.app.post(
                prefix + "/translate",
                summary="translate text",
            )(self.chat_completions)

            self.app.post(
                prefix + "/detect",
                summary="detect language",
            )(self.detect_language)

            self.app.post(
                prefix + "/tts",
                summary="text to speech",
            )(self.text_to_speech)


class ArgParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        super(ArgParser, self).__init__(*args, **kwargs)

        self.add_argument(
            "-s",
            "--server",
            type=str,
            default="0.0.0.0",
            help="Server IP for HF LLM Chat API",
        )
        self.add_argument(
            "-p",
            "--port",
            type=int,
            default=23333,
            help="Server Port for HF LLM Chat API",
        )

        self.add_argument(
            "-d",
            "--dev",
            default=False,
            action="store_true",
            help="Run in dev mode",
        )

        self.args = self.parse_args(sys.argv[1:])


app = ChatAPIApp().app

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

if __name__ == "__main__":
    args = ArgParser().args
    if args.dev:
        uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
    else:
        uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)

    # python -m apis.chat_api      # [Docker] on product mode
    # python -m apis.chat_api -d   # [Dev]    on develop mode