File size: 10,273 Bytes
81cdd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import csv
import json
import logging
import random
import re
from dataclasses import replace
from pathlib import Path

from config import BASE_DIR, RANDOMIZE_CHOICES
from models import Case, CaseSummary, AnswerLog, ConversationTurn, QuestionOutcome, ClinicalMCQ

# --- Configuration ---
# Configure basic logging (optional, adjust as needed)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def fetch_report(report_path: Path):
    """Report file reading utility function."""
    try:
        with open(report_path, 'r') as f:
            report = json.load(f)
        logger.info(f"Successfully loaded '{report_path}' into memory.")
        return report
    except FileNotFoundError:
        logger.error(f"ERROR: Could not find report file: {report_path}")
        return ""


def get_available_reports(reports_csv_path: Path):
    """Reads available reports as Cases for this demo."""
    available_reports: dict[str, Case] = {}
    if reports_csv_path.is_file():
        try:
            with (open(reports_csv_path, mode='r', encoding='utf-8') as csvfile):
                reader = csv.DictReader(csvfile)
                required_headers = {'case_id', 'case_condition_name', 'report_path', 'download_image_url', 'findings'}
                if not required_headers.issubset(reader.fieldnames):
                    logger.error(
                        f"CSV file {reports_csv_path} is missing one or more required headers: {required_headers - set(reader.fieldnames)}"
                    )
                else:
                    for row in reader:
                        case_id = row['case_id']
                        condition_name = row['case_condition_name']
                        report_path_from_csv = row['report_path']  # e.g., static/reports/report1.txt or empty
                        download_image_url_from_csv = row['download_image_url']
                        potential_findings = row['findings']

                        # Construct absolute path for report file validation (paths from CSV are relative to BASE_DIR)
                        abs_report_path_to_check = BASE_DIR / report_path_from_csv
                        if not abs_report_path_to_check.is_file():
                            logger.warning(
                                f"Image file not found for case '{case_id}' at '{abs_report_path_to_check}'. Skipping this entry.")
                            continue

                        if download_image_url_from_csv is None or download_image_url_from_csv == "":
                            logger.warning(
                                f"Download image url not found for case '{case_id}'. Skipping this entry.")
                            continue

                        ground_truth_labels = fetch_report(report_path_from_csv)
                        case = Case(
                            id=case_id,
                            condition_name=condition_name,
                            ground_truth_labels=ground_truth_labels,
                            download_image_url=download_image_url_from_csv,
                            potential_findings=potential_findings,
                        )
                        available_reports[str(case_id)] = case
                    logger.info(f"Loaded {len(available_reports)} report/image pairs from CSV.")

        except Exception as e:
            logger.error(f"Error reading or processing CSV file {reports_csv_path}: {e}", exc_info=True)
    else:
        logger.warning(f"Manifest CSV file not found at {reports_csv_path}. AVAILABLE_REPORTS will be empty.")
    return available_reports


def get_json_from_model_response(response_text: str) -> dict:
    """
    Robustly parses a JSON object from a response that may contain it
    within a markdown code block.
    """
    # This regex now looks for a JSON object starting with { and ending with }
    json_match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL)
    if json_match:
        json_str = json_match.group(1)
        try:
            return json.loads(json_str)
        except json.JSONDecodeError as e:
            logger.error(f"Failed to decode JSON after extraction: {e}")
            raise Exception(f"Could not parse JSON from extracted block: {json_str}")

    # Fallback if the model misses the markdown block
    logger.warning("Could not find a ```json block. Falling back to raw search.")
    json_match_fallback = re.search(r"(\{.*\})", response_text, re.DOTALL)
    if json_match_fallback:
        return json.loads(json_match_fallback.group(1))

    raise Exception(f"Could not find or parse JSON object in the API response: {response_text}")


def get_potential_findings(case: Case) -> str:
    """Get potential findings for a case."""
    return case.potential_findings


