api_for_chat / app.py
ldhldh's picture
Update app.py
6fce119
raw
history blame
4.44 kB
from threading import Thread
import gradio as gr
import inspect
from gradio import routes
from typing import List, Type
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
import requests, os, re, asyncio, json
loop = asyncio.get_event_loop()
# init code
def get_types(cls_set: List[Type], component: str):
docset = []
types = []
if component == "input":
for cls in cls_set:
doc = inspect.getdoc(cls)
doc_lines = doc.split("\n")
docset.append(doc_lines[1].split(":")[-1])
types.append(doc_lines[1].split(")")[0].split("(")[-1])
else:
for cls in cls_set:
doc = inspect.getdoc(cls)
doc_lines = doc.split("\n")
docset.append(doc_lines[-1].split(":")[-1])
types.append(doc_lines[-1].split(")")[0].split("(")[-1])
return docset, types
routes.get_types = get_types
# App code
model_name = "petals-team/StableBeluga2"
#daekeun-ml/Llama-2-ko-instruct-13B
#quantumaikr/llama-2-70b-fb16-korean
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = None
history = {
"":{
}
}
def check(model_name):
data = requests.get("https://health.petals.dev/api/v1/state").json()
out = []
for d in data['model_reports']:
if d['name'] == model_name:
if d['state']=="healthy":
return True
return False
def init():
global model
if check(model_name):
model = AutoDistributedModelForCausalLM.from_pretrained(model_name)
def chat(id, npc, text):
if model == None:
init()
return "no model"
# get_coin endpoint
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_6", json={
"data": [
id,
]}).json()
coin = response["data"][0]
if int(coin) == 0:
return "no coin"
# model inference
if check(model_name):
global history
if not npc in npc_story:
return "no npc"
if not npc in history:
history[npc] = {}
if not id in history[npc]:
history[npc][id] = ""
if len(history[npc][id].split("###")) > 10:
history[npc][id] = "###" + history[npc][id].split("###", 3)[3]
npc_list = str([k for k in npc_story.keys()]).replace('\'', '')
town_story = f"""[{id}์˜ ๋งˆ์„]
์™ธ๋”ด ๊ณณ์˜ ์กฐ๊ทธ๋งŒ ์„ฌ์— ์—ฌ๋Ÿฌ ์ฃผ๋ฏผ๋“ค์ด ๋ชจ์—ฌ ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
ํ˜„์žฌ {npc_list}์ด ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค."""
system_message = f"""1. ๋‹น์‹ ์€ ํ•œ๊ตญ์–ด์— ๋Šฅ์ˆ™ํ•ฉ๋‹ˆ๋‹ค.
2. ๋‹น์‹ ์€ ์ง€๊ธˆ ์—ญํ• ๊ทน์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. {npc}์˜ ๋ฐ˜์‘์„ ์ƒ์ƒํ•˜๊ณ  ๋งค๋ ฅ์ ์ด๊ฒŒ ํ‘œํ˜„ํ•ฉ๋‹ˆ๋‹ค.
3. ๋‹น์‹ ์€ {npc}์ž…๋‹ˆ๋‹ค. {npc}์˜ ์ž…์žฅ์—์„œ ์ƒ๊ฐํ•˜๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.
4. ์ฃผ์–ด์ง€๋Š” ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐœ์—ฐ์„ฑ์žˆ๊ณ  ์‹ค๊ฐ๋‚˜๋Š” {npc}์˜ ๋Œ€์‚ฌ๋ฅผ ์™„์„ฑํ•˜์„ธ์š”.
5. ์ฃผ์–ด์ง€๋Š” {npc}์˜ ์ •๋ณด๋ฅผ ์‹ ์ค‘ํ•˜๊ฒŒ ์ฝ๊ณ , ๊ณผํ•˜์ง€ ์•Š๊ณ  ๋‹ด๋ฐฑํ•˜๊ฒŒ ์บ๋ฆญํ„ฐ๋ฅผ ์—ฐ๊ธฐํ•˜์„ธ์š”.
6. User์˜ ์—ญํ• ์„ ์ ˆ๋Œ€๋กœ ์นจ๋ฒ”ํ•˜์ง€ ๋งˆ์„ธ์š”. ๊ฐ™์€ ๋ง์„ ๋ฐ˜๋ณตํ•˜์ง€ ๋งˆ์„ธ์š”.
7. {npc}์˜ ๋งํˆฌ๋ฅผ ์ง€์ผœ์„œ ์ž‘์„ฑํ•˜์„ธ์š”."""
prom = f"""<<SYS>>
{system_message}<</SYS>>
{town_story}
### ์บ๋ฆญํ„ฐ ์ •๋ณด: {npc_story[npc]}
### ๋ช…๋ น์–ด:
{npc}์˜ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ {npc}์ด ํ•  ๋ง์„ ์ƒํ™ฉ์— ๋งž์ถฐ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
{history[npc][id]}
### User:
{text}
### {npc}:
"""
inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100)
output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n")
print(outputs)
print(output)
else:
output = "no model"
# add_transaction endpoint
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_5", json={
"data": [
id,
"inference",
"### input:\n" + prompt + "\n\n### output:\n" + output
]}).json()
d = response["data"][0]
return output
with gr.Blocks() as demo:
count = 0
aa = gr.Interface(
fn=chat,
inputs=["text","text","text"],
outputs="text",
description="chat, ai ์‘๋‹ต์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ํŠธ๋žœ์žญ์…˜ ์ƒ์„ฑ. \n /run/predict",
)
demo.queue(max_size=32).launch(enable_queue=True)