Jiaaaaaaax commited on
Commit
8fb710a
·
verified ·
1 Parent(s): fe3550f

Upload 3 files

Browse files
Files changed (3) hide show
  1. README (1).md +12 -0
  2. app (2).py +457 -0
  3. requirements (1).txt +9 -0
README (1).md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Motivational Interviewing Gemini
3
+ emoji: 😻
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.41.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app (2).py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import google.generativeai as genai
4
+ import whisper
5
+ import torch
6
+ import re
7
+ import numpy as np
8
+ import tempfile
9
+ import os
10
+ import json
11
+ from pathlib import Path
12
+ from moviepy import VideoFileClip
13
+ from pyannote.audio import Pipeline
14
+
15
+ # Ensure necessary imports are included
16
+ import time
17
+ import ffmpeg
18
+
19
+ # MediaProcessor class handles media processing (transcription and diarization)
20
+ class MediaProcessor:
21
+ def __init__(self, auth_token: str):
22
+ """
23
+ Initialize with HuggingFace auth token for speaker diarization
24
+ """
25
+ # Load Whisper model
26
+ self.whisper_model = whisper.load_model("medium")
27
+ # Initialize PyAnnote speaker diarization pipeline
28
+ self.diarization_pipeline = Pipeline.from_pretrained(
29
+ "pyannote/speaker-diarization-3.0",
30
+ use_auth_token=auth_token
31
+ )
32
+ self.supported_formats = {
33
+ 'audio': ['.mp3', '.wav', '.m4a', '.ogg', '.flac'],
34
+ 'video': ['.mp4', '.avi', '.mov', '.mkv', '.webm']
35
+ }
36
+
37
+ def process_media(self, file, progress_bar=None) -> pd.DataFrame:
38
+ """Process audio or video file and return transcript DataFrame"""
39
+ file_ext = Path(file.name).suffix.lower()
40
+
41
+ with tempfile.TemporaryDirectory() as temp_dir:
42
+ temp_path = Path(temp_dir) / file.name
43
+
44
+ # Save uploaded file
45
+ with open(temp_path, 'wb') as f:
46
+ f.write(file.getvalue())
47
+
48
+ # Convert video to audio if necessary
49
+ if file_ext in self.supported_formats['video']:
50
+ audio_path = self._extract_audio_from_video(temp_path)
51
+ else:
52
+ audio_path = temp_path
53
+
54
+ # Process audio
55
+ return self._process_audio_file(audio_path, progress_bar)
56
+
57
+ def _extract_audio_from_video(self, video_path: Path) -> Path:
58
+ """Extract audio from video file"""
59
+ audio_path = video_path.with_suffix('.wav')
60
+ video = VideoFileClip(str(video_path))
61
+ video.audio.write_audiofile(str(audio_path))
62
+ video.close()
63
+ return audio_path
64
+
65
+ def _process_audio_file(self, audio_path: Path, progress_bar) -> pd.DataFrame:
66
+ """
67
+ Process audio file with transcription and diarization
68
+ Returns DataFrame with speaker-separated transcript
69
+ """
70
+ if progress_bar:
71
+ progress_bar.progress(0.1)
72
+ progress_bar.text("Transcribing audio...")
73
+
74
+ # Transcribe audio using Whisper
75
+ transcription = self.whisper_model.transcribe(str(audio_path))
76
+
77
+ if progress_bar:
78
+ progress_bar.progress(0.5)
79
+ progress_bar.text("Performing speaker diarization...")
80
+
81
+ # Perform speaker diarization
82
+ diarization = self.diarization_pipeline(str(audio_path))
83
+
84
+ if progress_bar:
85
+ progress_bar.progress(0.8)
86
+ progress_bar.text("Aligning transcription with speakers...")
87
+
88
+ # Align transcription with speaker segments
89
+ transcript_data = self._align_transcript_with_speakers(
90
+ transcription, diarization
91
+ )
92
+
93
+ if progress_bar:
94
+ progress_bar.progress(1.0)
95
+ progress_bar.text("Processing complete!")
96
+
97
+ return pd.DataFrame(transcript_data)
98
+
99
+ def _align_transcript_with_speakers(self, transcription, diarization):
100
+ """
101
+ Align transcription with speaker segments
102
+ Returns list of dicts with aligned data
103
+ """
104
+ # Prepare a list to hold the aligned segments
105
+ segments = []
106
+ # Iterate over diarization segments
107
+ for segment in diarization.itersegments():
108
+ speaker = diarization[segment]
109
+ # Find corresponding text from transcription
110
+ text = self._find_text_in_timerange(
111
+ transcription['segments'],
112
+ segment.start,
113
+ segment.end
114
+ )
115
+ if text:
116
+ segments.append({
117
+ 'P or C': 'P' if speaker == 'SPEAKER_00' else 'C',
118
+ 'Content of Utterance': text,
119
+ 'Start Time': segment.start,
120
+ 'End Time': segment.end,
121
+ 'Speaker': speaker
122
+ })
123
+ return segments
124
+
125
+ @staticmethod
126
+ def _find_text_in_timerange(segments, start_time, end_time):
127
+ """Find transcribed text within a time range"""
128
+ relevant_segments = [
129
+ seg['text'] for seg in segments
130
+ if (seg['start'] >= start_time and seg['end'] <= end_time)
131
+ ]
132
+ return ' '.join(relevant_segments).strip()
133
+
134
+ # MITIAnalyzer class handles analysis and scoring using Google Gemini API
135
+ class MITIAnalyzer:
136
+ def __init__(self, api_key):
137
+ # Set the API key for Google Gemini
138
+ genai.configure(api_key=api_key)
139
+ self.global_scores = {
140
+ "cultivating_change": None,
141
+ "softening_sustain-talk": None,
142
+ "partnership": None,
143
+ "empathy": None
144
+ }
145
+ self.behavior_counts = {
146
+ "gi": 0, # Giving Information
147
+ "persuade": 0,
148
+ "persuade_with": 0, # Persuade with Permission
149
+ "question": 0,
150
+ "sr": 0, # Simple Reflection
151
+ "cr": 0, # Complex Reflection
152
+ "affirm": 0,
153
+ "seek": 0, # Seeking Collaboration
154
+ "emphasize": 0, # Emphasizing Autonomy
155
+ "confront": 0
156
+ }
157
+
158
+ def extract_score(self, response_text):
159
+ """Extract numerical score from Gemini API response"""
160
+ # Look for patterns like "Score: X" or "I would rate this as X"
161
+ score_patterns = [
162
+ r"score.*?([1-5])",
163
+ r"rate.*?([1-5])",
164
+ r"([1-5]).*?out of 5"
165
+ ]
166
+
167
+ for pattern in score_patterns:
168
+ match = re.search(pattern, response_text.lower())
169
+ if match:
170
+ return int(match.group(1))
171
+ return None
172
+
173
+ def analyze_transcript(self, transcript_df):
174
+ """Analyze transcript and generate all MITI scores"""
175
+ # Analyze global scores
176
+ model = genai.GenerativeModel('gemini-1.5-flash')
177
+ generation_config = genai.GenerationConfig(max_output_tokens=2048)
178
+ for dimension in self.global_scores.keys():
179
+ prompt = self.load_prompt(f"prompts/prompts/0{list(self.global_scores.keys()).index(dimension)+1}-MITI-global-{dimension.replace('_', '-')}.md")
180
+
181
+ full_prompt = f"{prompt}\n\n<transcript>\n{transcript_df.to_csv(index=False)}\n</transcript>"
182
+
183
+ response = model.generate_content(
184
+ full_prompt,
185
+ generation_config=generation_config
186
+ )
187
+ score = self.extract_score(response.text)
188
+ self.global_scores[dimension] = score
189
+
190
+ # Analyze behavior counts
191
+ self.count_behaviors(transcript_df)
192
+
193
+ def count_behaviors(self, transcript_df):
194
+ """Count specific behaviors in transcript"""
195
+ model = genai.GenerativeModel('gemini-1.5-flash')
196
+ generation_config = genai.GenerationConfig(max_output_tokens=2048)
197
+ # Create behavior detection prompt
198
+ behavior_prompt = """
199
+ You are an expert in Motivational Interviewing. Analyze the following therapist utterance and identify any of these behaviors:
200
+ - Giving Information (GI)
201
+ - Persuade
202
+ - Persuade with Permission
203
+ - Question (Q)
204
+ - Simple Reflection (SR)
205
+ - Complex Reflection (CR)
206
+ - Affirm (AF)
207
+ - Seeking Collaboration (Seek)
208
+ - Emphasizing Autonomy (Emphasize)
209
+ - Confront
210
+
211
+ Return results in JSON format, e.g., {"GI":1, "Persuade":0, ...}
212
+ """
213
+
214
+ for _, row in transcript_df.iterrows():
215
+ if row['P or C'] == 'P': # Provider/Therapist utterance
216
+
217
+ behavior_full_prompt = f"{behavior_prompt}\n\nUtterance: {row['Content of Utterance']}"
218
+ response = model.generate_content(
219
+ behavior_full_prompt,
220
+ generation_config=generation_config
221
+ )
222
+ try:
223
+ # Extract JSON from response
224
+ behaviors = json.loads(response.text)
225
+ for behavior, count in behaviors.items():
226
+ key = behavior.lower().replace(" ", "_")
227
+ if key in self.behavior_counts:
228
+ self.behavior_counts[key] += count
229
+ except Exception as e:
230
+ st.warning(f"Could not parse behaviors for utterance: {row['Content of Utterance']}\nError: {e}")
231
+
232
+ def calculate_summary_scores(self):
233
+ """Calculate MITI summary scores"""
234
+ summary = {}
235
+
236
+ # Technical Global
237
+ if all(self.global_scores[s] is not None for s in ['cultivating_change', 'softening_sustain-talk']):
238
+ summary['technical'] = (self.global_scores['cultivating_change'] +
239
+ self.global_scores['softening_sustain-talk']) / 2
240
+
241
+ # Relational Global
242
+ if all(self.global_scores[s] is not None for s in ['partnership', 'empathy']):
243
+ summary['relational'] = (self.global_scores['partnership'] +
244
+ self.global_scores['empathy']) / 2
245
+
246
+ # % Complex Reflections
247
+ total_reflections = self.behavior_counts['sr'] + self.behavior_counts['cr']
248
+ if total_reflections > 0:
249
+ summary['pct_cr'] = (self.behavior_counts['cr'] / total_reflections) * 100
250
+
251
+ # Reflection-to-Question Ratio
252
+ if self.behavior_counts['question'] > 0:
253
+ summary['r_to_q'] = total_reflections / self.behavior_counts['question']
254
+
255
+ # Total MI-Adherent
256
+ summary['total_mia'] = (self.behavior_counts['seek'] +
257
+ self.behavior_counts['affirm'] +
258
+ self.behavior_counts['emphasize'])
259
+
260
+ # Total MI Non-Adherent
261
+ summary['total_mina'] = (self.behavior_counts['confront'] +
262
+ self.behavior_counts['persuade'])
263
+
264
+ return summary
265
+
266
+ @staticmethod
267
+ def load_prompt(filename):
268
+ """Load prompt from file"""
269
+ try:
270
+ with open(filename, 'r') as f:
271
+ return f.read()
272
+ except Exception as e:
273
+ st.error(f"Could not load prompt file: {filename}\nError: {e}")
274
+ return ""
275
+
276
+ def render_miti_results(analyzer):
277
+ """Render MITI results in Streamlit"""
278
+ st.header("MITI Evaluation Results")
279
+
280
+ # Global Scores
281
+ st.subheader("Global Scores")
282
+ global_scores_df = pd.DataFrame(analyzer.global_scores.items(), columns=['Dimension', 'Score'])
283
+ st.table(global_scores_df)
284
+
285
+ # Behavior Counts
286
+ st.subheader("Behavior Counts")
287
+ counts_df = pd.DataFrame(analyzer.behavior_counts.items(), columns=['Behavior', 'Count'])
288
+ st.table(counts_df)
289
+
290
+ # Summary Scores
291
+ st.subheader("Summary Scores")
292
+ summary = analyzer.calculate_summary_scores()
293
+ summary_items = summary.items()
294
+ if summary_items:
295
+ summary_df = pd.DataFrame(summary_items, columns=['Metric', 'Value'])
296
+ st.table(summary_df)
297
+ else:
298
+ st.write("No summary scores available.")
299
+
300
+ def export_results(analyzer, export_format):
301
+ """Export results in specified format"""
302
+ results = {
303
+ 'global_scores': analyzer.global_scores,
304
+ 'behavior_counts': analyzer.behavior_counts,
305
+ 'summary_scores': analyzer.calculate_summary_scores()
306
+ }
307
+ if export_format == "JSON":
308
+ return json.dumps(results, indent=2)
309
+ elif export_format == "CSV":
310
+ # Convert results to CSV format
311
+ all_results = {**analyzer.global_scores, **analyzer.behavior_counts, **analyzer.calculate_summary_scores()}
312
+ df = pd.DataFrame(list(all_results.items()), columns=['Metric', 'Value'])
313
+ return df.to_csv(index=False)
314
+ elif export_format == "TXT":
315
+ # Plain text format
316
+ output = ""
317
+ output += "Global Scores:\n"
318
+ for k, v in analyzer.global_scores.items():
319
+ output += f"{k}: {v}\n"
320
+ output += "\nBehavior Counts:\n"
321
+ for k, v in analyzer.behavior_counts.items():
322
+ output += f"{k}: {v}\n"
323
+ output += "\nSummary Scores:\n"
324
+ for k, v in analyzer.calculate_summary_scores().items():
325
+ output += f"{k}: {v}\n"
326
+ return output
327
+
328
+ def main():
329
+ st.title("MITI Session Analyzer")
330
+
331
+ # Hide Streamlit's default hamburger menu
332
+ hide_streamlit_style = """
333
+ <style>
334
+ #MainMenu {visibility: hidden;}
335
+ footer {visibility: hidden;}
336
+ </style>
337
+ """
338
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
339
+
340
+ # Initialize processors
341
+ if 'media_processor' not in st.session_state:
342
+ if "HF_AUTH_TOKEN" not in st.secrets:
343
+ st.error("Hugging Face Auth Token not found. Please add it to Streamlit secrets.")
344
+ return
345
+ st.session_state.media_processor = MediaProcessor(
346
+ auth_token=st.secrets["HF_AUTH_TOKEN"]
347
+ )
348
+ if 'miti_analyzer' not in st.session_state:
349
+ if "GEMINI_API_KEY" not in st.secrets:
350
+ st.error("Gemini API Key not found. Please add it to Streamlit secrets.")
351
+ return
352
+ st.session_state.miti_analyzer = MITIAnalyzer(
353
+ api_key=st.secrets["GEMINI_API_KEY"]
354
+ )
355
+
356
+ # File upload section
357
+ st.subheader("Upload Session Recording or Transcript")
358
+
359
+ file_type = st.radio(
360
+ "Select input type:",
361
+ ["Audio/Video Recording", "Text Transcript"]
362
+ )
363
+
364
+ if file_type == "Audio/Video Recording":
365
+ supported_formats = (
366
+ st.session_state.media_processor.supported_formats['audio'] +
367
+ st.session_state.media_processor.supported_formats['video']
368
+ )
369
+
370
+ uploaded_file = st.file_uploader(
371
+ "Upload recording",
372
+ type=[fmt[1:] for fmt in supported_formats]
373
+ )
374
+
375
+ if uploaded_file:
376
+ progress_bar = st.progress(0)
377
+ with st.spinner("Processing media file..."):
378
+ try:
379
+ transcript_df = st.session_state.media_processor.process_media(
380
+ uploaded_file,
381
+ progress_bar
382
+ )
383
+ st.session_state.transcript_df = transcript_df
384
+
385
+ # Display transcript
386
+ st.subheader("Generated Transcript")
387
+ st.dataframe(transcript_df)
388
+
389
+ # Allow transcript editing
390
+ if st.checkbox("Edit Transcript"):
391
+ st.session_state.transcript_df = st.data_editor(
392
+ transcript_df,
393
+ num_rows="dynamic"
394
+ )
395
+
396
+ except Exception as e:
397
+ st.error(f"Error processing file: {str(e)}")
398
+
399
+ else: # Text Transcript
400
+ uploaded_file = st.file_uploader(
401
+ "Upload transcript (CSV format)",
402
+ type=['csv']
403
+ )
404
+
405
+ if uploaded_file:
406
+ try:
407
+ transcript_df = pd.read_csv(uploaded_file)
408
+ st.session_state.transcript_df = transcript_df
409
+ st.subheader("Transcript")
410
+ st.dataframe(transcript_df)
411
+ # Allow transcript editing
412
+ if st.checkbox("Edit Transcript"):
413
+ st.session_state.transcript_df = st.data_editor(
414
+ transcript_df,
415
+ num_rows="dynamic"
416
+ )
417
+
418
+ except Exception as e:
419
+ st.error(f"Error reading transcript: {str(e)}")
420
+
421
+ # Analysis section
422
+ if 'transcript_df' in st.session_state:
423
+ st.subheader("MITI Analysis")
424
+
425
+ if st.button("Generate MITI Ratings"):
426
+ with st.spinner("Analyzing session..."):
427
+ st.session_state.miti_analyzer.analyze_transcript(
428
+ st.session_state.transcript_df
429
+ )
430
+ render_miti_results(st.session_state.miti_analyzer)
431
+
432
+ # Save results
433
+ st.session_state.last_results = st.session_state.miti_analyzer
434
+
435
+ # Export options
436
+ if 'last_results' in st.session_state:
437
+ st.subheader("Export Analysis Report")
438
+ export_format = st.selectbox(
439
+ "Select export format",
440
+ ["JSON", "CSV", "TXT"]
441
+ )
442
+
443
+ if st.button("Download Report"):
444
+ report_data = export_results(
445
+ st.session_state.last_results,
446
+ export_format
447
+ )
448
+ file_extension = export_format.lower()
449
+ st.download_button(
450
+ label="Download Report",
451
+ data=report_data,
452
+ file_name=f"miti_analysis.{file_extension}",
453
+ mime=f"text/{file_extension}" if export_format != 'JSON' else 'application/json'
454
+ )
455
+
456
+ if __name__ == "__main__":
457
+ main()
requirements (1).txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ google-generativeai
4
+ git+https://github.com/openai/whisper.git
5
+ torch
6
+ numpy
7
+ moviepy
8
+ pyannote.audio
9
+ ffmpeg-python