Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import os | |
from threading import Thread | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
MODEL_ID = "kodetr/stunting-qa-v5" | |
MODELS = os.environ.get("MODELS") | |
TITLE = "<h1><center>KONSULTASI STUNTING</center></h1>" | |
DESCRIPTION = f""" | |
<center> | |
<p> | |
Developed By Tanwir | |
</p> | |
</center> | |
""" | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
# ------------------------------------- | |
# ------- use model stunting V5 ------- | |
# ------------------------------------- | |
text_pipeline = pipeline( | |
"text-generation", | |
model=MODEL_ID, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
# ------------------------------------- | |
# ------- use model stunting V6 ------- | |
# ------------------------------------- | |
# model = AutoModelForCausalLM.from_pretrained( | |
# MODEL_ID, | |
# torch_dtype=torch.bfloat16, | |
# device_map="auto", | |
# ) | |
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float): | |
print(f'message is - {message}') | |
print(f'history is - {history}') | |
conversation = [{"role": "system", "content": 'Anda adalah chatbot kesehatan masyarakat yang hanya memberikan informasi dan konsultasi terkait pencegahan stunting, gizi anak, dan kesehatan ibu. Tolak semua pertanyaan yang tidak relevan atau di luar konteks ini dengan sopan.'}] | |
for prompt, answer in history: | |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
conversation.append({"role": "user", "content": message}) | |
print(f"Conversation is -\n{conversation}") | |
# ------------------------------------- | |
# ------- use model stunting V5 ------- | |
# ------------------------------------- | |
terminators = [ | |
text_pipeline.tokenizer.eos_token_id, | |
text_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
] | |
# Hasil dari pipeline akan berupa list dengan dictionary berisi text | |
outputs = text_pipeline( | |
conversation, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=terminators, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=penalty | |
) | |
# Karena pipeline tidak support streaming per token, kita bisa stream per kalimat atau per paragraf | |
full_text = outputs[0]["generated_text"] | |
buffer = "" | |
for part in full_text.split(". "): # Stream berdasarkan kalimat | |
buffer += part.strip() + ". " | |
yield buffer | |
# ------------------------------------- | |
# ------- use model stunting V6 ------- | |
# ------------------------------------- | |
# input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
# inputs = tokenizer(input_ids, return_tensors="pt").to(0) #gpu 0, cpu 1 | |
# streamer = TextIteratorStreamer(tokenizer, timeout=60., skip_prompt=True, skip_special_tokens=True) | |
# generate_kwargs = dict( | |
# inputs, | |
# streamer=streamer, | |
# top_k=top_k, | |
# top_p=top_p, | |
# repetition_penalty=penalty, | |
# max_new_tokens=max_new_tokens, | |
# do_sample=True, | |
# temperature=temperature, | |
# pad_token_id=128000, | |
# eos_token_id=[128001,128008,128009], | |
# ) | |
# thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
# thread.start() | |
# buffer = "" | |
# for new_text in streamer: | |
# buffer += new_text | |
# yield buffer | |
chatbot = gr.Chatbot(height=600) | |
with gr.Blocks(css=CSS) as demo: | |
gr.HTML(TITLE) | |
gr.HTML(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.8, | |
label="Temperature", | |
render=False, | |
), | |
gr.Slider( | |
minimum=128, | |
maximum=4096, | |
step=1, | |
value=1024, | |
label="Max new tokens", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=0.8, | |
label="top_p", | |
render=False, | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=20, | |
label="top_k", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.0, | |
label="Repetition penalty", | |
render=False, | |
), | |
], | |
examples=[ | |
["Apa yang dimaksud tentang Stunting?"], | |
["Apa saja tanda-tanda anak mengalami stunting?"], | |
["Apa saja makanan yang bisa mencegah stunting?"], | |
["Bagaimana malnutrisi dapat mempengaruhi perkembangan otak anak?"], | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |