import sys
from typing import List
import traceback
import os
import base64
# needs to be imported *before* transformers
if os.path.exists('use_normal_tokenizers'):
    import tokenizers
    BIG_MODEL = False
    CUDA = False
else:
    import tokenizers_patch
    BIG_MODEL = True
    CUDA = True
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import pprint

# from flask import Flask, request, render_template
# from flask_cors import CORS
# app = Flask(__name__, static_folder='static')
# app.config['TEMPLATES_AUTO_RELOAD'] = Tru
# CORS(app, resources= {
#     r"/generate": {"origins": origins},
#     r"/infill": {"origins": origins},
# })
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"]

PORT = 7860
VERBOSE = False

MAX_LENGTH = 256+64
TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'

if BIG_MODEL:
    model_name = "facebook/incoder-6B"
    kwargs = dict(
        revision="float16",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
    )
else:
    model_name = "facebook/incoder-1B"
    kwargs = dict()

from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")


print("loading model")
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
print("loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("loading complete")

if CUDA:
    model = model.half().cuda()

BOS = "<|endoftext|>"
EOM = "<|endofmask|>"

def make_sentinel(i):
    return f"<|mask:{i}|>"

SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM]

def generate(input, length_limit=None, temperature=None):
    input_ids = tokenizer(input, return_tensors="pt").input_ids
    if CUDA:
        input_ids = input_ids.cuda()
    current_length = input_ids.flatten().size(0)
    max_length = length_limit + current_length
    truncated = False
    if max_length > MAX_LENGTH:
        max_length = MAX_LENGTH
        truncated = True
    if max_length == current_length:
        return input, True
    output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length)
    detok_hypo_str = tokenizer.decode(output.flatten())
    if detok_hypo_str.startswith(BOS):
        detok_hypo_str = detok_hypo_str[len(BOS):]
    return detok_hypo_str, truncated

def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1):
    assert isinstance(parts, list)
    retries_attempted = 0
    done = False


    while (not done) and (retries_attempted < max_retries):
        any_truncated = False
        retries_attempted += 1
        if VERBOSE:
            print(f"retry {retries_attempted}")
        if len(parts) == 1:
            prompt = parts[0]
        else:
            prompt = ""
            # encode parts separated by sentinel
            for sentinel_ix, part in enumerate(parts):
                prompt += part
                if extra_sentinel or (sentinel_ix < len(parts) - 1):
                    prompt += make_sentinel(sentinel_ix)
            
            # prompt += TokenizerWrapper.make_sentinel(0)
        
        infills = []
        complete = []

        done = True

        for sentinel_ix, part in enumerate(parts[:-1]):
            complete.append(part)
            prompt += make_sentinel(sentinel_ix)
            completion, this_truncated = generate(prompt, length_limit, temperature)
            any_truncated |= this_truncated
            completion = completion[len(prompt):]
            if EOM not in completion:
                if VERBOSE:
                    print(f"warning: {EOM} not found")
                completion += EOM
                # TODO: break inner loop here
                done = False
            completion = completion[:completion.index(EOM) + len(EOM)]
            infilled = completion[:-len(EOM)]
            infills.append(infilled)
            complete.append(infilled)
            prompt += completion
        complete.append(parts[-1])
        text = ''.join(complete)

    if VERBOSE:
        print("generated text:")
        print(prompt)
        print()
        print("parts:")
        print(parts)
        print()
        print("infills:")
        print(infills)
        print()
        print("restitched text:")
        print(text)
        print()
    
    return {
        'text': text,
        'parts': parts,
        'infills': infills,
        'retries_attempted': retries_attempted,
        'truncated': any_truncated,
    } 


@app.head("/")
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="static/index.html", media_type="text/html")

@app.get('/generate')
# async def generate_maybe(request: Request):
async def generate_maybe(info: str):
    # form = await info.json()
    # form = await request.json()
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    print(info)
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    print(info)
    form = json.loads(info)
    pprint.pprint(form)
    # print(form)
    prompt = form['prompt']
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    if VERBOSE:
        print(prompt)
    try:
        generation, truncated = generate(prompt, length_limit, temperature)
        if truncated:
            message = TRUNCATION_MESSAGE 
        else:
            message = ''
        return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}

@app.get('/infill')
# async def infill_maybe(request: Request):
async def infill_maybe(info: str):
    # form = await info.json()
    # form = await request.json()
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    print(info)
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    print(info)
    form = json.loads(info)
    pprint.pprint(form)
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    max_retries = 1
    extra_sentinel = True
    try:
        if len(form['parts']) > 4:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."}
        generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries)
        generation['result'] = 'success'
        generation['type'] = 'infill'
        if generation['truncated']:
            generation['message'] = TRUNCATION_MESSAGE
        else:
            generation['message'] = ''
        return generation
        # return {'result': 'success', 'prefix': prefix, 'suffix': suffix,  'text': generation['text']}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        print(e)
        return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}


if __name__ == "__main__":
    app.run(host='0.0.0.0', port=PORT, threaded=False)