# Copyright 2022 Tristan Behrens.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3

from flask import Flask, render_template, request, send_file, jsonify, redirect, url_for
from PIL import Image
import os
import io
import random
import base64
import torch
import wave
from source.logging import create_logger
from source.tokensequence import token_sequence_to_audio, token_sequence_to_image
from source import constants
from transformers import AutoTokenizer, AutoModelForCausalLM

logger = create_logger(__name__)

# Load the auth-token from authtoken.txt.
auth_token = os.getenv("authtoken")

# Loading the model and its tokenizer.
logger.info("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token)
model = AutoModelForCausalLM.from_pretrained("ai-guru/lakhclean_mmmtrack_4bars_d-2048", use_auth_token=auth_token)
logger.info("Done.")

# Create the app.
logger.info("Creating app...")
app = Flask(__name__)
logger.info("Done.")

# Route for the loading page.
@app.route("/")
def index():
    return render_template(
        "index.html",
        compose_styles=constants.get_compose_styles_for_ui(),
        densities=constants.get_densities_for_ui(),
        temperatures=constants.get_temperatures_for_ui(),
    )


@app.route("/compose", methods=["POST"])
def compose():

    # Get the parameters as JSON.
    params = request.get_json()
    music_style = params["music_style"]
    density = params["density"]
    temperature = params["temperature"]

    instruments = constants.get_instruments(music_style)
    density = constants.get_density(density)
    temperature = constants.get_temperature(temperature)
    print(f"instruments: {instruments} density: {density} temperature: {temperature}")

    # Generate with the given parameters.
    logger.info(f"Generating token sequence...")
    generated_sequence = generate_sequence(instruments, density, temperature)
    logger.info(f"Generated token sequence: {generated_sequence}")

    # Get the audio data as a array of int16.
    logger.info("Generating audio...")
    sample_rate, audio_data = token_sequence_to_audio(generated_sequence)
    logger.info(f"Done. Audio data: {len(audio_data)}")

    # Encode the audio-data as wave file in memory. Use the wave module.
    audio_data_bytes = io.BytesIO()
    wave_file = wave.open(audio_data_bytes, "wb")
    wave_file.setframerate(sample_rate)
    wave_file.setnchannels(1)
    wave_file.setsampwidth(2)
    wave_file.writeframes(audio_data)
    wave_file.close()

    # Return the audio-data as a base64-encoded string.
    audio_data_bytes.seek(0)
    audio_data_base64 = base64.b64encode(audio_data_bytes.read()).decode("utf-8")
    audio_data_bytes.close()

    # Convert the audio data to an PIL image.
    image = token_sequence_to_image(generated_sequence)

    # Save PIL image to harddrive as PNG.
    logger.debug(f"Saving image to harddrive... {type(image)}")
    image_file_name = "compose.png"
    image.save(image_file_name, "PNG")

    # Save image to virtual file.
    img_io = io.BytesIO()
    image.save(img_io, "PNG", quality=70)
    img_io.seek(0)

    # Return the image as a base64-encoded string.
    image_data_base64 = base64.b64encode(img_io.read()).decode("utf-8")
    img_io.close()

    # Return.
    return jsonify({
        "tokens": generated_sequence,
        "audio": "data:audio/wav;base64," + audio_data_base64,
        "image": "data:image/png;base64," + image_data_base64,
        "status": "OK"
    })


def generate_sequence(instruments, density, temperature):

    instruments = instruments[::]
    random.shuffle(instruments)

    generated_ids = tokenizer.encode("PIECE_START", return_tensors="pt")[0]

    for instrument in instruments:
        more_ids = tokenizer.encode(f"TRACK_START INST={instrument} DENSITY={density}", return_tensors="pt")[0]
        generated_ids = torch.cat((generated_ids, more_ids))
        generated_ids = generated_ids.unsqueeze(0)

        generated_ids = model.generate(
            generated_ids,
            max_length=2048,
            do_sample=True,
            temperature=temperature,
            eos_token_id=tokenizer.encode("TRACK_END")[0]
        )[0]

    generated_sequence = tokenizer.decode(generated_ids)
    return generated_sequence


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)