Spaces:
Runtime error
Runtime error
import json | |
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple | |
import gradio as gr | |
from gradio.components import Component # cannot use TYPE_CHECKING here | |
from ..chat import ChatModel | |
from ..data import Role | |
from ..extras.misc import torch_gc | |
from ..hparams import GeneratingArguments | |
from .common import get_save_dir | |
from .locales import ALERTS | |
if TYPE_CHECKING: | |
from .manager import Manager | |
class WebChatModel(ChatModel): | |
def __init__( | |
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True | |
) -> None: | |
self.manager = manager | |
self.demo_mode = demo_mode | |
self.model = None | |
self.tokenizer = None | |
self.generating_args = GeneratingArguments() | |
if not lazy_init: # read arguments from command line | |
super().__init__() | |
if demo_mode: # load demo_config.json if exists | |
import json | |
try: | |
with open("demo_config.json", "r", encoding="utf-8") as f: | |
args = json.load(f) | |
assert args.get("model_name_or_path", None) and args.get("template", None) | |
super().__init__(args) | |
except AssertionError: | |
print("Please provided model name and template in `demo_config.json`.") | |
except Exception: | |
print("Cannot find `demo_config.json` at current directory.") | |
def loaded(self) -> bool: | |
return self.model is not None | |
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: | |
get = lambda name: data[self.manager.get_elem_by_name(name)] | |
lang = get("top.lang") | |
error = "" | |
if self.loaded: | |
error = ALERTS["err_exists"][lang] | |
elif not get("top.model_name"): | |
error = ALERTS["err_no_model"][lang] | |
elif not get("top.model_path"): | |
error = ALERTS["err_no_path"][lang] | |
elif self.demo_mode: | |
error = ALERTS["err_demo"][lang] | |
if error: | |
gr.Warning(error) | |
yield error | |
return | |
if get("top.adapter_path"): | |
adapter_name_or_path = ",".join( | |
[ | |
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) | |
for adapter in get("top.adapter_path") | |
] | |
) | |
else: | |
adapter_name_or_path = None | |
yield ALERTS["info_loading"][lang] | |
args = dict( | |
model_name_or_path=get("top.model_path"), | |
adapter_name_or_path=adapter_name_or_path, | |
finetuning_type=get("top.finetuning_type"), | |
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, | |
template=get("top.template"), | |
flash_attn=(get("top.booster") == "flash_attn"), | |
use_unsloth=(get("top.booster") == "unsloth"), | |
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, | |
) | |
super().__init__(args) | |
yield ALERTS["info_loaded"][lang] | |
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: | |
lang = data[self.manager.get_elem_by_name("top.lang")] | |
if self.demo_mode: | |
gr.Warning(ALERTS["err_demo"][lang]) | |
yield ALERTS["err_demo"][lang] | |
return | |
yield ALERTS["info_unloading"][lang] | |
self.model = None | |
self.tokenizer = None | |
torch_gc() | |
yield ALERTS["info_unloaded"][lang] | |
def predict( | |
self, | |
chatbot: List[Tuple[str, str]], | |
query: str, | |
messages: Sequence[Tuple[str, str]], | |
system: str, | |
tools: str, | |
max_new_tokens: int, | |
top_p: float, | |
temperature: float, | |
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: | |
chatbot.append([query, ""]) | |
query_messages = messages + [{"role": Role.USER, "content": query}] | |
response = "" | |
for new_text in self.stream_chat( | |
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature | |
): | |
response += new_text | |
if tools: | |
result = self.template.format_tools.extract(response) | |
else: | |
result = response | |
if isinstance(result, tuple): | |
name, arguments = result | |
arguments = json.loads(arguments) | |
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) | |
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}] | |
bot_text = "```json\n" + tool_call + "\n```" | |
else: | |
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}] | |
bot_text = result | |
chatbot[-1] = [query, self.postprocess(bot_text)] | |
yield chatbot, output_messages | |
def postprocess(self, response: str) -> str: | |
blocks = response.split("```") | |
for i, block in enumerate(blocks): | |
if i % 2 == 0: | |
blocks[i] = block.replace("<", "<").replace(">", ">") | |
return "```".join(blocks) | |