Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# Copyright 2025 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import os | |
import shutil # For zipping the cache directory | |
from dataclasses import asdict | |
from functools import wraps | |
from pathlib import Path | |
from flask import Blueprint, request, jsonify, current_app, send_from_directory | |
import case_util | |
import config | |
from background_task_manager import BackgroundTaskManager | |
from models import ConversationTurn | |
# Use pathlib to construct the path to the images directory | |
# This is more robust than relative string paths. | |
IMAGE_DIR = Path(__file__).parent / 'data/images' | |
main_bp = Blueprint('main', __name__) | |
logger = logging.getLogger(__name__) | |
def log_full_cycle(response): | |
""" | |
This function runs after a request and has access to both | |
the incoming 'request' and the outgoing 'response'. | |
""" | |
if response.status_code != 200: | |
logger.error( | |
f"Request: {request.method} {request.path} | " | |
f"Response Status: {response.status}" | |
) | |
# You MUST return the response object | |
return response | |
def get_case(case_id): | |
available_reports = current_app.config["AVAILABLE_REPORTS"] | |
if case_id not in available_reports: | |
logger.error(f"Case Id {case_id} does not exist.") | |
return jsonify({"error": f"Case Id {case_id} does not exist."}), 400 | |
return jsonify(asdict(available_reports.get(case_id))) | |
def get_cases(): | |
available_reports = current_app.config["AVAILABLE_REPORTS"] | |
cases = available_reports.values() | |
return jsonify([asdict(case) for case in cases]) | |
def rag_initialization_complete_required(f): | |
def decorated_function(*args, **kwargs): | |
task_manager: BackgroundTaskManager = current_app.config.get('TASK_MANAGER') | |
# Check if RAG task has failed | |
if task_manager.get_error("rag_system"): | |
return jsonify({"error": "A critical background task failed. Check application logs."}), 500 | |
# Check if RAG task is still running | |
if not task_manager.is_task_done("rag_system"): | |
logger.warning("RAG initialization is running..") | |
response = jsonify( | |
{"status": "initializing", "message": "The system is starting up. Please try again in 60 seconds."}) | |
response.headers['Retry-After'] = 60 | |
return response, 503 | |
return f(*args, **kwargs) | |
return decorated_function | |
def get_all_questions(case_id): | |
"""Retrieves all questions for a given case ID, prioritizing cached data and generating live questions via LLM if necessary.""" | |
logger.info(f"Retrieve all questions for the given case '{case_id}'") | |
cache_manager = current_app.config['DEMO_CACHE'] | |
# 1. Check the cache first | |
if config.USE_CACHE and cache_manager: | |
all_mcqs_sequence = cache_manager.get_all_mcqs_sequence(case_id) | |
if len(all_mcqs_sequence) > 0: | |
logger.info(f"CACHE HIT for case '{case_id}'") | |
randomized_choices_mcqs = case_util.randomize_mcqs(all_mcqs_sequence) | |
return jsonify([asdict(mcq) for mcq in randomized_choices_mcqs]) | |
# 2. CACHE MISS: Generate live | |
logger.info( | |
f"CACHE MISS or cache disabled for case '{case_id}'. Generating live question...") | |
llm_client = current_app.config['LLM_CLIENT'] | |
if not llm_client: | |
logger.error( | |
"LLM client (REST API) not initialized. Cannot process request.") | |
return jsonify({"error": "LLM client not initialized."}), 500 | |
static_case_info = current_app.config['AVAILABLE_REPORTS'].get(case_id) | |
if not static_case_info: | |
logger.error(f"Static case data for id {case_id} not found.") | |
return jsonify({"error": f"Static case data for id {case_id} not found."}), 404 | |
rag_cache = current_app.config.get('RAG_CONTEXT_CACHE', {}) | |
prefetched_data = rag_cache.get(case_id, {}) | |
guideline_context_string = prefetched_data.get("context_string", "") | |
live_generated_mcqs = llm_client.generate_all_questions( | |
case_data=asdict(static_case_info), | |
guideline_context=guideline_context_string | |
) | |
if live_generated_mcqs is not None and len(live_generated_mcqs) > 0: | |
# 3. WRITE-THROUGH: Update the cache with the new question if caching is enabled | |
if config.USE_CACHE and cache_manager: | |
cache_manager.add_all_mcqs_to_case(case_id, live_generated_mcqs) | |
randomized_choices_mcqs = case_util.randomize_mcqs(live_generated_mcqs) | |
return jsonify([asdict(mcq) for mcq in randomized_choices_mcqs]), 200 | |
else: | |
logger.error("MCQ Sequence generation failed.") | |
return jsonify( | |
{"error": "MCQ Sequence generation failed."}), 500 | |
def get_case_summary(case_id): | |
""" | |
API endpoint to generate a case summary. | |
This version first attempts to load from cache, then falls back to building on the fly. | |
""" | |
data = request.get_json(force=True) | |
conversation_history_data = data.get('conversation_history') | |
if not conversation_history_data: | |
logger.error(f"Missing 'conversation_history' in request body for case {case_id}.") | |
return jsonify({"error": f"Missing 'conversation_history' in request body for case {case_id}."}), 400 | |
try: | |
summary_template = None | |
# First, try to get the summary from the cache, if caching is enabled | |
cache_manager = current_app.config.get('DEMO_CACHE') | |
if cache_manager: | |
summary_template = cache_manager.get_summary_template(case_id) | |
if summary_template: | |
logger.info(f"Summary template for case {case_id} found in cache.") | |
# If cache is disabled OR the template was not in the cache, build it now | |
if summary_template is None: | |
logger.warning(f"Summary template for case {case_id} not in cache or cache disabled. Building on the fly.") | |
static_case_info = current_app.config['AVAILABLE_REPORTS'].get(case_id) | |
if not static_case_info: | |
logger.error(f"Static case data for case {case_id} not found.") | |
return jsonify({"error": f"Static case data for case {case_id} not found."}), 404 | |
summary_template = case_util.build_summary_template(static_case_info, | |
current_app.config.get('RAG_CONTEXT_CACHE', {})) | |
if cache_manager: | |
cache_manager.save_summary_template(case_id, summary_template) | |
if summary_template is None: | |
logger.error(f"Summary template not found for case {case_id}.") | |
return jsonify({"error": f"An internal error occurred."}), 500 | |
# Once summary template is ready, we can programmatically populate rationale based on user's journey | |
conversation_turns = [ConversationTurn.from_dict(turn) for turn in conversation_history_data] | |
summary = case_util.populate_rationale(summary_template, conversation_turns) | |
return jsonify(asdict(summary)), 200 | |
except Exception as e: | |
logger.error(f"Error generating summary for case {case_id}: {e}", exc_info=True) | |
return jsonify({"error": f"An internal error occurred: {e}"}), 500 | |
def download_cache_zip(): | |
"""Zips the cache directory and serves it for download.""" | |
zip_filename = "rad-learn-cache.zip" | |
# Create the zip file in a temporary directory | |
# Using /tmp is common in containerized environments | |
temp_dir = "/tmp" | |
zip_base_path = os.path.join(temp_dir, "rad-learn-cache") # shutil adds .zip | |
zip_filepath = zip_base_path + ".zip" | |
# Ensure the cache directory exists before trying to zip it | |
cache_manager = current_app.config.get('DEMO_CACHE') | |
cache_directory = cache_manager.cache_directory | |
if not os.path.isdir(cache_directory): | |
logger.error(f"Cache directory not found at {cache_directory}") | |
return jsonify({"error": f"Cache directory not found on server: {cache_directory}"}), 500 | |
try: | |
logger.info(f"Creating zip archive of cache directory: {cache_directory} to {zip_filepath}") | |
shutil.make_archive( | |
zip_base_path, # This is the base name, shutil adds the .zip extension | |
"zip", | |
cache_directory, # This is the root directory to archive | |
) | |
logger.info("Zip archive created successfully.") | |
# Send the file and then clean it up | |
return send_from_directory(temp_dir, zip_filename, as_attachment=True) | |
except Exception as e: | |
logger.error(f"Error creating or sending zip archive of cache directory: {e}", exc_info=True) | |
return jsonify({"error": f"Error creating or sending zip archive: {e}"}), 500 | |