File size: 4,558 Bytes
fe93c00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the model and tokenizer
model_name = "manjunathshiva/ICD11-Clinical-Terminology-16bit"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
@spaces.GPU(duration=120)
def generate_text(prompt, max_length, temperature):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
top_k=100,
top_p=0.95,
)
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Custom CSS
css = """
body {
background-color: #1a1a2e;
color: #e0e0e0;
font-family: 'Arial', sans-serif;
}
.container {
max-width: 900px;
margin: auto;
padding: 20px;
}
.gradio-container {
background-color: #16213e;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
background-color: #0f3460;
padding: 20px;
border-radius: 15px 15px 0 0;
text-align: center;
margin-bottom: 20px;
}
.header h1 {
color: #e94560;
font-size: 2.5em;
margin-bottom: 10px;
}
.header p {
color: #a0a0a0;
}
.header img {
max-width: 300px;
border-radius: 10px;
margin: 15px auto;
display: block;
}
.input-group, .output-group {
background-color: #1a1a2e;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
.input-group label, .output-group label {
color: #e94560;
font-weight: bold;
}
.generate-btn {
background-color: #e94560 !important;
color: white !important;
border: none !important;
border-radius: 5px !important;
padding: 10px 20px !important;
font-size: 16px !important;
cursor: pointer !important;
transition: background-color 0.3s ease !important;
}
.generate-btn:hover {
background-color: #c81e45 !important;
}
.example-prompts {
background-color: #1f2b47;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
}
.example-prompts h3 {
color: #e94560;
margin-bottom: 10px;
}
.example-prompts ul {
list-style-type: none;
padding-left: 0;
}
.example-prompts li {
margin-bottom: 5px;
cursor: pointer;
transition: color 0.3s ease;
}
.example-prompts li:hover {
color: #e94560;
}
"""
# Example prompts
example_prompts = [
"What is the code for Coronary microvascular disease ?",
"What is the ICD 11 code for Diseases of coronary artery, unspecified ?",
"The code is BA85.0. What does it mean?",
]
# Gradio interface
# Gradio interface
with gr.Blocks(css=css) as iface:
gr.HTML(
"""
<div class="header">
<h1>manjunathshiva/ICD11-Clinical-Terminology-16bit</h1>
<p>Generate ICD 11 code using the fintetuned Llama3.1 8B model : manjunathshiva/ICD11-Clinical-Terminology-16bit model. Enter a prompt and let the AI create!</p>
<img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
</div>
"""
)
with gr.Group():
with gr.Group(elem_classes="example-prompts"):
gr.HTML("<h3>Example Prompts:</h3>")
example_buttons = [gr.Button(prompt) for prompt in example_prompts]
with gr.Group(elem_classes="input-group"):
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
generate_btn = gr.Button("Generate", elem_classes="generate-btn")
with gr.Group(elem_classes="output-group"):
output = gr.Textbox(label="Generated Text", lines=10)
generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)
# Set up example prompt buttons
for button in example_buttons:
button.click(lambda x: x, inputs=[button], outputs=[prompt])
# Launch the app
iface.launch()
|