# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import csv import json import logging import random import re from dataclasses import replace from pathlib import Path from config import BASE_DIR, RANDOMIZE_CHOICES from models import Case, CaseSummary, AnswerLog, ConversationTurn, QuestionOutcome, ClinicalMCQ # --- Configuration --- # Configure basic logging (optional, adjust as needed) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def fetch_report(report_path: Path): """Report file reading utility function.""" try: with open(report_path, 'r') as f: report = json.load(f) logger.info(f"Successfully loaded '{report_path}' into memory.") return report except FileNotFoundError: logger.error(f"ERROR: Could not find report file: {report_path}") return "" def get_available_reports(reports_csv_path: Path): """Reads available reports as Cases for this demo.""" available_reports: dict[str, Case] = {} if reports_csv_path.is_file(): try: with (open(reports_csv_path, mode='r', encoding='utf-8') as csvfile): reader = csv.DictReader(csvfile) required_headers = {'case_id', 'case_condition_name', 'report_path', 'download_image_url', 'findings'} if not required_headers.issubset(reader.fieldnames): logger.error( f"CSV file {reports_csv_path} is missing one or more required headers: {required_headers - set(reader.fieldnames)}" ) else: for row in reader: case_id = row['case_id'] condition_name = row['case_condition_name'] report_path_from_csv = row['report_path'] # e.g., static/reports/report1.txt or empty download_image_url_from_csv = row['download_image_url'] potential_findings = row['findings'] # Construct absolute path for report file validation (paths from CSV are relative to BASE_DIR) abs_report_path_to_check = BASE_DIR / report_path_from_csv if not abs_report_path_to_check.is_file(): logger.warning( f"Image file not found for case '{case_id}' at '{abs_report_path_to_check}'. Skipping this entry.") continue if download_image_url_from_csv is None or download_image_url_from_csv == "": logger.warning( f"Download image url not found for case '{case_id}'. Skipping this entry.") continue ground_truth_labels = fetch_report(report_path_from_csv) case = Case( id=case_id, condition_name=condition_name, ground_truth_labels=ground_truth_labels, download_image_url=download_image_url_from_csv, potential_findings=potential_findings, ) available_reports[str(case_id)] = case logger.info(f"Loaded {len(available_reports)} report/image pairs from CSV.") except Exception as e: logger.error(f"Error reading or processing CSV file {reports_csv_path}: {e}", exc_info=True) else: logger.warning(f"Manifest CSV file not found at {reports_csv_path}. AVAILABLE_REPORTS will be empty.") return available_reports def get_json_from_model_response(response_text: str) -> dict: """ Robustly parses a JSON object from a response that may contain it within a markdown code block. """ # This regex now looks for a JSON object starting with { and ending with } json_match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL) if json_match: json_str = json_match.group(1) try: return json.loads(json_str) except json.JSONDecodeError as e: logger.error(f"Failed to decode JSON after extraction: {e}") raise Exception(f"Could not parse JSON from extracted block: {json_str}") # Fallback if the model misses the markdown block logger.warning("Could not find a ```json block. Falling back to raw search.") json_match_fallback = re.search(r"(\{.*\})", response_text, re.DOTALL) if json_match_fallback: return json.loads(json_match_fallback.group(1)) raise Exception(f"Could not find or parse JSON object in the API response: {response_text}") def get_potential_findings(case: Case) -> str: """Get potential findings for a case.""" return case.potential_findings def build_summary_template(case: Case, rag_cache: dict) -> CaseSummary: """Builds summary template with static data like potential_findings, guideline_resources and condition.""" citation_string = "" # Default rag_data = rag_cache.get(case.id, {}) citations = rag_data.get("citations", []) if citations: citation_string = ', '.join(map(str, citations)) return CaseSummary( med_gemma_interpretation="", potential_findings=get_potential_findings(case), rationale=[], guideline_specific_resource=citation_string, condition=case.condition_name ) def populate_rationale(summary_template: CaseSummary, conversation_history: list[ConversationTurn]) -> CaseSummary: """Populates rationale and interpretation depending on user journey.""" correct_count = 0 total_questions = len(conversation_history) rationale_logs = [] for turn in conversation_history: question = turn.clinicalMcq.question choices = turn.clinicalMcq.choices model_answer_key = turn.clinicalMcq.answer user_attempt1_key = turn.userResponse.attempt1 user_attempt2_key = turn.userResponse.attempt2 correct_answer_text = choices.get(model_answer_key, f"N/A - Model Answer Key '{model_answer_key}' not found.") outcomes = [] if user_attempt1_key != model_answer_key and user_attempt2_key != model_answer_key: user_attempt_key = user_attempt2_key if user_attempt2_key else user_attempt1_key incorrect_text = choices[user_attempt_key] outcomes.append(QuestionOutcome(type="Incorrect", text=incorrect_text)) else: correct_count += 1 outcomes.append(QuestionOutcome(type="Correct", text=correct_answer_text)) rationale_logs.append(AnswerLog(question=question, outcomes=outcomes)) accuracy = (correct_count / total_questions) * 100 if total_questions > 0 else 0 if accuracy == 100: interpretation = f"Wonderful job! You achieved a perfect score of {accuracy:.0f}%, correctly identifying all key findings on your first attempt." elif accuracy >= 50: interpretation = f"Good job. You scored {accuracy:.0f}%, showing a solid understanding of the key findings for this case." else: interpretation = f"This was a challenging case, and you scored {accuracy:.0f}%. More preparation is needed. Review the rationale below for details." return CaseSummary( med_gemma_interpretation=interpretation, potential_findings=summary_template.potential_findings, rationale=rationale_logs, guideline_specific_resource=summary_template.guideline_specific_resource, condition=summary_template.condition, ) def randomize_mcqs(original_mcqs: list[ClinicalMCQ]) -> list[ClinicalMCQ]: """ Takes a list of clinical MCQs and randomizes their answer choices. If an error occurs while randomizing a question, it returns the original question in its place and continues. """ if not RANDOMIZE_CHOICES: return original_mcqs randomized_questions = [] for q in original_mcqs: try: # --- Step 1: Identify the correct answer's text --- # Before shuffling, we save the actual string of the correct answer. correct_answer_text = q.choices[q.answer] # --- Step 2: Shuffle the choice values --- # We extract the choice texts into a list and shuffle them in place. choice_texts = list(q.choices.values()) random.shuffle(choice_texts) # --- Step 3: Rebuild choices and find the new answer key (Concise version) --- # Pair the original sorted keys with the newly shuffled texts using zip. keys = sorted(q.choices.keys()) new_choices = dict(zip(keys, choice_texts)) # Efficiently find the new key corresponding to the correct answer's text. new_answer_key = next(key for key, value in new_choices.items() if value == correct_answer_text) # --- Step 4: Create an updated, immutable copy of the question --- # Using `dataclasses.replace` is a clean, Pythonic way to create a new # instance with updated values, promoting immutability. randomized_q = replace(q, choices=new_choices, answer=new_answer_key) randomized_questions.append(randomized_q) except Exception as e: # If any error occurs (e.g., KeyError from a bad answer key), # print a warning and append the original, unmodified question. logger.warning(f"Warning: Could not randomize question '{q.id}'. Returning original. Error: {e}") randomized_questions.append(q) return randomized_questions