T5 / app.py
pdarleyjr's picture
Update CORS configuration for GitHub Pages integration
1310cf1
raw
history blame
2.83 kB
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Create FastAPI app with CORS configuration
app = FastAPI()
# Add CORS middleware with specific origin
app.add_middleware(
CORSMiddleware,
allow_origins=["https://pdarleyjr.github.io"], # Specifically allow the GitHub Pages domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the base T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
def generate_clinical_report(input_text):
"""
Generate a clinical report from the input text using the fine-tuned T5 model.
"""
# Prepare input text
input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate report
outputs = model.generate(
input_ids,
max_length=256,
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=2.0,
early_stopping=True,
bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
for word in ['http', 'www', '.com', '.org']]
)
# Decode and return the generated report
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Create Gradio interface with queue
demo = gr.Interface(
fn=generate_clinical_report,
inputs=[
gr.Textbox(
lines=8,
placeholder="Enter clinical notes here...",
label="Clinical Notes",
elem_id="input-box"
)
],
outputs=[
gr.Textbox(
lines=8,
label="Generated Clinical Report",
elem_id="output-box"
)
],
title="Clinical Report Generator",
description="Generate professional clinical reports from clinical notes using a T5 model.",
examples=[
["Patient presented with severe abdominal pain in the lower right quadrant. Temperature 38.5°C, BP 130/85."],
["Follow-up visit for diabetes management. Blood sugar levels have been stable with current medication regimen."]
],
theme=gr.themes.Soft(),
css="""
#input-box { background-color: #f6f6f6; }
#output-box { background-color: #f0f7ff; }
""",
flagging_mode="never"
)
# Enable queue
demo.queue()
# Mount the Gradio app with queue support
app = gr.mount_gradio_app(app, demo, path="/")
# Launch the app with proper queue configuration
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860, # Default Gradio port
share=True,
max_threads=40,
show_error=True,
root_path="",
enable_queue=True
)