|
import gradio as gr |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
model = T5ForConditionalGeneration.from_pretrained('t5-small') |
|
tokenizer = T5Tokenizer.from_pretrained('t5-small') |
|
|
|
def generate_clinical_report(input_text): |
|
""" |
|
Generate a clinical report from the input text using the fine-tuned T5 model. |
|
""" |
|
|
|
input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
|
|
outputs = model.generate( |
|
input_ids, |
|
max_length=256, |
|
num_beams=4, |
|
no_repeat_ngram_size=3, |
|
length_penalty=2.0, |
|
early_stopping=True, |
|
bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]] |
|
for word in ['http', 'www', '.com', '.org']] |
|
) |
|
|
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_clinical_report, |
|
inputs=[ |
|
gr.Textbox( |
|
lines=8, |
|
placeholder="Enter clinical notes here...", |
|
label="Clinical Notes", |
|
elem_id="input-box" |
|
) |
|
], |
|
outputs=[ |
|
gr.Textbox( |
|
lines=8, |
|
label="Generated Clinical Report", |
|
elem_id="output-box" |
|
) |
|
], |
|
title="Clinical Report Generator", |
|
description="Generate professional clinical reports from clinical notes using a T5 model.", |
|
examples=[ |
|
["Patient presented with severe abdominal pain in the lower right quadrant. Temperature 38.5°C, BP 130/85."], |
|
["Follow-up visit for diabetes management. Blood sugar levels have been stable with current medication regimen."] |
|
], |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
#input-box { background-color: #f6f6f6; } |
|
#output-box { background-color: #f0f7ff; } |
|
""", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", share=True) |
|
|