File size: 2,415 Bytes
647a796
 
0f9b256
 
 
 
 
 
 
 
 
 
 
 
 
 
e213acb
647a796
 
 
 
2248522
647a796
2248522
647a796
a344135
 
 
 
2248522
a344135
 
2248522
 
 
 
 
a344135
 
 
 
 
 
 
 
 
647a796
2248522
647a796
 
2248522
 
 
 
 
 
 
 
 
647a796
52bab4e
2248522
647a796
 
0f9b256
647a796
4c4c9e8
0f9b256
1310cf1
79a19e9
4c4c9e8
 
 
 
0f9b256
 
1310cf1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

# Create FastAPI app
app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://pdarleyjr.github.io"],  # Allow GitHub Pages domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the base T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer = T5Tokenizer.from_pretrained('t5-small')

def generate_clinical_report(input_text):
    """
    Generate a clinical report from the input text using the T5 model.
    """
    try:
        # Prepare input text
        input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate report
        outputs = model.generate(
            input_ids,
            max_length=256,
            num_beams=4,
            no_repeat_ngram_size=3,
            length_penalty=2.0,
            early_stopping=True,
            bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]] 
                          for word in ['http', 'www', '.com', '.org']]
        )
        
        # Decode and return the generated report
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        print(f"Error generating report: {str(e)}")
        return f"Error: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=generate_clinical_report,
    inputs=gr.Textbox(
        lines=8,
        placeholder="Enter clinical notes here...",
        label="Clinical Notes"
    ),
    outputs=gr.Textbox(
        lines=8,
        label="Generated Clinical Report"
    ),
    title="Clinical Report Generator",
    description="Generate professional clinical reports from clinical notes using a T5 model.",
    allow_flagging="never"
)

# Mount the Gradio app and launch
if __name__ == "__main__":
    demo.queue(concurrency_count=3)  # Allow multiple concurrent requests
    app = gr.mount_gradio_app(app, demo, path="/")
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        enable_queue=True,
        show_api=True,  # Enable API documentation
        share=False,    # Not needed in Spaces
        debug=True,
        root_path=""
    )