from threading import Thread
import falcon
from falcon.http_status import HTTPStatus
import json
import requests
import time
from Model import generate_completion
import sys


class AutoComplete(object):
    def on_post(self, req, resp, single_endpoint=True, x=None, y=None):
        json_data = json.loads(req.bounded_stream.read())

        resp.status = falcon.HTTP_200

        start = time.time()

        try:
            context = json_data["context"].rstrip()
        except KeyError:
            resp.body = "The context field is required"
            resp.status = falcon.HTTP_422
            return

        try:
            n_samples = json_data['samples']
        except KeyError:
            n_samples = 3

        try:
            length = json_data['gen_length']
        except KeyError:
            length = 20

        try:
            max_time = json_data['max_time']
        except KeyError:
            max_time = -1

        try:
            model_name = json_data['model_size']
        except KeyError:
            model_name = "small"

        try:
            temperature = json_data['temperature']
        except KeyError:
            temperature = 0.7

        try:
            max_tokens = json_data['max_tokens']
        except KeyError:
            max_tokens = 256

        try:
            top_p = json_data['top_p']
        except KeyError:
            top_p = 0.95

        try:
            top_k = json_data['top_k']
        except KeyError:
            top_k = 40


        # CTRL
        try:
            repetition_penalty = json_data['repetition_penalty']
        except KeyError:
            repetition_penalty = 0.02

        # PPLM
        try:
            stepsize = json_data['step_size']
        except KeyError:
            stepsize = 0.02

        try:
            gm_scale = json_data['gm_scale']
        except KeyError:
            gm_scale = None

        try:
            kl_scale = json_data['kl_scale']
        except KeyError:
            kl_scale = None

        try:
            num_iterations = json_data['num_iterations']
        except KeyError:
            num_iterations = None

        try:
            use_sampling = json_data['use_sampling']
        except KeyError:
            use_sampling = None

        try:
            bag_of_words_or_discrim = json_data['bow_or_discrim']
        except KeyError:
            bag_of_words_or_discrim = "kitchen"

        print(json_data)

        sentences = generate_completion(
            context,
            length=length,
            max_time=max_time,
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=top_p,
            top_k=top_k,

            # CTRL
            repetition_penalty=repetition_penalty,
            
            # PPLM
            stepsize=stepsize,
            bag_of_words_or_discrim=bag_of_words_or_discrim,
            gm_scale=gm_scale,
            kl_scale=kl_scale,
            num_iterations=num_iterations,
            use_sampling=use_sampling
        )

        resp.body = json.dumps({"sentences": sentences, 'time': time.time() - start})

        resp.status = falcon.HTTP_200
        sys.stdout.flush()


class Request(Thread):
    def __init__(self, end_point, data):
        Thread.__init__(self)
        self.end_point = end_point
        self.data = data
        self.ret = None

    def run(self):
        print("Requesting with url", self.end_point)
        self.ret = requests.post(url=self.end_point, json=self.data)

    def join(self):
        Thread.join(self)
        return self.ret.text


class HandleCORS(object):
    def process_request(self, req, resp):
        resp.set_header('Access-Control-Allow-Origin', '*')
        resp.set_header('Access-Control-Allow-Methods', '*')
        resp.set_header('Access-Control-Allow-Headers', '*')
        if req.method == 'OPTIONS':
            raise HTTPStatus(falcon.HTTP_200, body='\n')


autocomplete = AutoComplete()
app = falcon.API(middleware=[HandleCORS()])
app.add_route('/autocomplete', autocomplete)
app.add_route('/autocomplete/{x}', autocomplete)
app.add_route('/autocomplete/{x}/{y}', autocomplete)

application = app