Update Gradio interface with configurable parameters and error handling
Browse files
app.py
CHANGED
@@ -5,58 +5,52 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
5 |
model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
6 |
tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
7 |
|
8 |
-
def generate_clinical_report(input_text):
|
9 |
"""
|
10 |
-
Generate a clinical report from the input text using the
|
11 |
"""
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
# Create Gradio interface with
|
31 |
demo = gr.Interface(
|
32 |
fn=generate_clinical_report,
|
33 |
inputs=[
|
34 |
gr.Textbox(
|
35 |
lines=8,
|
36 |
placeholder="Enter clinical notes here...",
|
37 |
-
label="Clinical Notes"
|
38 |
-
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
gr.
|
43 |
-
|
44 |
-
label="Generated Clinical Report",
|
45 |
-
elem_id="output-box"
|
46 |
-
)
|
47 |
],
|
|
|
48 |
title="Clinical Report Generator",
|
49 |
description="Generate professional clinical reports from clinical notes using a T5 model.",
|
50 |
-
|
51 |
-
|
52 |
-
["Follow-up visit for diabetes management. Blood sugar levels have been stable with current medication regimen."]
|
53 |
-
],
|
54 |
-
theme=gr.themes.Soft(),
|
55 |
-
css="""
|
56 |
-
#input-box { background-color: #f6f6f6; }
|
57 |
-
#output-box { background-color: #f0f7ff; }
|
58 |
-
""",
|
59 |
-
flagging_mode="never"
|
60 |
)
|
61 |
|
62 |
# Launch the app with optimized configuration for Hugging Face Spaces
|
|
|
5 |
model = T5ForConditionalGeneration.from_pretrained('t5-small')
|
6 |
tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
7 |
|
8 |
+
def generate_clinical_report(input_text, max_length=256, num_beams=4, no_repeat_ngram_size=3, length_penalty=2.0, early_stopping=True):
|
9 |
"""
|
10 |
+
Generate a clinical report from the input text using the T5 model with configurable parameters.
|
11 |
"""
|
12 |
+
try:
|
13 |
+
# Prepare input text
|
14 |
+
input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
|
15 |
+
|
16 |
+
# Generate report with provided parameters
|
17 |
+
outputs = model.generate(
|
18 |
+
input_ids,
|
19 |
+
max_length=max_length,
|
20 |
+
num_beams=num_beams,
|
21 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
22 |
+
length_penalty=length_penalty,
|
23 |
+
early_stopping=early_stopping,
|
24 |
+
bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
|
25 |
+
for word in ['http', 'www', '.com', '.org']]
|
26 |
+
)
|
27 |
+
|
28 |
+
# Decode and return the generated report
|
29 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error generating report: {str(e)}")
|
32 |
+
return f"Error: {str(e)}"
|
33 |
|
34 |
+
# Create Gradio interface with API configuration
|
35 |
demo = gr.Interface(
|
36 |
fn=generate_clinical_report,
|
37 |
inputs=[
|
38 |
gr.Textbox(
|
39 |
lines=8,
|
40 |
placeholder="Enter clinical notes here...",
|
41 |
+
label="Clinical Notes"
|
42 |
+
),
|
43 |
+
gr.Slider(minimum=50, maximum=500, value=256, step=1, label="Max Length"),
|
44 |
+
gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Num Beams"),
|
45 |
+
gr.Slider(minimum=1, maximum=5, value=3, step=1, label="No Repeat Ngram Size"),
|
46 |
+
gr.Slider(minimum=0.1, maximum=5.0, value=2.0, step=0.1, label="Length Penalty"),
|
47 |
+
gr.Checkbox(value=True, label="Early Stopping")
|
|
|
|
|
|
|
48 |
],
|
49 |
+
outputs=gr.Textbox(lines=8, label="Generated Clinical Report"),
|
50 |
title="Clinical Report Generator",
|
51 |
description="Generate professional clinical reports from clinical notes using a T5 model.",
|
52 |
+
allow_flagging="never",
|
53 |
+
analytics_enabled=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
)
|
55 |
|
56 |
# Launch the app with optimized configuration for Hugging Face Spaces
|