|
import base64 |
|
import json |
|
import os |
|
import time |
|
import requests |
|
import yaml |
|
import numpy as np |
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer |
|
from threading import Thread |
|
from modules.utils import get_available_models |
|
from modules.models import load_model, unload_model |
|
from modules.models_settings import (get_model_settings_from_yamls, |
|
update_model_parameters) |
|
|
|
from modules import shared |
|
from modules.text_generation import encode, generate_reply |
|
|
|
params = { |
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, |
|
} |
|
|
|
debug = True if 'OPENEDAI_DEBUG' in os.environ else False |
|
|
|
|
|
|
|
default_req_params = { |
|
'max_new_tokens': 200, |
|
'temperature': 1.0, |
|
'top_p': 1.0, |
|
'top_k': 1, |
|
'repetition_penalty': 1.18, |
|
'encoder_repetition_penalty': 1.0, |
|
'suffix': None, |
|
'stream': False, |
|
'echo': False, |
|
'seed': -1, |
|
|
|
'truncation_length': 2048, |
|
'add_bos_token': True, |
|
'do_sample': True, |
|
'typical_p': 1.0, |
|
'epsilon_cutoff': 0.0, |
|
'eta_cutoff': 0.0, |
|
'tfs': 1.0, |
|
'top_a': 0.0, |
|
'min_length': 0, |
|
'no_repeat_ngram_size': 0, |
|
'num_beams': 1, |
|
'penalty_alpha': 0.0, |
|
'length_penalty': 1.0, |
|
'early_stopping': False, |
|
'mirostat_mode': 0, |
|
'mirostat_tau': 5.0, |
|
'mirostat_eta': 0.1, |
|
'ban_eos_token': False, |
|
'skip_special_tokens': True, |
|
'custom_stopping_strings': '', |
|
} |
|
|
|
|
|
|
|
try: |
|
from sentence_transformers import SentenceTransformer |
|
except ImportError: |
|
pass |
|
|
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" |
|
embedding_model = None |
|
|
|
|
|
def default(dic, key, default): |
|
val = dic.get(key, default) |
|
if type(val) != type(default): |
|
|
|
try: |
|
v = type(default)(val) |
|
if type(val)(v) == val: |
|
return v |
|
except: |
|
pass |
|
|
|
val = default |
|
return val |
|
|
|
|
|
def clamp(value, minvalue, maxvalue): |
|
return max(minvalue, min(value, maxvalue)) |
|
|
|
|
|
def float_list_to_base64(float_list): |
|
|
|
float_array = np.array(float_list, dtype="float32") |
|
|
|
|
|
bytes_array = float_array.tobytes() |
|
|
|
|
|
encoded_bytes = base64.b64encode(bytes_array) |
|
|
|
|
|
ascii_string = encoded_bytes.decode('ascii') |
|
return ascii_string |
|
|
|
|
|
class Handler(BaseHTTPRequestHandler): |
|
def send_access_control_headers(self): |
|
self.send_header("Access-Control-Allow-Origin", "*") |
|
self.send_header("Access-Control-Allow-Credentials", "true") |
|
self.send_header( |
|
"Access-Control-Allow-Methods", |
|
"GET,HEAD,OPTIONS,POST,PUT" |
|
) |
|
self.send_header( |
|
"Access-Control-Allow-Headers", |
|
"Origin, Accept, X-Requested-With, Content-Type, " |
|
"Access-Control-Request-Method, Access-Control-Request-Headers, " |
|
"Authorization" |
|
) |
|
|
|
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''): |
|
self.send_response(code) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
error_resp = { |
|
'error': { |
|
'message': message, |
|
'code': code, |
|
'type': error_type, |
|
'param': param, |
|
} |
|
} |
|
if internal_message: |
|
error_resp['internal_message'] = internal_message |
|
|
|
response = json.dumps(error_resp) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
def do_OPTIONS(self): |
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
self.wfile.write("OK".encode('utf-8')) |
|
|
|
def do_GET(self): |
|
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): |
|
current_model_list = [ shared.model_name ] |
|
embeddings_model_list = [ st_model ] if embedding_model else [] |
|
pseudo_model_list = [ |
|
'gpt-3.5-turbo', |
|
'text-curie-001', |
|
'text-davinci-002' |
|
] |
|
|
|
is_legacy = 'engines' in self.path |
|
is_list = self.path in ['/v1/engines', '/v1/models'] |
|
|
|
resp = '' |
|
|
|
if is_legacy and not is_list: |
|
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):] |
|
|
|
resp = { |
|
"id": model_name, |
|
"object": "engine", |
|
"owner": "self", |
|
"ready": True, |
|
} |
|
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: |
|
|
|
|
|
|
|
shared.model_name = model_name |
|
unload_model() |
|
|
|
model_settings = get_model_settings_from_yamls(shared.model_name) |
|
shared.settings.update(model_settings) |
|
update_model_parameters(model_settings, initial=True) |
|
|
|
if shared.settings['mode'] != 'instruct': |
|
shared.settings['instruction_template'] = None |
|
|
|
shared.model, shared.tokenizer = load_model(shared.model_name) |
|
|
|
if not shared.model: |
|
shared.model_name = "None" |
|
resp['id'] = "None" |
|
resp['ready'] = False |
|
|
|
elif is_list: |
|
|
|
available_model_list = get_available_models() |
|
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list |
|
|
|
models = {} |
|
|
|
if is_legacy: |
|
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ] |
|
if not shared.model: |
|
models[0]['ready'] = False |
|
else: |
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ] |
|
|
|
resp = { |
|
"object": "list", |
|
"data": models, |
|
} |
|
|
|
else: |
|
the_model_name = self.path[len('/v1/models/'):] |
|
resp = { |
|
"id": the_model_name, |
|
"object": "model", |
|
"owned_by": "user", |
|
"permission": [] |
|
} |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
response = json.dumps(resp) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif '/billing/usage' in self.path: |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
response = json.dumps({ |
|
"total_usage": 0, |
|
}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
else: |
|
self.send_error(404) |
|
|
|
def do_POST(self): |
|
if debug: |
|
print(self.headers) |
|
content_length = int(self.headers['Content-Length']) |
|
body = json.loads(self.rfile.read(content_length).decode('utf-8')) |
|
|
|
if debug: |
|
print(body) |
|
|
|
if '/completions' in self.path or '/generate' in self.path: |
|
|
|
if not shared.model: |
|
self.openai_error("No model loaded.") |
|
return |
|
|
|
is_legacy = '/generate' in self.path |
|
is_chat_request = 'chat' in self.path |
|
resp_list = 'data' if is_legacy else 'choices' |
|
|
|
|
|
|
|
model = shared.model_name |
|
created_time = int(time.time()) |
|
|
|
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time) |
|
|
|
|
|
|
|
req_params = default_req_params.copy() |
|
stopping_strings = [] |
|
|
|
if 'stop' in body: |
|
if isinstance(body['stop'], str): |
|
stopping_strings.extend([body['stop']]) |
|
elif isinstance(body['stop'], list): |
|
stopping_strings.extend(body['stop']) |
|
|
|
truncation_length = default(shared.settings, 'truncation_length', 2048) |
|
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) |
|
|
|
default_max_tokens = truncation_length if is_chat_request else 16 |
|
|
|
max_tokens_str = 'length' if is_legacy else 'max_tokens' |
|
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) |
|
|
|
|
|
req_params['max_new_tokens'] = max_tokens |
|
req_params['truncation_length'] = truncation_length |
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) |
|
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) |
|
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k']) |
|
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix']) |
|
req_params['stream'] = default(body, 'stream', default_req_params['stream']) |
|
req_params['echo'] = default(body, 'echo', default_req_params['echo']) |
|
req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) |
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) |
|
|
|
is_streaming = req_params['stream'] |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
if is_streaming: |
|
self.send_header('Content-Type', 'text/event-stream') |
|
self.send_header('Cache-Control', 'no-cache') |
|
|
|
else: |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
token_count = 0 |
|
completion_token_count = 0 |
|
prompt = '' |
|
stream_object_type = '' |
|
object_type = '' |
|
|
|
if is_chat_request: |
|
|
|
stream_object_type = 'chat.completions.chunk' |
|
object_type = 'chat.completions' |
|
|
|
messages = body['messages'] |
|
|
|
role_formats = { |
|
'user': 'user: {message}\n', |
|
'assistant': 'assistant: {message}\n', |
|
'system': '{message}', |
|
'context': 'You are a helpful assistant. Answer as concisely as possible.', |
|
'prompt': 'assistant:', |
|
} |
|
|
|
|
|
if shared.settings['instruction_template']: |
|
try: |
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) |
|
|
|
template = instruct['turn_template'] |
|
system_message_template = "{message}" |
|
system_message_default = instruct['context'] |
|
bot_start = template.find('<|bot|>') |
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user']) |
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot']) |
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') |
|
|
|
role_formats = { |
|
'user': user_message_template, |
|
'assistant': bot_message_template, |
|
'system': system_message_template, |
|
'context': system_message_default, |
|
'prompt': bot_prompt, |
|
} |
|
|
|
if 'Alpaca' in shared.settings['instruction_template']: |
|
stopping_strings.extend(['\n###']) |
|
elif instruct['user']: |
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user']]) |
|
|
|
if debug: |
|
print(f"Loaded instruction role format: {shared.settings['instruction_template']}") |
|
|
|
except Exception as e: |
|
stopping_strings.extend(['\nuser:']) |
|
|
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}") |
|
print("Warning: Loaded default instruction-following template for model.") |
|
|
|
else: |
|
stopping_strings.extend(['\nuser:']) |
|
print("Warning: Loaded default instruction-following template for model.") |
|
|
|
system_msgs = [] |
|
chat_msgs = [] |
|
|
|
|
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' |
|
if context_msg: |
|
system_msgs.extend([context_msg]) |
|
|
|
|
|
if 'prompt' in body: |
|
prompt_msg = role_formats['system'].format(message=body['prompt']) |
|
system_msgs.extend([prompt_msg]) |
|
|
|
for m in messages: |
|
role = m['role'] |
|
content = m['content'] |
|
msg = role_formats[role].format(message=content) |
|
if role == 'system': |
|
system_msgs.extend([msg]) |
|
else: |
|
chat_msgs.extend([msg]) |
|
|
|
|
|
system_msg = '\n'.join(system_msgs) |
|
if system_msg and system_msg[-1] != '\n': |
|
system_msg = system_msg + '\n' |
|
|
|
system_token_count = len(encode(system_msg)[0]) |
|
remaining_tokens = truncation_length - system_token_count |
|
chat_msg = '' |
|
|
|
while chat_msgs: |
|
new_msg = chat_msgs.pop() |
|
new_size = len(encode(new_msg)[0]) |
|
if new_size <= remaining_tokens: |
|
chat_msg = new_msg + chat_msg |
|
remaining_tokens -= new_size |
|
else: |
|
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).") |
|
break |
|
|
|
prompt = system_msg + chat_msg + role_formats['prompt'] |
|
|
|
token_count = len(encode(prompt)[0]) |
|
|
|
else: |
|
|
|
stream_object_type = 'text_completion.chunk' |
|
object_type = 'text_completion' |
|
|
|
|
|
if is_legacy: |
|
prompt = body['context'] |
|
else: |
|
prompt = body['prompt'] |
|
|
|
if isinstance(prompt, list): |
|
self.openai_error("API Batched generation not yet supported.") |
|
return |
|
|
|
token_count = len(encode(prompt)[0]) |
|
if token_count >= truncation_length: |
|
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count) |
|
prompt = prompt[-new_len:] |
|
new_token_count = len(encode(prompt)[0]) |
|
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.") |
|
token_count = new_token_count |
|
|
|
if truncation_length - token_count < req_params['max_new_tokens']: |
|
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}") |
|
req_params['max_new_tokens'] = truncation_length - token_count |
|
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}") |
|
|
|
if is_streaming: |
|
|
|
chunk = { |
|
"id": cmpl_id, |
|
"object": stream_object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": None, |
|
}], |
|
} |
|
|
|
if stream_object_type == 'text_completion.chunk': |
|
chunk[resp_list][0]["text"] = "" |
|
else: |
|
|
|
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} |
|
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} |
|
|
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
|
|
if debug: |
|
print({'prompt': prompt, 'req_params': req_params}) |
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
answer = '' |
|
seen_content = '' |
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0]) |
|
|
|
for a in generator: |
|
answer = a |
|
|
|
stop_string_found = False |
|
len_seen = len(seen_content) |
|
search_start = max(len_seen - longest_stop_len, 0) |
|
|
|
for string in stopping_strings: |
|
idx = answer.find(string, search_start) |
|
if idx != -1: |
|
answer = answer[:idx] |
|
stop_string_found = True |
|
|
|
if stop_string_found: |
|
break |
|
|
|
|
|
|
|
buffer_and_continue = False |
|
|
|
for string in stopping_strings: |
|
for j in range(len(string) - 1, 0, -1): |
|
if answer[-j:] == string[:j]: |
|
buffer_and_continue = True |
|
break |
|
else: |
|
continue |
|
break |
|
|
|
if buffer_and_continue: |
|
continue |
|
|
|
if is_streaming: |
|
|
|
new_content = answer[len_seen:] |
|
|
|
if not new_content or chr(0xfffd) in new_content: |
|
continue |
|
|
|
seen_content = answer |
|
chunk = { |
|
"id": cmpl_id, |
|
"object": stream_object_type, |
|
"created": created_time, |
|
"model": shared.model_name, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": None, |
|
}], |
|
} |
|
|
|
|
|
if len_seen == 0 and new_content[0] == ' ': |
|
new_content = new_content[1:] |
|
|
|
if stream_object_type == 'text_completion.chunk': |
|
chunk[resp_list][0]['text'] = new_content |
|
else: |
|
|
|
chunk[resp_list][0]['message'] = {'content': new_content} |
|
chunk[resp_list][0]['delta'] = {'content': new_content} |
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' |
|
self.wfile.write(response.encode('utf-8')) |
|
completion_token_count += len(encode(new_content)[0]) |
|
|
|
if is_streaming: |
|
chunk = { |
|
"id": cmpl_id, |
|
"object": stream_object_type, |
|
"created": created_time, |
|
"model": model, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": "stop", |
|
}], |
|
"usage": { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
} |
|
if stream_object_type == 'text_completion.chunk': |
|
chunk[resp_list][0]['text'] = '' |
|
else: |
|
|
|
chunk[resp_list][0]['message'] = {'content': ''} |
|
chunk[resp_list][0]['delta'] = {'content': ''} |
|
|
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n' |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
if debug: |
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
print({'answer': answer}, chunk) |
|
return |
|
|
|
|
|
if answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
if debug: |
|
print({'response': answer}) |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
stop_reason = "stop" |
|
if token_count + completion_token_count >= truncation_length: |
|
stop_reason = "length" |
|
|
|
resp = { |
|
"id": cmpl_id, |
|
"object": object_type, |
|
"created": created_time, |
|
"model": model, |
|
resp_list: [{ |
|
"index": 0, |
|
"finish_reason": stop_reason, |
|
}], |
|
"usage": { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
} |
|
|
|
if is_chat_request: |
|
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer} |
|
else: |
|
resp[resp_list][0]["text"] = answer |
|
|
|
response = json.dumps(resp) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif '/edits' in self.path: |
|
if not shared.model: |
|
self.openai_error("No model loaded.") |
|
return |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
created_time = int(time.time()) |
|
|
|
|
|
instruction = body['instruction'] |
|
input = body.get('input', '') |
|
|
|
|
|
req_params = default_req_params.copy() |
|
stopping_strings = [] |
|
|
|
|
|
default_template = ( |
|
"Below is an instruction that describes a task, paired with an input that provides further context. " |
|
"Write a response that appropriately completes the request.\n\n" |
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" |
|
) |
|
|
|
instruction_template = default_template |
|
|
|
|
|
if shared.settings['instruction_template']: |
|
if 'Alpaca' in shared.settings['instruction_template']: |
|
stopping_strings.extend(['\n###']) |
|
else: |
|
try: |
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r')) |
|
|
|
template = instruct['turn_template'] |
|
template = template\ |
|
.replace('<|user|>', instruct.get('user', ''))\ |
|
.replace('<|bot|>', instruct.get('bot', ''))\ |
|
.replace('<|user-message|>', '{instruction}\n{input}') |
|
|
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') |
|
if instruct['user']: |
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ]) |
|
|
|
except Exception as e: |
|
instruction_template = default_template |
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}") |
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.") |
|
else: |
|
stopping_strings.extend(['\n###']) |
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.") |
|
|
|
|
|
edit_task = instruction_template.format(instruction=instruction, input=input) |
|
|
|
truncation_length = default(shared.settings, 'truncation_length', 2048) |
|
token_count = len(encode(edit_task)[0]) |
|
max_tokens = truncation_length - token_count |
|
|
|
req_params['max_new_tokens'] = max_tokens |
|
req_params['truncation_length'] = truncation_length |
|
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) |
|
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0) |
|
req_params['seed'] = shared.settings.get('seed', default_req_params['seed']) |
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token']) |
|
|
|
if debug: |
|
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) |
|
|
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) |
|
|
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0]) |
|
answer = '' |
|
seen_content = '' |
|
for a in generator: |
|
answer = a |
|
|
|
stop_string_found = False |
|
len_seen = len(seen_content) |
|
search_start = max(len_seen - longest_stop_len, 0) |
|
|
|
for string in stopping_strings: |
|
idx = answer.find(string, search_start) |
|
if idx != -1: |
|
answer = answer[:idx] |
|
stop_string_found = True |
|
|
|
if stop_string_found: |
|
break |
|
|
|
|
|
|
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ': |
|
answer = answer[1:] |
|
|
|
completion_token_count = len(encode(answer)[0]) |
|
|
|
resp = { |
|
"object": "edit", |
|
"created": created_time, |
|
"choices": [{ |
|
"text": answer, |
|
"index": 0, |
|
}], |
|
"usage": { |
|
"prompt_tokens": token_count, |
|
"completion_tokens": completion_token_count, |
|
"total_tokens": token_count + completion_token_count |
|
} |
|
} |
|
|
|
if debug: |
|
print({'answer': answer, 'completion_token_count': completion_token_count}) |
|
|
|
response = json.dumps(resp) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] |
|
response_format = default(body, 'response_format', 'url') |
|
|
|
payload = { |
|
'prompt': body['prompt'], |
|
'width': width, |
|
'height': height, |
|
'batch_size': default(body, 'n', 1) |
|
} |
|
|
|
resp = { |
|
'created': int(time.time()), |
|
'data': [] |
|
} |
|
|
|
|
|
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img" |
|
|
|
response = requests.post(url=sd_url, json=payload) |
|
r = response.json() |
|
|
|
for b64_json in r['images']: |
|
if response_format == 'b64_json': |
|
resp['data'].extend([{'b64_json': b64_json}]) |
|
else: |
|
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) |
|
|
|
response = json.dumps(resp) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif '/embeddings' in self.path and embedding_model is not None: |
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
input = body['input'] if 'input' in body else body['text'] |
|
if type(input) is str: |
|
input = [input] |
|
|
|
embeddings = embedding_model.encode(input).tolist() |
|
|
|
def enc_emb(emb): |
|
|
|
if body.get("encoding_format", "") == "base64": |
|
return float_list_to_base64(emb) |
|
else: |
|
return emb |
|
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)] |
|
|
|
response = json.dumps({ |
|
"object": "list", |
|
"data": data, |
|
"model": st_model, |
|
"usage": { |
|
"prompt_tokens": 0, |
|
"total_tokens": 0, |
|
} |
|
}) |
|
|
|
if debug: |
|
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif '/moderations' in self.path: |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
response = json.dumps({ |
|
"id": "modr-5MWoLO", |
|
"model": "text-moderation-001", |
|
"results": [{ |
|
"categories": { |
|
"hate": False, |
|
"hate/threatening": False, |
|
"self-harm": False, |
|
"sexual": False, |
|
"sexual/minors": False, |
|
"violence": False, |
|
"violence/graphic": False |
|
}, |
|
"category_scores": { |
|
"hate": 0.0, |
|
"hate/threatening": 0.0, |
|
"self-harm": 0.0, |
|
"sexual": 0.0, |
|
"sexual/minors": 0.0, |
|
"violence": 0.0, |
|
"violence/graphic": 0.0 |
|
}, |
|
"flagged": False |
|
}] |
|
}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
elif self.path == '/api/v1/token-count': |
|
|
|
self.send_response(200) |
|
self.send_access_control_headers() |
|
self.send_header('Content-Type', 'application/json') |
|
self.end_headers() |
|
|
|
tokens = encode(body['prompt'])[0] |
|
response = json.dumps({ |
|
'results': [{ |
|
'tokens': len(tokens) |
|
}] |
|
}) |
|
self.wfile.write(response.encode('utf-8')) |
|
|
|
else: |
|
print(self.path, self.headers) |
|
self.send_error(404) |
|
|
|
|
|
def run_server(): |
|
global embedding_model |
|
try: |
|
embedding_model = SentenceTransformer(st_model) |
|
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}") |
|
except: |
|
print(f"\nFailed to load embedding model: {st_model}") |
|
pass |
|
|
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) |
|
server = ThreadingHTTPServer(server_addr, Handler) |
|
if shared.args.share: |
|
try: |
|
from flask_cloudflared import _run_cloudflared |
|
public_url = _run_cloudflared(params['port'], params['port'] + 1) |
|
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1') |
|
except ImportError: |
|
print('You should install flask_cloudflared manually') |
|
else: |
|
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') |
|
|
|
server.serve_forever() |
|
|
|
|
|
def setup(): |
|
Thread(target=run_server, daemon=True).start() |
|
|