karthi311 commited on
Commit
d648110
·
verified ·
1 Parent(s): 16a9e04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -225
app.py CHANGED
@@ -1,225 +1,227 @@
1
- import torch
2
- import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
4
- from pydub import AudioSegment
5
- from sentence_transformers import SentenceTransformer, util
6
- import spacy
7
- import json
8
- from faster_whisper import WhisperModel
9
-
10
- # Audio conversion from MP4 to MP3
11
- def convert_mp4_to_mp3(mp4_path, mp3_path):
12
- try:
13
- audio = AudioSegment.from_file(mp4_path, format="mp4")
14
- audio.export(mp3_path, format="mp3")
15
- except Exception as e:
16
- raise RuntimeError(f"Error converting MP4 to MP3: {e}")
17
-
18
-
19
- # Check if CUDA is available for GPU acceleration
20
- if torch.cuda.is_available():
21
- device = "cuda"
22
- compute_type = "float16"
23
- else:
24
- device = "cpu"
25
- compute_type = "int8"
26
-
27
-
28
- # Load Faster Whisper Model for transcription
29
- def load_faster_whisper():
30
- model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
31
- return model
32
-
33
-
34
- # Load NLP model and other helpers
35
- nlp = spacy.load("en_core_web_sm")
36
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
37
-
38
- tokenizer = AutoTokenizer.from_pretrained("Mahalingam/DistilBart-Med-Summary")
39
- model = AutoModelForSeq2SeqLM.from_pretrained("Mahalingam/DistilBart-Med-Summary")
40
-
41
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
42
-
43
-
44
- soap_prompts = {
45
- "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
46
- "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
47
- "assessment": "Clinical assessments, expertise-based opinions on conditions, and significance of medical interventions. Focused on medical evaluations or patient condition summaries.",
48
- "plan": "Future steps, recommendations for treatment, follow-up instructions, and healthcare management plans."
49
- }
50
- soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
51
-
52
-
53
- # Load Mistral model and tokenizer
54
- def load_mistral_model():
55
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
56
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
57
- model.to(device)
58
- return model, tokenizer
59
-
60
-
61
- # Initialize Mistral
62
- mistral_model, mistral_tokenizer = load_mistral_model()
63
-
64
- # Query function for Mistral
65
- def mistral_query(user_prompt, soap_note):
66
- combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
67
- try:
68
- inputs = mistral_tokenizer(combined_prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)
69
- outputs = mistral_model.generate(
70
- inputs["input_ids"],
71
- max_length=512,
72
- temperature=0.7,
73
- num_beams=4,
74
- no_repeat_ngram_size=3
75
- )
76
- return mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
77
- except Exception as e:
78
- return f"Error generating response: {e}"
79
-
80
-
81
- # Convert the response to JSON format
82
- def convert_to_json(template):
83
- try:
84
- lines = template.split("\n")
85
- json_data = {}
86
- section = None
87
- for line in lines:
88
- if line.endswith(":"):
89
- section = line[:-1]
90
- json_data[section] = []
91
- elif section:
92
- json_data[section].append(line.strip())
93
- return json.dumps(json_data, indent=2)
94
- except Exception as e:
95
- return f"Error converting to JSON: {e}"
96
-
97
-
98
- # Transcription using Faster Whisper
99
- def transcribe_audio(mp4_path):
100
- try:
101
- print(f"Processing MP4 file: {mp4_path}")
102
- model = load_faster_whisper()
103
- mp3_path = "output_audio.mp3"
104
- convert_mp4_to_mp3(mp4_path, mp3_path)
105
-
106
- # Transcribe using Faster Whisper
107
- result, segments = model.transcribe(mp3_path, beam_size=5)
108
- transcription = " ".join([seg.text for seg in segments])
109
- return transcription
110
- except Exception as e:
111
- return f"Error during transcription: {e}"
112
-
113
-
114
- # Classify the sentence to the correct SOAP section
115
- def classify_sentence(sentence):
116
- similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
117
- return max(similarities, key=similarities.get)
118
-
119
-
120
- # Summarize the section if it's too long
121
- def summarize_section(section_text):
122
- if len(section_text.split()) < 50:
123
- return section_text
124
- target_length = int(len(section_text.split()) * 0.65)
125
- inputs = tokenizer.encode(section_text, return_tensors="pt", truncation=True, max_length=1024)
126
- summary_ids = model.generate(
127
- inputs,
128
- max_length=target_length,
129
- min_length=int(target_length * 0.60),
130
- length_penalty=1.0,
131
- num_beams=4
132
- )
133
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
134
-
135
-
136
- # Analyze the SOAP content and divide into sections
137
- def soap_analysis(text):
138
- doc = nlp(text)
139
- soap_note = {section: "" for section in soap_prompts.keys()}
140
-
141
- for sentence in doc.sents:
142
- section = classify_sentence(sentence.text)
143
- soap_note[section] += sentence.text + " "
144
-
145
- # Summarize each section of the SOAP note
146
- for section in soap_note:
147
- soap_note[section] = summarize_section(soap_note[section].strip())
148
-
149
- return format_soap_output(soap_note)
150
-
151
-
152
- # Format the SOAP note output
153
- def format_soap_output(soap_note):
154
- return (
155
- f"Subjective:\n{soap_note['subjective']}\n\n"
156
- f"Objective:\n{soap_note['objective']}\n\n"
157
- f"Assessment:\n{soap_note['assessment']}\n\n"
158
- f"Plan:\n{soap_note['plan']}\n"
159
- )
160
-
161
-
162
- # Process file function for audio to SOAP
163
- def process_file(mp4_file, user_prompt):
164
- transcription = transcribe_audio(mp4_file.name)
165
- print("Transcribed Text: ", transcription)
166
-
167
- soap_note = soap_analysis(transcription)
168
- print("SOAP Notes: ", soap_note)
169
-
170
- template_output = mistral_query(user_prompt, soap_note)
171
- print("Template: ", template_output)
172
-
173
- json_output = convert_to_json(template_output)
174
-
175
- return soap_note, template_output, json_output
176
-
177
-
178
- # Process text function for text input to SOAP
179
- def process_text(text, user_prompt):
180
- soap_note = soap_analysis(text)
181
- print(soap_note)
182
-
183
- template_output = mistral_query(user_prompt, soap_note)
184
- print(template_output)
185
- json_output = convert_to_json(template_output)
186
-
187
- return soap_note, template_output, json_output
188
-
189
-
190
- # Launch the Gradio interface
191
- def launch_gradio():
192
- with gr.Blocks(theme=gr.themes.Default()) as demo:
193
- gr.Markdown("# SOAP Note Generator")
194
- with gr.Tab("Audio to SOAP"):
195
- gr.Interface(
196
- fn=process_file,
197
- inputs=[
198
- gr.File(label="Upload MP4 File"),
199
- gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6),
200
- ],
201
- outputs=[
202
- gr.Textbox(label="SOAP Note"),
203
- gr.Textbox(label="Generated Template from Mistral"),
204
- gr.Textbox(label="JSON Output"),
205
- ],
206
- )
207
- with gr.Tab("Text to SOAP"):
208
- gr.Interface(
209
- fn=process_text,
210
- inputs=[
211
- gr.Textbox(label="Enter Text", placeholder="Enter medical notes...", lines=6),
212
- gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6),
213
- ],
214
- outputs=[
215
- gr.Textbox(label="SOAP Note"),
216
- gr.Textbox(label="Generated Template from Mistral"),
217
- gr.Textbox(label="JSON Output"),
218
- ],
219
- )
220
- demo.launch(share=True, debug=True)
221
-
222
-
223
- # Run the Gradio app
224
- if __name__ == "__main__":
225
- launch_gradio()
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
4
+ from pydub import AudioSegment
5
+ from sentence_transformers import SentenceTransformer, util
6
+ import spacy
7
+ import spacy.cli
8
+ spacy.cli.download("en_core_web_sm")
9
+ import json
10
+ from faster_whisper import WhisperModel
11
+
12
+ # Audio conversion from MP4 to MP3
13
+ def convert_mp4_to_mp3(mp4_path, mp3_path):
14
+ try:
15
+ audio = AudioSegment.from_file(mp4_path, format="mp4")
16
+ audio.export(mp3_path, format="mp3")
17
+ except Exception as e:
18
+ raise RuntimeError(f"Error converting MP4 to MP3: {e}")
19
+
20
+
21
+ # Check if CUDA is available for GPU acceleration
22
+ if torch.cuda.is_available():
23
+ device = "cuda"
24
+ compute_type = "float16"
25
+ else:
26
+ device = "cpu"
27
+ compute_type = "int8"
28
+
29
+
30
+ # Load Faster Whisper Model for transcription
31
+ def load_faster_whisper():
32
+ model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type)
33
+ return model
34
+
35
+
36
+ # Load NLP model and other helpers
37
+ nlp = spacy.load("en_core_web_sm")
38
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained("Mahalingam/DistilBart-Med-Summary")
41
+ model = AutoModelForSeq2SeqLM.from_pretrained("Mahalingam/DistilBart-Med-Summary")
42
+
43
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
44
+
45
+
46
+ soap_prompts = {
47
+ "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.",
48
+ "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.",
49
+ "assessment": "Clinical assessments, expertise-based opinions on conditions, and significance of medical interventions. Focused on medical evaluations or patient condition summaries.",
50
+ "plan": "Future steps, recommendations for treatment, follow-up instructions, and healthcare management plans."
51
+ }
52
+ soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()}
53
+
54
+
55
+ # Load Mistral model and tokenizer
56
+ def load_mistral_model():
57
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
58
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
59
+ model.to(device)
60
+ return model, tokenizer
61
+
62
+
63
+ # Initialize Mistral
64
+ mistral_model, mistral_tokenizer = load_mistral_model()
65
+
66
+ # Query function for Mistral
67
+ def mistral_query(user_prompt, soap_note):
68
+ combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}"
69
+ try:
70
+ inputs = mistral_tokenizer(combined_prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)
71
+ outputs = mistral_model.generate(
72
+ inputs["input_ids"],
73
+ max_length=512,
74
+ temperature=0.7,
75
+ num_beams=4,
76
+ no_repeat_ngram_size=3
77
+ )
78
+ return mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+ except Exception as e:
80
+ return f"Error generating response: {e}"
81
+
82
+
83
+ # Convert the response to JSON format
84
+ def convert_to_json(template):
85
+ try:
86
+ lines = template.split("\n")
87
+ json_data = {}
88
+ section = None
89
+ for line in lines:
90
+ if line.endswith(":"):
91
+ section = line[:-1]
92
+ json_data[section] = []
93
+ elif section:
94
+ json_data[section].append(line.strip())
95
+ return json.dumps(json_data, indent=2)
96
+ except Exception as e:
97
+ return f"Error converting to JSON: {e}"
98
+
99
+
100
+ # Transcription using Faster Whisper
101
+ def transcribe_audio(mp4_path):
102
+ try:
103
+ print(f"Processing MP4 file: {mp4_path}")
104
+ model = load_faster_whisper()
105
+ mp3_path = "output_audio.mp3"
106
+ convert_mp4_to_mp3(mp4_path, mp3_path)
107
+
108
+ # Transcribe using Faster Whisper
109
+ result, segments = model.transcribe(mp3_path, beam_size=5)
110
+ transcription = " ".join([seg.text for seg in segments])
111
+ return transcription
112
+ except Exception as e:
113
+ return f"Error during transcription: {e}"
114
+
115
+
116
+ # Classify the sentence to the correct SOAP section
117
+ def classify_sentence(sentence):
118
+ similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()}
119
+ return max(similarities, key=similarities.get)
120
+
121
+
122
+ # Summarize the section if it's too long
123
+ def summarize_section(section_text):
124
+ if len(section_text.split()) < 50:
125
+ return section_text
126
+ target_length = int(len(section_text.split()) * 0.65)
127
+ inputs = tokenizer.encode(section_text, return_tensors="pt", truncation=True, max_length=1024)
128
+ summary_ids = model.generate(
129
+ inputs,
130
+ max_length=target_length,
131
+ min_length=int(target_length * 0.60),
132
+ length_penalty=1.0,
133
+ num_beams=4
134
+ )
135
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
136
+
137
+
138
+ # Analyze the SOAP content and divide into sections
139
+ def soap_analysis(text):
140
+ doc = nlp(text)
141
+ soap_note = {section: "" for section in soap_prompts.keys()}
142
+
143
+ for sentence in doc.sents:
144
+ section = classify_sentence(sentence.text)
145
+ soap_note[section] += sentence.text + " "
146
+
147
+ # Summarize each section of the SOAP note
148
+ for section in soap_note:
149
+ soap_note[section] = summarize_section(soap_note[section].strip())
150
+
151
+ return format_soap_output(soap_note)
152
+
153
+
154
+ # Format the SOAP note output
155
+ def format_soap_output(soap_note):
156
+ return (
157
+ f"Subjective:\n{soap_note['subjective']}\n\n"
158
+ f"Objective:\n{soap_note['objective']}\n\n"
159
+ f"Assessment:\n{soap_note['assessment']}\n\n"
160
+ f"Plan:\n{soap_note['plan']}\n"
161
+ )
162
+
163
+
164
+ # Process file function for audio to SOAP
165
+ def process_file(mp4_file, user_prompt):
166
+ transcription = transcribe_audio(mp4_file.name)
167
+ print("Transcribed Text: ", transcription)
168
+
169
+ soap_note = soap_analysis(transcription)
170
+ print("SOAP Notes: ", soap_note)
171
+
172
+ template_output = mistral_query(user_prompt, soap_note)
173
+ print("Template: ", template_output)
174
+
175
+ json_output = convert_to_json(template_output)
176
+
177
+ return soap_note, template_output, json_output
178
+
179
+
180
+ # Process text function for text input to SOAP
181
+ def process_text(text, user_prompt):
182
+ soap_note = soap_analysis(text)
183
+ print(soap_note)
184
+
185
+ template_output = mistral_query(user_prompt, soap_note)
186
+ print(template_output)
187
+ json_output = convert_to_json(template_output)
188
+
189
+ return soap_note, template_output, json_output
190
+
191
+
192
+ # Launch the Gradio interface
193
+ def launch_gradio():
194
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
195
+ gr.Markdown("# SOAP Note Generator")
196
+ with gr.Tab("Audio to SOAP"):
197
+ gr.Interface(
198
+ fn=process_file,
199
+ inputs=[
200
+ gr.File(label="Upload MP4 File"),
201
+ gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6),
202
+ ],
203
+ outputs=[
204
+ gr.Textbox(label="SOAP Note"),
205
+ gr.Textbox(label="Generated Template from Mistral"),
206
+ gr.Textbox(label="JSON Output"),
207
+ ],
208
+ )
209
+ with gr.Tab("Text to SOAP"):
210
+ gr.Interface(
211
+ fn=process_text,
212
+ inputs=[
213
+ gr.Textbox(label="Enter Text", placeholder="Enter medical notes...", lines=6),
214
+ gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6),
215
+ ],
216
+ outputs=[
217
+ gr.Textbox(label="SOAP Note"),
218
+ gr.Textbox(label="Generated Template from Mistral"),
219
+ gr.Textbox(label="JSON Output"),
220
+ ],
221
+ )
222
+ demo.launch(share=True, debug=True)
223
+
224
+
225
+ # Run the Gradio app
226
+ if __name__ == "__main__":
227
+ launch_gradio()