import base64
import json
import os
import time
import requests
import yaml
import numpy as np
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules.utils import get_available_models
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
                                     update_model_parameters)

from modules import shared
from modules.text_generation import encode, generate_reply

params = {
    'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
}

debug = True if 'OPENEDAI_DEBUG' in os.environ else False

# Slightly different defaults for OpenAI's API
# Data type is important, Ex. use 0.0 for a float 0
default_req_params = {
    'max_new_tokens': 200,
    'temperature': 1.0,
    'top_p': 1.0,
    'top_k': 1,
    'repetition_penalty': 1.18,
    'encoder_repetition_penalty': 1.0,
    'suffix': None,
    'stream': False,
    'echo': False,
    'seed': -1,
    # 'n' : default(body, 'n', 1),  # 'n' doesn't have a direct map
    'truncation_length': 2048,
    'add_bos_token': True,
    'do_sample': True,
    'typical_p': 1.0,
    'epsilon_cutoff': 0.0,  # In units of 1e-4
    'eta_cutoff': 0.0,  # In units of 1e-4
    'tfs': 1.0,
    'top_a': 0.0,
    'min_length': 0,
    'no_repeat_ngram_size': 0,
    'num_beams': 1,
    'penalty_alpha': 0.0,
    'length_penalty': 1.0,
    'early_stopping': False,
    'mirostat_mode': 0,
    'mirostat_tau': 5.0,
    'mirostat_eta': 0.1,
    'ban_eos_token': False,
    'skip_special_tokens': True,
    'custom_stopping_strings': '',
}

# Optional, install the module and download the model to enable
# v1/embeddings
try:
    from sentence_transformers import SentenceTransformer
except ImportError:
    pass

st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embedding_model = None

# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
    val = dic.get(key, default)
    if type(val) != type(default):
        # maybe it's just something like 1 instead of 1.0
        try:
            v = type(default)(val)
            if type(val)(v) == val:  # if it's the same value passed in, it's ok.
                return v
        except:
            pass

        val = default
    return val


def clamp(value, minvalue, maxvalue):
    return max(minvalue, min(value, maxvalue))


def float_list_to_base64(float_list):
    # Convert the list to a float32 array that the OpenAPI client expects
    float_array = np.array(float_list, dtype="float32")

    # Get raw bytes
    bytes_array = float_array.tobytes()

    # Encode bytes into base64
    encoded_bytes = base64.b64encode(bytes_array)

    # Turn raw base64 encoded bytes into ASCII
    ascii_string = encoded_bytes.decode('ascii')
    return ascii_string


