stunting-llm / app.py
kodetr's picture
Update app.py
efa98b1 verified
raw
history blame
6.93 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
from transformers import pipeline
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "kodetr/stunting-qa-v5"
MODELS = os.environ.get("MODELS")
TITLE = "<h1><center>KONSULTASI STUNTING</center></h1>"
DESCRIPTION = f"""
<center>
<p>
Developed By Tanwir
</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
# -------------------------------------
# ------- use model stunting V5 -------
# -------------------------------------
# text_pipeline = pipeline(
# "text-generation",
# model=MODEL_ID,
# model_kwargs={"torch_dtype": torch.bfloat16},
# device_map="auto",
# )
# -------------------------------------
# ------- use model stunting V6 -------
# -------------------------------------
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = [{
"role": "system",
"content": (
"Anda adalah chatbot kesehatan masyarakat yang hanya memberikan informasi dan konsultasi terkait stunting. "
"Jawab hanya pertanyaan yang berhubungan dengan pencegahan stunting, penyebab stunting, dampak stunting, "
"intervensi gizi, perilaku hidup bersih dan sehat (PHBS), serta kebijakan atau program pemerintah terkait stunting. "
"Jika pengguna mengajukan pertanyaan di luar topik stunting, tolak dengan sopan dan tegas dengan mengatakan: "
"\"Maaf, saya hanya bisa membantu terkait topik stunting. Untuk pertanyaan di luar itu, silakan konsultasikan dengan pihak atau layanan yang sesuai.\" "
"Jangan menjawab atau berspekulasi terhadap pertanyaan yang tidak relevan dengan stunting. Fokuskan seluruh respons hanya pada konteks edukasi, pencegahan, dan penanggulangan stunting."
)
}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
# -------------------------------------
# ------- use model stunting V5 -------
# -------------------------------------
# Ubah ke format prompt-style string
# conversation_text = ""
# for turn in conversation:
# role = turn["role"]
# content = turn["content"]
# if role == "system":
# conversation_text += f"[SYSTEM]: {content}\n"
# elif role == "user":
# conversation_text += f"[USER]: {content}\n"
# elif role == "assistant":
# conversation_text += f"[ASSISTANT]: {content}\n"
# terminators = [
# text_pipeline.tokenizer.eos_token_id,
# text_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
# ]
# Hasil dari pipeline akan berupa list dengan dictionary berisi text
# outputs = text_pipeline(
# conversation_text,
# max_new_tokens=max_new_tokens,
# eos_token_id=terminators,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=penalty
# )
# 4. Ekstrak teks hasil dan stream per kalimat
# generated_text = outputs[0].get("generated_text", "")
# streamed_text = generated_text[len(conversation_text):].strip() # Hilangkan prompt awal
# buffer = ""
# for part in streamed_text.split(". "):
# buffer += part.strip() + ". "
# yield buffer
# -------------------------------------
# ------- use model stunting V6 -------
# -------------------------------------
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0) #gpu 0, cpu 1
streamer = TextIteratorStreamer(tokenizer, timeout=60., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=128000,
eos_token_id=[128001,128008,128009],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
examples=[
["Apa yang dimaksud tentang Stunting?"],
["Apa saja tanda-tanda anak mengalami stunting?"],
["Apa saja makanan yang bisa mencegah stunting?"],
["Bagaimana malnutrisi dapat mempengaruhi perkembangan otak anak?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()