manjunathshiva commited on
Commit
b7a4928
·
verified ·
1 Parent(s): 88c32bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -161
app.py CHANGED
@@ -1,163 +1,3 @@
1
  import gradio as gr
2
- import spaces
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # Load the model and tokenizer
7
- model_name = "manjunathshiva/ICD11-Clinical-Terminology-16bit"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_name,
11
- torch_dtype=torch.bfloat16,
12
- device_map="auto"
13
- )
14
-
15
- @spaces.GPU(duration=120)
16
- def generate_text(prompt, max_length, temperature):
17
- messages = [
18
- {"role": "system", "content": "You are a helpful assistant."},
19
- {"role": "user", "content": prompt}
20
- ]
21
- formatted_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
22
-
23
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
24
-
25
- outputs = model.generate(
26
- **inputs,
27
- max_new_tokens=max_length,
28
- do_sample=True,
29
- temperature=temperature,
30
- top_k=100,
31
- top_p=0.95,
32
- )
33
-
34
- return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
35
-
36
-
37
- # Custom CSS
38
- css = """
39
- body {
40
- background-color: #1a1a2e;
41
- color: #e0e0e0;
42
- font-family: 'Arial', sans-serif;
43
- }
44
- .container {
45
- max-width: 900px;
46
- margin: auto;
47
- padding: 20px;
48
- }
49
- .gradio-container {
50
- background-color: #16213e;
51
- border-radius: 15px;
52
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
53
- }
54
- .header {
55
- background-color: #0f3460;
56
- padding: 20px;
57
- border-radius: 15px 15px 0 0;
58
- text-align: center;
59
- margin-bottom: 20px;
60
- }
61
- .header h1 {
62
- color: #e94560;
63
- font-size: 2.5em;
64
- margin-bottom: 10px;
65
- }
66
- .header p {
67
- color: #a0a0a0;
68
- }
69
- .header img {
70
- max-width: 300px;
71
- border-radius: 10px;
72
- margin: 15px auto;
73
- display: block;
74
- }
75
- .input-group, .output-group {
76
- background-color: #1a1a2e;
77
- padding: 20px;
78
- border-radius: 10px;
79
- margin-bottom: 20px;
80
- }
81
- .input-group label, .output-group label {
82
- color: #e94560;
83
- font-weight: bold;
84
- }
85
- .generate-btn {
86
- background-color: #e94560 !important;
87
- color: white !important;
88
- border: none !important;
89
- border-radius: 5px !important;
90
- padding: 10px 20px !important;
91
- font-size: 16px !important;
92
- cursor: pointer !important;
93
- transition: background-color 0.3s ease !important;
94
- }
95
- .generate-btn:hover {
96
- background-color: #c81e45 !important;
97
- }
98
- .example-prompts {
99
- background-color: #1f2b47;
100
- padding: 15px;
101
- border-radius: 10px;
102
- margin-bottom: 20px;
103
- }
104
- .example-prompts h3 {
105
- color: #e94560;
106
- margin-bottom: 10px;
107
- }
108
- .example-prompts ul {
109
- list-style-type: none;
110
- padding-left: 0;
111
- }
112
- .example-prompts li {
113
- margin-bottom: 5px;
114
- cursor: pointer;
115
- transition: color 0.3s ease;
116
- }
117
- .example-prompts li:hover {
118
- color: #e94560;
119
- }
120
- """
121
-
122
- # Example prompts
123
- example_prompts = [
124
- "What is the code for Coronary microvascular disease ?",
125
- "What is the ICD 11 code for Diseases of coronary artery, unspecified ?",
126
- "The code is BA85.0. What does it mean?",
127
- ]
128
-
129
- # Gradio interface
130
- # Gradio interface
131
- with gr.Blocks(css=css) as iface:
132
- gr.HTML(
133
- """
134
- <div class="header">
135
- <h1>manjunathshiva/ICD11-Clinical-Terminology-16bit</h1>
136
- <p>Generate ICD 11 code using the fintetuned Llama3.1 8B model : manjunathshiva/ICD11-Clinical-Terminology-16bit model. Enter a prompt and let the AI create!</p>
137
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64c75c1237333ccfef30a602/tmOlbERGKP7JSODa6T06J.jpeg" alt="Llama">
138
- </div>
139
- """
140
- )
141
-
142
- with gr.Group():
143
- with gr.Group(elem_classes="example-prompts"):
144
- gr.HTML("<h3>Example Prompts:</h3>")
145
- example_buttons = [gr.Button(prompt) for prompt in example_prompts]
146
-
147
- with gr.Group(elem_classes="input-group"):
148
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=5)
149
- max_length = gr.Slider(minimum=1, maximum=500, value=128, step=1, label="Max Length")
150
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
151
- generate_btn = gr.Button("Generate", elem_classes="generate-btn")
152
-
153
- with gr.Group(elem_classes="output-group"):
154
- output = gr.Textbox(label="Generated Text", lines=10)
155
-
156
- generate_btn.click(generate_text, inputs=[prompt, max_length, temperature], outputs=output)
157
-
158
- # Set up example prompt buttons
159
- for button in example_buttons:
160
- button.click(lambda x: x, inputs=[button], outputs=[prompt])
161
-
162
- # Launch the app
163
- iface.launch()
 
1
  import gradio as gr
 
 
 
2
 
3
+ gr.load("models/manjunathshiva/ICD11-Clinical-Terminology-16bit").launch()