pdarleyjr commited on
Commit
2248522
·
1 Parent(s): a344135

Simplify Gradio interface for better API compatibility

Browse files
Files changed (1) hide show
  1. app.py +22 -29
app.py CHANGED
@@ -5,22 +5,22 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
  model = T5ForConditionalGeneration.from_pretrained('t5-small')
6
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
7
 
8
- def generate_clinical_report(input_text, max_length=256, num_beams=4, no_repeat_ngram_size=3, length_penalty=2.0, early_stopping=True):
9
  """
10
- Generate a clinical report from the input text using the T5 model with configurable parameters.
11
  """
12
  try:
13
  # Prepare input text
14
  input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
15
 
16
- # Generate report with provided parameters
17
  outputs = model.generate(
18
  input_ids,
19
- max_length=max_length,
20
- num_beams=num_beams,
21
- no_repeat_ngram_size=no_repeat_ngram_size,
22
- length_penalty=length_penalty,
23
- early_stopping=early_stopping,
24
  bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
25
  for word in ['http', 'www', '.com', '.org']]
26
  )
@@ -31,34 +31,27 @@ def generate_clinical_report(input_text, max_length=256, num_beams=4, no_repeat_
31
  print(f"Error generating report: {str(e)}")
32
  return f"Error: {str(e)}"
33
 
34
- # Create Gradio interface with API configuration
35
  demo = gr.Interface(
36
  fn=generate_clinical_report,
37
- inputs=[
38
- gr.Textbox(
39
- lines=8,
40
- placeholder="Enter clinical notes here...",
41
- label="Clinical Notes"
42
- ),
43
- gr.Slider(minimum=50, maximum=500, value=256, step=1, label="Max Length"),
44
- gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Num Beams"),
45
- gr.Slider(minimum=1, maximum=5, value=3, step=1, label="No Repeat Ngram Size"),
46
- gr.Slider(minimum=0.1, maximum=5.0, value=2.0, step=0.1, label="Length Penalty"),
47
- gr.Checkbox(value=True, label="Early Stopping")
48
- ],
49
- outputs=gr.Textbox(lines=8, label="Generated Clinical Report"),
50
  title="Clinical Report Generator",
51
  description="Generate professional clinical reports from clinical notes using a T5 model.",
52
- allow_flagging="never",
53
- analytics_enabled=False
54
  )
55
 
56
- # Launch the app with optimized configuration for Hugging Face Spaces
57
  if __name__ == "__main__":
58
- demo.queue() # Enable queue with default settings
59
  demo.launch(
60
  server_name="0.0.0.0",
61
- share=True,
62
- show_error=True,
63
- max_threads=40 # Increase max threads for better performance
64
  )
 
5
  model = T5ForConditionalGeneration.from_pretrained('t5-small')
6
  tokenizer = T5Tokenizer.from_pretrained('t5-small')
7
 
8
+ def generate_clinical_report(input_text):
9
  """
10
+ Generate a clinical report from the input text using the T5 model.
11
  """
12
  try:
13
  # Prepare input text
14
  input_ids = tokenizer.encode("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True)
15
 
16
+ # Generate report
17
  outputs = model.generate(
18
  input_ids,
19
+ max_length=256,
20
+ num_beams=4,
21
+ no_repeat_ngram_size=3,
22
+ length_penalty=2.0,
23
+ early_stopping=True,
24
  bad_words_ids=[[tokenizer.encode(word, add_special_tokens=False)[0]]
25
  for word in ['http', 'www', '.com', '.org']]
26
  )
 
31
  print(f"Error generating report: {str(e)}")
32
  return f"Error: {str(e)}"
33
 
34
+ # Create Gradio interface
35
  demo = gr.Interface(
36
  fn=generate_clinical_report,
37
+ inputs=gr.Textbox(
38
+ lines=8,
39
+ placeholder="Enter clinical notes here...",
40
+ label="Clinical Notes"
41
+ ),
42
+ outputs=gr.Textbox(
43
+ lines=8,
44
+ label="Generated Clinical Report"
45
+ ),
 
 
 
 
46
  title="Clinical Report Generator",
47
  description="Generate professional clinical reports from clinical notes using a T5 model.",
48
+ allow_flagging="never"
 
49
  )
50
 
51
+ # Launch the app
52
  if __name__ == "__main__":
53
+ demo.queue()
54
  demo.launch(
55
  server_name="0.0.0.0",
56
+ share=True
 
 
57
  )