def build_summary_template(case: Case, rag_cache: dict) -> CaseSummary:
    """Builds summary template with static data like potential_findings, guideline_resources and condition."""
    citation_string = ""  # Default
    rag_data = rag_cache.get(case.id, {})
    citations = rag_data.get("citations", [])
    if citations:
        citation_string = ', '.join(map(str, citations))

    return CaseSummary(
        med_gemma_interpretation="",
        potential_findings=get_potential_findings(case),
        rationale=[],
        guideline_specific_resource=citation_string,
        condition=case.condition_name
    )


def populate_rationale(summary_template: CaseSummary, conversation_history: list[ConversationTurn]) -> CaseSummary:
    """Populates rationale and interpretation depending on user journey."""
    correct_count = 0
    total_questions = len(conversation_history)
    rationale_logs = []

    for turn in conversation_history:
        question = turn.clinicalMcq.question
        choices = turn.clinicalMcq.choices
        model_answer_key = turn.clinicalMcq.answer
        user_attempt1_key = turn.userResponse.attempt1
        user_attempt2_key = turn.userResponse.attempt2
        correct_answer_text = choices.get(model_answer_key, f"N/A - Model Answer Key '{model_answer_key}' not found.")

        outcomes = []
        if user_attempt1_key != model_answer_key and user_attempt2_key != model_answer_key:
            user_attempt_key = user_attempt2_key if user_attempt2_key else user_attempt1_key
            incorrect_text = choices[user_attempt_key]
            outcomes.append(QuestionOutcome(type="Incorrect", text=incorrect_text))
        else:
            correct_count += 1
        outcomes.append(QuestionOutcome(type="Correct", text=correct_answer_text))

        rationale_logs.append(AnswerLog(question=question, outcomes=outcomes))

    accuracy = (correct_count / total_questions) * 100 if total_questions > 0 else 0

    if accuracy == 100:
        interpretation = f"Wonderful job! You achieved a perfect score of {accuracy:.0f}%, correctly identifying all key findings on your first attempt."
    elif accuracy >= 50:
        interpretation = f"Good job. You scored {accuracy:.0f}%, showing a solid understanding of the key findings for this case."
    else:
        interpretation = f"This was a challenging case, and you scored {accuracy:.0f}%. More preparation is needed. Review the rationale below for details."

    return CaseSummary(
        med_gemma_interpretation=interpretation,
        potential_findings=summary_template.potential_findings,
        rationale=rationale_logs,
        guideline_specific_resource=summary_template.guideline_specific_resource,
        condition=summary_template.condition,
    )


def randomize_mcqs(original_mcqs: list[ClinicalMCQ]) -> list[ClinicalMCQ]:
    """
    Takes a list of clinical MCQs and randomizes their answer choices.
    If an error occurs while randomizing a question, it returns the original question
    in its place and continues.
    """
    if not RANDOMIZE_CHOICES:
        return original_mcqs
    randomized_questions = []

    for q in original_mcqs:
        try:
            # --- Step 1: Identify the correct answer's text ---
            # Before shuffling, we save the actual string of the correct answer.
            correct_answer_text = q.choices[q.answer]

            # --- Step 2: Shuffle the choice values ---
            # We extract the choice texts into a list and shuffle them in place.
            choice_texts = list(q.choices.values())
            random.shuffle(choice_texts)

            # --- Step 3: Rebuild choices and find the new answer key (Concise version) ---
            # Pair the original sorted keys with the newly shuffled texts using zip.
            keys = sorted(q.choices.keys())
            new_choices = dict(zip(keys, choice_texts))

            # Efficiently find the new key corresponding to the correct answer's text.
            new_answer_key = next(key for key, value in new_choices.items() if value == correct_answer_text)

            # --- Step 4: Create an updated, immutable copy of the question ---
            # Using `dataclasses.replace` is a clean, Pythonic way to create a new
            # instance with updated values, promoting immutability.
            randomized_q = replace(q, choices=new_choices, answer=new_answer_key)
            randomized_questions.append(randomized_q)
        except Exception as e:
            # If any error occurs (e.g., KeyError from a bad answer key),
            # print a warning and append the original, unmodified question.
            logger.warning(f"Warning: Could not randomize question '{q.id}'. Returning original. Error: {e}")
            randomized_questions.append(q)

    return randomized_questions