T5 / app.py
pdarleyjr's picture
Add FastAPI integration with proper CORS and queue handling
9e0003a
raw
history blame
2.23 kB
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["https://pdarleyjr.github.io"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the base T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')
def generate_clinical_report(input_text):
"""
Generate a clinical report from the input text using the T5 model.
"""
try:
# Prepare input text
input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate report
outputs = model.generate(
input_ids,
max_length=256,
num_beams=4,
no_repeat_ngram_size=3,
length_penalty=2.0,
early_stopping=True,
bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
for word in ['http', 'www', '.com', '.org']]
)
# Decode and return the generated report
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Error generating report: {str(e)}")
return f"Error: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=generate_clinical_report,
inputs=gr.Textbox(
lines=8,
placeholder="Enter clinical notes here...",
label="Clinical Notes"
),
outputs=gr.Textbox(
lines=8,
label="Generated Clinical Report"
),
title="Clinical Report Generator",
description="Generate professional clinical reports from clinical notes using a T5 model.",
allow_flagging="never"
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
# Launch the app with FastAPI integration
if __name__ == "__main__":
import uvicorn
demo.queue(max_size=20)
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)