class Handler(BaseHTTPRequestHandler):
    def send_access_control_headers(self):
        self.send_header("Access-Control-Allow-Origin", "*")
        self.send_header("Access-Control-Allow-Credentials", "true")
        self.send_header(
            "Access-Control-Allow-Methods",
            "GET,HEAD,OPTIONS,POST,PUT"
        )
        self.send_header(
            "Access-Control-Allow-Headers",
            "Origin, Accept, X-Requested-With, Content-Type, "
            "Access-Control-Request-Method, Access-Control-Request-Headers, "
            "Authorization"
        )

    def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
        self.send_response(code)
        self.send_access_control_headers()
        self.send_header('Content-Type', 'application/json')
        self.end_headers()
        error_resp = {
            'error': {
                'message': message,
                'code': code,
                'type': error_type,
                'param': param,
            }
        }
        if internal_message:
            error_resp['internal_message'] = internal_message

        response = json.dumps(error_resp)
        self.wfile.write(response.encode('utf-8'))

    def do_OPTIONS(self):
        self.send_response(200)
        self.send_access_control_headers()
        self.send_header('Content-Type', 'application/json')
        self.end_headers()
        self.wfile.write("OK".encode('utf-8'))

    def do_GET(self):
        if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
            current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
            embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
            pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
                'gpt-3.5-turbo', # /v1/chat/completions
                'text-curie-001', # /v1/completions, 2k context
                'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
            ]

            is_legacy = 'engines' in self.path
            is_list = self.path in ['/v1/engines', '/v1/models']

            resp = ''

            if is_legacy and not is_list: # load model
                model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]

                resp = {
                    "id": model_name,
                    "object": "engine",
                    "owner": "self",
                    "ready": True,
                }
                if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
                    # No args. Maybe it works anyways!
                    # TODO: hack some heuristics into args for better results

                    shared.model_name = model_name
                    unload_model()

                    model_settings = get_model_settings_from_yamls(shared.model_name)
                    shared.settings.update(model_settings)
                    update_model_parameters(model_settings, initial=True)

                    if shared.settings['mode'] != 'instruct':
                        shared.settings['instruction_template'] = None

                    shared.model, shared.tokenizer = load_model(shared.model_name)

                    if not shared.model: # load failed.
                        shared.model_name = "None"
                        resp['id'] = "None"
                        resp['ready'] = False

            elif is_list:
                # TODO: Lora's?
                available_model_list = get_available_models()
                all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list

                models = {}

                if is_legacy:
                    models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
                    if not shared.model:
                        models[0]['ready'] = False
                else:
                    models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]

                resp = {
                    "object": "list",
                    "data": models,
                }

            else:
                the_model_name = self.path[len('/v1/models/'):]
                resp = {
                    "id": the_model_name,
                    "object": "model",
                    "owned_by": "user",
                    "permission": []
                }

            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()
            response = json.dumps(resp)
            self.wfile.write(response.encode('utf-8'))

        elif '/billing/usage' in self.path:
            # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            response = json.dumps({
                "total_usage": 0,
            })
            self.wfile.write(response.encode('utf-8'))

        else:
            self.send_error(404)

    def do_POST(self):
        if debug:
            print(self.headers)  # did you know... python-openai sends your linux kernel & python version?
        content_length = int(self.headers['Content-Length'])
        body = json.loads(self.rfile.read(content_length).decode('utf-8'))

        if debug:
            print(body)

        if '/completions' in self.path or '/generate' in self.path:

            if not shared.model:
                self.openai_error("No model loaded.")
                return

            is_legacy = '/generate' in self.path
            is_chat_request = 'chat' in self.path
            resp_list = 'data' if is_legacy else 'choices'

            # XXX model is ignored for now
            # model = body.get('model', shared.model_name) # ignored, use existing for now
            model = shared.model_name
            created_time = int(time.time())

            cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)

            # Request Parameters
            # Try to use openai defaults or map them to something with the same intent
            req_params = default_req_params.copy()
            stopping_strings = []

            if 'stop' in body:
                if isinstance(body['stop'], str):
                    stopping_strings.extend([body['stop']])
                elif isinstance(body['stop'], list):
                    stopping_strings.extend(body['stop'])

            truncation_length = default(shared.settings, 'truncation_length', 2048)
            truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)

            default_max_tokens = truncation_length if is_chat_request else 16  # completions default, chat default is 'inf' so we need to cap it.

            max_tokens_str = 'length' if is_legacy else 'max_tokens'
            max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
            # if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough

            req_params['max_new_tokens'] = max_tokens
            req_params['truncation_length'] = truncation_length
            req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
            req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
            req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
            req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
            req_params['stream'] = default(body, 'stream', default_req_params['stream'])
            req_params['echo'] = default(body, 'echo', default_req_params['echo'])
            req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
            req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])

            is_streaming = req_params['stream']

            self.send_response(200)
            self.send_access_control_headers()
            if is_streaming:
                self.send_header('Content-Type', 'text/event-stream')
                self.send_header('Cache-Control', 'no-cache')
                # self.send_header('Connection', 'keep-alive')
            else:
                self.send_header('Content-Type', 'application/json')
            self.end_headers()

            token_count = 0
            completion_token_count = 0
            prompt = ''
            stream_object_type = ''
            object_type = ''

            if is_chat_request:
                # Chat Completions
                stream_object_type = 'chat.completions.chunk'
                object_type = 'chat.completions'

                messages = body['messages']

                role_formats = {
                    'user': 'user: {message}\n',
                    'assistant': 'assistant: {message}\n',
                    'system': '{message}',
                    'context': 'You are a helpful assistant. Answer as concisely as possible.',
                    'prompt': 'assistant:',
                }

                # Instruct models can be much better
                if shared.settings['instruction_template']:
                    try:
                        instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))

                        template = instruct['turn_template']
                        system_message_template = "{message}"
                        system_message_default = instruct['context']
                        bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
                        user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
                        bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
                        bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
                
                        role_formats = {
                            'user': user_message_template,
                            'assistant': bot_message_template,
                            'system': system_message_template,
                            'context': system_message_default,
                            'prompt': bot_prompt,
                        }

                        if 'Alpaca' in shared.settings['instruction_template']:
                            stopping_strings.extend(['\n###'])
                        elif instruct['user']: # WizardLM and some others have no user prompt.
                            stopping_strings.extend(['\n' + instruct['user'], instruct['user']])

                        if debug:
                            print(f"Loaded instruction role format: {shared.settings['instruction_template']}")

                    except Exception as e:
                        stopping_strings.extend(['\nuser:'])

                        print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
                        print("Warning: Loaded default instruction-following template for model.")

                else:
                    stopping_strings.extend(['\nuser:'])
                    print("Warning: Loaded default instruction-following template for model.")

                system_msgs = []
                chat_msgs = []

                # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
                context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
                if context_msg:
                    system_msgs.extend([context_msg])

                # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
                if 'prompt' in body:
                    prompt_msg = role_formats['system'].format(message=body['prompt'])
                    system_msgs.extend([prompt_msg])

                for m in messages:
                    role = m['role']
                    content = m['content']
                    msg = role_formats[role].format(message=content)
                    if role == 'system':
                        system_msgs.extend([msg])
                    else:
                        chat_msgs.extend([msg])

                # can't really truncate the system messages
                system_msg = '\n'.join(system_msgs)
                if system_msg and system_msg[-1] != '\n':
                    system_msg = system_msg + '\n'

                system_token_count = len(encode(system_msg)[0])
                remaining_tokens = truncation_length - system_token_count
                chat_msg = ''

                while chat_msgs:
                    new_msg = chat_msgs.pop()
                    new_size = len(encode(new_msg)[0])
                    if new_size <= remaining_tokens:
                        chat_msg = new_msg + chat_msg
                        remaining_tokens -= new_size
                    else:
                        print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
                        break

                prompt = system_msg + chat_msg + role_formats['prompt']

                token_count = len(encode(prompt)[0])

            else:
                # Text Completions
                stream_object_type = 'text_completion.chunk'
                object_type = 'text_completion'

                # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
                if is_legacy:
                    prompt = body['context']  # Older engines.generate API
                else:
                    prompt = body['prompt']  # XXX this can be different types

                if isinstance(prompt, list):
                    self.openai_error("API Batched generation not yet supported.")
                    return

                token_count = len(encode(prompt)[0])
                if token_count >= truncation_length:
                    new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
                    prompt = prompt[-new_len:]
                    new_token_count = len(encode(prompt)[0])
                    print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
                    token_count = new_token_count

            if truncation_length - token_count < req_params['max_new_tokens']:
                print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
                req_params['max_new_tokens'] = truncation_length - token_count
                print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")

            if is_streaming:
                # begin streaming
                chunk = {
                    "id": cmpl_id,
                    "object": stream_object_type,
                    "created": created_time,
                    "model": shared.model_name,
                    resp_list: [{
                        "index": 0,
                        "finish_reason": None,
                    }],
                }

                if stream_object_type == 'text_completion.chunk':
                    chunk[resp_list][0]["text"] = ""
                else:
                    # So yeah... do both methods? delta and messages.
                    chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
                    chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}

                response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
                self.wfile.write(response.encode('utf-8'))

            # generate reply #######################################
            if debug:
                print({'prompt': prompt, 'req_params': req_params})
            generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)

            answer = ''
            seen_content = ''
            longest_stop_len = max([len(x) for x in stopping_strings] + [0])

            for a in generator:
                answer = a

                stop_string_found = False
                len_seen = len(seen_content)
                search_start = max(len_seen - longest_stop_len, 0)

                for string in stopping_strings:
                    idx = answer.find(string, search_start)
                    if idx != -1:
                        answer = answer[:idx]  # clip it.
                        stop_string_found = True

                if stop_string_found:
                    break

                # If something like "\nYo" is generated just before "\nYou:"
                # is completed, buffer and generate more, don't send it
                buffer_and_continue = False

                for string in stopping_strings:
                    for j in range(len(string) - 1, 0, -1):
                        if answer[-j:] == string[:j]:
                            buffer_and_continue = True
                            break
                    else:
                        continue
                    break

                if buffer_and_continue:
                    continue

                if is_streaming:
                    # Streaming
                    new_content = answer[len_seen:]

                    if not new_content or chr(0xfffd) in new_content:  # partial unicode character, don't send it yet.
                        continue

                    seen_content = answer
                    chunk = {
                        "id": cmpl_id,
                        "object": stream_object_type,
                        "created": created_time,
                        "model": shared.model_name,
                        resp_list: [{
                            "index": 0,
                            "finish_reason": None,
                        }],
                    }

                    # strip extra leading space off new generated content
                    if len_seen == 0 and new_content[0] == ' ':
                        new_content = new_content[1:]

                    if stream_object_type == 'text_completion.chunk':
                        chunk[resp_list][0]['text'] = new_content
                    else:
                        # So yeah... do both methods? delta and messages.
                        chunk[resp_list][0]['message'] = {'content': new_content}
                        chunk[resp_list][0]['delta'] = {'content': new_content}
                    response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
                    self.wfile.write(response.encode('utf-8'))
                    completion_token_count += len(encode(new_content)[0])

            if is_streaming:
                chunk = {
                    "id": cmpl_id,
                    "object": stream_object_type,
                    "created": created_time,
                    "model": model,  # TODO: add Lora info?
                    resp_list: [{
                        "index": 0,
                        "finish_reason": "stop",
                    }],
                    "usage": {
                        "prompt_tokens": token_count,
                        "completion_tokens": completion_token_count,
                        "total_tokens": token_count + completion_token_count
                    }
                }
                if stream_object_type == 'text_completion.chunk':
                    chunk[resp_list][0]['text'] = ''
                else:
                    # So yeah... do both methods? delta and messages.
                    chunk[resp_list][0]['message'] = {'content': ''}
                    chunk[resp_list][0]['delta'] = {'content': ''}

                response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
                self.wfile.write(response.encode('utf-8'))
                # Finished if streaming.
                if debug:
                    if answer and answer[0] == ' ':
                        answer = answer[1:]
                    print({'answer': answer}, chunk)
                return

            # strip extra leading space off new generated content
            if answer and answer[0] == ' ':
                answer = answer[1:]

            if debug:
                print({'response': answer})

            completion_token_count = len(encode(answer)[0])
            stop_reason = "stop"
            if token_count + completion_token_count >= truncation_length:
                stop_reason = "length"

            resp = {
                "id": cmpl_id,
                "object": object_type,
                "created": created_time,
                "model": model,  # TODO: add Lora info?
                resp_list: [{
                    "index": 0,
                    "finish_reason": stop_reason,
                }],
                "usage": {
                    "prompt_tokens": token_count,
                    "completion_tokens": completion_token_count,
                    "total_tokens": token_count + completion_token_count
                }
            }

            if is_chat_request:
                resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
            else:
                resp[resp_list][0]["text"] = answer

            response = json.dumps(resp)
            self.wfile.write(response.encode('utf-8'))

        elif '/edits' in self.path:
            if not shared.model:
                self.openai_error("No model loaded.")
                return

            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            created_time = int(time.time())

            # Using Alpaca format, this may work with other models too.
            instruction = body['instruction']
            input = body.get('input', '')

            # Request parameters
            req_params = default_req_params.copy()
            stopping_strings = []

            # Alpaca is verbose so a good default prompt
            default_template = (
                "Below is an instruction that describes a task, paired with an input that provides further context. "
                "Write a response that appropriately completes the request.\n\n"
                "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
            )

            instruction_template = default_template
            
            # Use the special instruction/input/response template for anything trained like Alpaca
            if shared.settings['instruction_template']:
                if 'Alpaca' in shared.settings['instruction_template']:
                    stopping_strings.extend(['\n###'])
                else:
                    try:
                        instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))

                        template = instruct['turn_template']
                        template = template\
                            .replace('<|user|>', instruct.get('user', ''))\
                            .replace('<|bot|>', instruct.get('bot', ''))\
                            .replace('<|user-message|>', '{instruction}\n{input}')

                        instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
                        if instruct['user']:
                            stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])

                    except Exception as e:
                        instruction_template = default_template
                        print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
                        print("Warning: Loaded default instruction-following template (Alpaca) for model.")
            else:
                stopping_strings.extend(['\n###'])
                print("Warning: Loaded default instruction-following template (Alpaca) for model.")
                

            edit_task = instruction_template.format(instruction=instruction, input=input)

            truncation_length = default(shared.settings, 'truncation_length', 2048)
            token_count = len(encode(edit_task)[0])
            max_tokens = truncation_length - token_count

            req_params['max_new_tokens'] = max_tokens
            req_params['truncation_length'] = truncation_length
            req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
            req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
            req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
            req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])

            if debug:
                print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
            
            generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)

            longest_stop_len = max([len(x) for x in stopping_strings] + [0])
            answer = ''
            seen_content = ''
            for a in generator:
                answer = a

                stop_string_found = False
                len_seen = len(seen_content)
                search_start = max(len_seen - longest_stop_len, 0)

                for string in stopping_strings:
                    idx = answer.find(string, search_start)
                    if idx != -1:
                        answer = answer[:idx]  # clip it.
                        stop_string_found = True

                if stop_string_found:
                    break


            # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
            if edit_task[-1] != '\n' and answer and answer[0] == ' ':
                answer = answer[1:]

            completion_token_count = len(encode(answer)[0])

            resp = {
                "object": "edit",
                "created": created_time,
                "choices": [{
                    "text": answer,
                    "index": 0,
                }],
                "usage": {
                    "prompt_tokens": token_count,
                    "completion_tokens": completion_token_count,
                    "total_tokens": token_count + completion_token_count
                }
            }

            if debug:
                print({'answer': answer, 'completion_token_count': completion_token_count})

            response = json.dumps(resp)
            self.wfile.write(response.encode('utf-8'))

        elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
            # Stable Diffusion callout wrapper for txt2img
            # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
            # the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
            # If you want high quality tailored results you should just use the Stable Diffusion API directly.
            # it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
            # Will probably work best with the stock SD models.
            # SD configuration is beyond the scope of this API.
            # At this point I will not add the edits and variations endpoints (ie. img2img) because they
            # require changing the form data handling to accept multipart form data, also to properly support
            # url return types will require file management and a web serving files... Perhaps later!

            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ]  # ignore the restrictions on size
            response_format = default(body, 'response_format', 'url')  # or b64_json
            
            payload = {
                'prompt': body['prompt'],  # ignore prompt limit of 1000 characters
                'width': width,
                'height': height,
                'batch_size': default(body, 'n', 1)  # ignore the batch limits of max 10
            }

            resp = {
                'created': int(time.time()),
                'data': []
            }

            # TODO: support SD_WEBUI_AUTH username:password pair.
            sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"

            response = requests.post(url=sd_url, json=payload)
            r = response.json()
            # r['parameters']...
            for b64_json in r['images']:
                if response_format == 'b64_json':
                    resp['data'].extend([{'b64_json': b64_json}])
                else:
                    resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}])  # yeah it's lazy. requests.get() will not work with this

            response = json.dumps(resp)
            self.wfile.write(response.encode('utf-8'))

        elif '/embeddings' in self.path and embedding_model is not None:
            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            input = body['input'] if 'input' in body else body['text']
            if type(input) is str:
                input = [input]

            embeddings = embedding_model.encode(input).tolist()

            def enc_emb(emb):
                # If base64 is specified, encode. Otherwise, do nothing.
                if body.get("encoding_format", "") == "base64":
                    return float_list_to_base64(emb)
                else:
                    return emb
            data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]

            response = json.dumps({
                "object": "list",
                "data": data,
                "model": st_model,  # return the real model
                "usage": {
                    "prompt_tokens": 0,
                    "total_tokens": 0,
                }
            })

            if debug:
                print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
            self.wfile.write(response.encode('utf-8'))

        elif '/moderations' in self.path:
            # for now do nothing, just don't error.
            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            response = json.dumps({
                "id": "modr-5MWoLO",
                "model": "text-moderation-001",
                "results": [{
                    "categories": {
                        "hate": False,
                        "hate/threatening": False,
                        "self-harm": False,
                        "sexual": False,
                        "sexual/minors": False,
                        "violence": False,
                        "violence/graphic": False
                    },
                    "category_scores": {
                        "hate": 0.0,
                        "hate/threatening": 0.0,
                        "self-harm": 0.0,
                        "sexual": 0.0,
                        "sexual/minors": 0.0,
                        "violence": 0.0,
                        "violence/graphic": 0.0
                    },
                    "flagged": False
                }]
            })
            self.wfile.write(response.encode('utf-8'))

        elif self.path == '/api/v1/token-count':
            # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
            self.send_response(200)
            self.send_access_control_headers()
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            tokens = encode(body['prompt'])[0]
            response = json.dumps({
                'results': [{
                    'tokens': len(tokens)
                }]
            })
            self.wfile.write(response.encode('utf-8'))

        else:
            print(self.path, self.headers)
            self.send_error(404)


def run_server():
    global embedding_model
    try:
        embedding_model = SentenceTransformer(st_model)
        print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}")
    except:
        print(f"\nFailed to load embedding model: {st_model}")
        pass

    server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
    server = ThreadingHTTPServer(server_addr, Handler)
    if shared.args.share:
        try:
            from flask_cloudflared import _run_cloudflared
            public_url = _run_cloudflared(params['port'], params['port'] + 1)
            print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
        except ImportError:
            print('You should install flask_cloudflared manually')
    else:
        print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
        
    server.serve_forever()


def setup():
    Thread(target=run_server, daemon=True).start()