Spaces:
Build error
Build error
File size: 4,263 Bytes
292c2df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import logging
import traceback
from functools import partial
import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
def apply_settings(extension, name):
if not hasattr(extension, 'params'):
return
for param in extension.params:
_id = f"{name}-{param}"
if _id not in shared.settings:
continue
extension.params[param] = shared.settings[_id]
def load_extensions():
global state, setup_called
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
logging.info(f'Loading the extension "{name}"...')
try:
exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script
apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
state[name] = [True, i]
except:
logging.error('Failed to load the extension "{name}".')
traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line
def iterator():
for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]:
yield getattr(extensions, name).script, name
# Extension functions that map string -> string
def _apply_string_extensions(function_name, text):
for extension, _ in iterator():
if hasattr(extension, function_name):
text = getattr(extension, function_name)(text)
return text
# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
for extension, _ in iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
if callable(extension.input_hijack['value']):
text, visible_text = extension.input_hijack['value'](text, visible_text)
else:
text, visible_text = extension.input_hijack['value']
return text, visible_text
# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
return None
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}
def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block():
global setup_called
should_display_ui = False
for extension, name in iterator():
if hasattr(extension, "ui"):
should_display_ui = True
break
# Creating the extension ui elements
if should_display_ui:
with gr.Column(elem_id="extensions"):
for extension, name in iterator():
if hasattr(extension, "ui"):
gr.Markdown(f"\n### {name}")
extension.ui()
|