File size: 2,440 Bytes
647a796 0f9b256 e213acb 647a796 2248522 647a796 2248522 647a796 a344135 2248522 a344135 2248522 a344135 647a796 2248522 647a796 2248522 647a796 52bab4e 2248522 647a796 0f9b256 647a796 0f9b256 75a2106 1310cf1 79a19e9 4c4c9e8 0f9b256 75a2106 1310cf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Create FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["https://pdarleyjr.github.io"], # Allow GitHub Pages domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the base T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
def generate_clinical_report(input_text):
"""
Generate a clinical report from the input text using the T5 model.
"""
try:
# Prepare input text
input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate report
outputs = model.generate(
input_ids,
max_length=256,
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=2.0,
early_stopping=True,
bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
for word in ['http', 'www', '.com', '.org']]
)
# Decode and return the generated report
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Error generating report: {str(e)}")
return f"Error: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=generate_clinical_report,
inputs=gr.Textbox(
lines=8,
placeholder="Enter clinical notes here...",
label="Clinical Notes"
),
outputs=gr.Textbox(
lines=8,
label="Generated Clinical Report"
),
title="Clinical Report Generator",
description="Generate professional clinical reports from clinical notes using a T5 model.",
allow_flagging="never"
)
# Mount the Gradio app and launch
if __name__ == "__main__":
app = gr.mount_gradio_app(app, demo, path="/")
demo.queue() # Enable queue with default settings
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=True, # Enable API documentation
share=False, # Not needed in Spaces
debug=True,
root_path="",
max_threads=40 # Increase max threads for better performance
)
|