chandru1652's picture
Initial public commit
81cdd5f
raw
history blame
10.3 kB
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import json
import logging
import random
import re
from dataclasses import replace
from pathlib import Path
from config import BASE_DIR, RANDOMIZE_CHOICES
from models import Case, CaseSummary, AnswerLog, ConversationTurn, QuestionOutcome, ClinicalMCQ
# --- Configuration ---
# Configure basic logging (optional, adjust as needed)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def fetch_report(report_path: Path):
"""Report file reading utility function."""
try:
with open(report_path, 'r') as f:
report = json.load(f)
logger.info(f"Successfully loaded '{report_path}' into memory.")
return report
except FileNotFoundError:
logger.error(f"ERROR: Could not find report file: {report_path}")
return ""
def get_available_reports(reports_csv_path: Path):
"""Reads available reports as Cases for this demo."""
available_reports: dict[str, Case] = {}
if reports_csv_path.is_file():
try:
with (open(reports_csv_path, mode='r', encoding='utf-8') as csvfile):
reader = csv.DictReader(csvfile)
required_headers = {'case_id', 'case_condition_name', 'report_path', 'download_image_url', 'findings'}
if not required_headers.issubset(reader.fieldnames):
logger.error(
f"CSV file {reports_csv_path} is missing one or more required headers: {required_headers - set(reader.fieldnames)}"
)
else:
for row in reader:
case_id = row['case_id']
condition_name = row['case_condition_name']
report_path_from_csv = row['report_path'] # e.g., static/reports/report1.txt or empty
download_image_url_from_csv = row['download_image_url']
potential_findings = row['findings']
# Construct absolute path for report file validation (paths from CSV are relative to BASE_DIR)
abs_report_path_to_check = BASE_DIR / report_path_from_csv
if not abs_report_path_to_check.is_file():
logger.warning(
f"Image file not found for case '{case_id}' at '{abs_report_path_to_check}'. Skipping this entry.")
continue
if download_image_url_from_csv is None or download_image_url_from_csv == "":
logger.warning(
f"Download image url not found for case '{case_id}'. Skipping this entry.")
continue
ground_truth_labels = fetch_report(report_path_from_csv)
case = Case(
id=case_id,
condition_name=condition_name,
ground_truth_labels=ground_truth_labels,
download_image_url=download_image_url_from_csv,
potential_findings=potential_findings,
)
available_reports[str(case_id)] = case
logger.info(f"Loaded {len(available_reports)} report/image pairs from CSV.")
except Exception as e:
logger.error(f"Error reading or processing CSV file {reports_csv_path}: {e}", exc_info=True)
else:
logger.warning(f"Manifest CSV file not found at {reports_csv_path}. AVAILABLE_REPORTS will be empty.")
return available_reports
def get_json_from_model_response(response_text: str) -> dict:
"""
Robustly parses a JSON object from a response that may contain it
within a markdown code block.
"""
# This regex now looks for a JSON object starting with { and ending with }
json_match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
logger.error(f"Failed to decode JSON after extraction: {e}")
raise Exception(f"Could not parse JSON from extracted block: {json_str}")
# Fallback if the model misses the markdown block
logger.warning("Could not find a ```json block. Falling back to raw search.")
json_match_fallback = re.search(r"(\{.*\})", response_text, re.DOTALL)
if json_match_fallback:
return json.loads(json_match_fallback.group(1))
raise Exception(f"Could not find or parse JSON object in the API response: {response_text}")
def get_potential_findings(case: Case) -> str:
"""Get potential findings for a case."""
return case.potential_findings
def build_summary_template(case: Case, rag_cache: dict) -> CaseSummary:
"""Builds summary template with static data like potential_findings, guideline_resources and condition."""
citation_string = "" # Default
rag_data = rag_cache.get(case.id, {})
citations = rag_data.get("citations", [])
if citations:
citation_string = ', '.join(map(str, citations))
return CaseSummary(
med_gemma_interpretation="",
potential_findings=get_potential_findings(case),
rationale=[],
guideline_specific_resource=citation_string,
condition=case.condition_name
)
def populate_rationale(summary_template: CaseSummary, conversation_history: list[ConversationTurn]) -> CaseSummary:
"""Populates rationale and interpretation depending on user journey."""
correct_count = 0
total_questions = len(conversation_history)
rationale_logs = []
for turn in conversation_history:
question = turn.clinicalMcq.question
choices = turn.clinicalMcq.choices
model_answer_key = turn.clinicalMcq.answer
user_attempt1_key = turn.userResponse.attempt1
user_attempt2_key = turn.userResponse.attempt2
correct_answer_text = choices.get(model_answer_key, f"N/A - Model Answer Key '{model_answer_key}' not found.")
outcomes = []
if user_attempt1_key != model_answer_key and user_attempt2_key != model_answer_key:
user_attempt_key = user_attempt2_key if user_attempt2_key else user_attempt1_key
incorrect_text = choices[user_attempt_key]
outcomes.append(QuestionOutcome(type="Incorrect", text=incorrect_text))
else:
correct_count += 1
outcomes.append(QuestionOutcome(type="Correct", text=correct_answer_text))
rationale_logs.append(AnswerLog(question=question, outcomes=outcomes))
accuracy = (correct_count / total_questions) * 100 if total_questions > 0 else 0
if accuracy == 100:
interpretation = f"Wonderful job! You achieved a perfect score of {accuracy:.0f}%, correctly identifying all key findings on your first attempt."
elif accuracy >= 50:
interpretation = f"Good job. You scored {accuracy:.0f}%, showing a solid understanding of the key findings for this case."
else:
interpretation = f"This was a challenging case, and you scored {accuracy:.0f}%. More preparation is needed. Review the rationale below for details."
return CaseSummary(
med_gemma_interpretation=interpretation,
potential_findings=summary_template.potential_findings,
rationale=rationale_logs,
guideline_specific_resource=summary_template.guideline_specific_resource,
condition=summary_template.condition,
)
def randomize_mcqs(original_mcqs: list[ClinicalMCQ]) -> list[ClinicalMCQ]:
"""
Takes a list of clinical MCQs and randomizes their answer choices.
If an error occurs while randomizing a question, it returns the original question
in its place and continues.
"""
if not RANDOMIZE_CHOICES:
return original_mcqs
randomized_questions = []
for q in original_mcqs:
try:
# --- Step 1: Identify the correct answer's text ---
# Before shuffling, we save the actual string of the correct answer.
correct_answer_text = q.choices[q.answer]
# --- Step 2: Shuffle the choice values ---
# We extract the choice texts into a list and shuffle them in place.
choice_texts = list(q.choices.values())
random.shuffle(choice_texts)
# --- Step 3: Rebuild choices and find the new answer key (Concise version) ---
# Pair the original sorted keys with the newly shuffled texts using zip.
keys = sorted(q.choices.keys())
new_choices = dict(zip(keys, choice_texts))
# Efficiently find the new key corresponding to the correct answer's text.
new_answer_key = next(key for key, value in new_choices.items() if value == correct_answer_text)
# --- Step 4: Create an updated, immutable copy of the question ---
# Using `dataclasses.replace` is a clean, Pythonic way to create a new
# instance with updated values, promoting immutability.
randomized_q = replace(q, choices=new_choices, answer=new_answer_key)
randomized_questions.append(randomized_q)
except Exception as e:
# If any error occurs (e.g., KeyError from a bad answer key),
# print a warning and append the original, unmodified question.
logger.warning(f"Warning: Could not randomize question '{q.id}'. Returning original. Error: {e}")
randomized_questions.append(q)
return randomized_questions