stunting-llm / app.py
kodetr's picture
Update app.py
1327db3 verified
raw
history blame
6.27 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
from transformers import pipeline
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "kodetr/stunting-qa-v5"
MODELS = os.environ.get("MODELS")
TITLE = "<h1><center>KONSULTASI STUNTING</center></h1>"
DESCRIPTION = f"""
<center>
<p>
Developed By Tanwir
</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
# -------------------------------------
# ------- use model stunting V5 -------
# -------------------------------------
# text_pipeline = pipeline(
# "text-generation",
# model=MODEL_ID,
# model_kwargs={"torch_dtype": torch.bfloat16},
# device_map="auto",
# )
# -------------------------------------
# ------- use model stunting V6 -------
# -------------------------------------
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = [{"role": "system", "content": 'Anda adalah chatbot kesehatan masyarakat yang hanya memberikan informasi dan konsultasi terkait pencegahan stunting, gizi anak, dan kesehatan ibu. Tolak semua pertanyaan yang tidak relevan atau di luar konteks ini dengan sopan.'}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
# -------------------------------------
# ------- use model stunting V5 -------
# -------------------------------------
# Ubah ke format prompt-style string
# conversation_text = ""
# for turn in conversation:
# role = turn["role"]
# content = turn["content"]
# if role == "system":
# conversation_text += f"[SYSTEM]: {content}\n"
# elif role == "user":
# conversation_text += f"[USER]: {content}\n"
# elif role == "assistant":
# conversation_text += f"[ASSISTANT]: {content}\n"
# terminators = [
# text_pipeline.tokenizer.eos_token_id,
# text_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
# ]
# Hasil dari pipeline akan berupa list dengan dictionary berisi text
# outputs = text_pipeline(
# conversation_text,
# max_new_tokens=max_new_tokens,
# eos_token_id=terminators,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=penalty
# )
# 4. Ekstrak teks hasil dan stream per kalimat
# generated_text = outputs[0].get("generated_text", "")
# streamed_text = generated_text[len(conversation_text):].strip() # Hilangkan prompt awal
# buffer = ""
# for part in streamed_text.split(". "):
# buffer += part.strip() + ". "
# yield buffer
# -------------------------------------
# ------- use model stunting V6 -------
# -------------------------------------
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0) #gpu 0, cpu 1
streamer = TextIteratorStreamer(tokenizer, timeout=60., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=128000,
eos_token_id=[128001,128008,128009],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=600)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="Repetition penalty",
render=False,
),
],
examples=[
["Apa yang dimaksud tentang Stunting?"],
["Apa saja tanda-tanda anak mengalami stunting?"],
["Apa saja makanan yang bisa mencegah stunting?"],
["Bagaimana malnutrisi dapat mempengaruhi perkembangan otak anak?"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()