|
import gradio as gr |
|
import spaces |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_name = "manjunathshiva/ICD11-Clinical-Terminology-16bit" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
|
|
@spaces.GPU(duration=120) |
|
def generate_text(prompt, max_length, temperature): |
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_k=100, |
|
top_p=0.95, |
|
) |
|
|
|
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
|
|
css = """ |
|
body { |
|
background-color: #1a1a2e; |
|
color: #e0e0e0; |
|
font-family: 'Arial', sans-serif; |
|
} |
|
.container { |
|
max-width: 900px; |
|
margin: auto; |
|
padding: 20px; |
|
} |
|
.gradio-container { |
|
background-color: #16213e; |
|
border-radius: 15px; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
} |
|
.header { |
|
background-color: #0f3460; |
|
padding: 20px; |
|
border-radius: 15px 15px 0 0; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
.header h1 { |
|
color: #e94560; |
|
font-size: 2.5em; |
|
margin-bottom: 10px; |
|
} |
|
.header p { |
|
color: #a0a0a0; |
|
} |
|
.header img { |
|
max-width: 300px; |
|
border-radius: 10px; |
|
margin: 15px auto; |
|
display: block; |
|
} |
|
.input-group, .output-group { |
|
background-color: #1a1a2e; |
|
padding: 20px; |
|
border-radius: 10px; |
|
margin-bottom: 20px; |
|
} |
|
.input-group label, .output-group label { |
|
color: #e94560; |
|
font-weight: bold; |
|
} |
|
.generate-btn { |
|
background-color: #e94560 !important; |
|
color: white !important; |
|
border: none !important; |
|
border-radius: 5px !important; |
|
padding: 10px 20px !important; |
|
font-size: 16px !important; |
|
cursor: pointer !important; |
|
transition: background-color 0.3s ease !important; |
|
} |
|
.generate-btn:hover { |
|
background-color: #c81e45 !important; |
|
} |
|
.example-prompts { |
|
background-color: #1f2b47; |
|
padding: 15px; |
|
border-radius: 10px; |
|
margin-bottom: 20px; |
|
} |
|
.example-prompts h3 { |
|
color: #e94560; |
|
margin-bottom: 10px; |
|
} |
|
.example-prompts ul { |
|
list-style-type: none; |
|
padding-left: 0; |
|
} |
|
.example-prompts li { |
|
margin-bottom: 5px; |
|
cursor: pointer; |
|
transition: color 0.3s ease; |
|
} |
|
.example-prompts li:hover { |
|
color: #e94560; |
|
} |
|
""" |
|
|
|
|
|
example_prompts = [ |
|
"What is the code for Coronary microvascular disease ?", |
|
"What is the ICD 11 code for Diseases of coronary artery, unspecified ?", |
|
"The code is BA85.0. What does it mean?", |
|
] |
|
|
|
|
|
|
|
with gr.Blocks(css=css) as iface: |
|
gr.HTML( |
|
""" |
|
<div class="header"> |
|
<h1>manjunathshiva/ICD11-Clinical-Terminology-16bit</h1> |
|
<p>Generate ICD 11 code using the fintetuned Llama3.1 8B model : manjunathshiva/ICD11-Clinical-Terminology-16bit model. Enter a prompt and let the AI create!</p> |
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama"> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Group(): |
|
with gr.Group(elem_classes="example-prompts"): |
|
gr.HTML("<h3>Example Prompts:</h3>") |
|
example_buttons = [gr.Button(prompt) for prompt in example_prompts] |
|
|
|
with gr.Group(elem_classes="input-group"): |
|
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5) |
|
max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length") |
|
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") |
|
generate_btn = gr.Button("Generate", elem_classes="generate-btn") |
|
|
|
with gr.Group(elem_classes="output-group"): |
|
output = gr.Textbox(label="Generated Text", lines=10) |
|
|
|
generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output) |
|
|
|
|
|
for button in example_buttons: |
|
button.click(lambda x: x, inputs=[button], outputs=[prompt]) |
|
|
|
|
|
iface.launch() |
|
|