ReqSolver-API / app.py
YchKhan's picture
Update app.py
9adc4f5 verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import os
import json
# ---- Requirements Models ----
class RequirementInfo(BaseModel):
"""Represents an extracted requirement info."""
context: str = Field(..., description="Context for the requirement.")
requirement: str = Field(..., description="The requirement itself.")
document: Optional[str] = Field('', description="The document the requirement is extracted from.")
class ReqGroupingCategory(BaseModel):
"""Represents the category of requirements grouped together."""
id: int = Field(..., description="ID of the grouping category")
title: str = Field(..., description="Title given to the grouping category")
requirements: List[RequirementInfo] = Field(
..., description="List of grouped requirements")
class ReqGroupingResponse(BaseModel):
categories: List[ReqGroupingCategory]
model_config = {
"json_schema_extra": {
"examples": [
{
"categories": [
{
"id": 1,
"title": "Robustness & Resilience",
"requirements": [
{
"context": "Subject to the operator’s policy and regulatory requirements, an AI service provided by the 6G network or UE shall be able to provide information regarding robustness scores.",
"requirement": "Expose an overall robustness score to service consumers."
},
{
"context": "The network can respond with a missings-resilience score for the used AI application.",
"requirement": "Report a missings-resilience score that quantifies tolerance to missing or corrupted input data."
}
]
},
{
"id": 2,
"title": "Environmental Sustainability",
"requirements": [
{
"context": "What is the level of energy consumption per information request (per inference run of the AI).",
"requirement": "Report energy consumption per 1 000 inference requests."
},
{
"context": "What is the portion of renewable energy of the energy consumed by the AI service.",
"requirement": "Report the share of renewable energy in the AI service’s power mix."
},
{
"context": "The application sets a requirement for the energy consumption needed for inference.",
"requirement": "Allow the consumer to specify a maximum energy-per-inference threshold that must be met."
}
]
},
{
"id": 3,
"title": "Explainability & Transparency",
"requirements": [
{
"context": "Local explanation: The aim is to explain individual outputs provided by an ML model.",
"requirement": "Support local explanations for single predictions."
},
{
"context": "Global explanation: The aim is to explain the whole ML model behaviour.",
"requirement": "Support global explanations that describe overall model logic."
},
{
"context": "Third-party applications have explanations of AI agent reasoning.",
"requirement": "Provide on-demand reasoning for predictions to authorised consumers."
}
]
},
{
"id": 4,
"title": "Service Discovery & Criteria Negotiation",
"requirements": [
{
"context": "A subscriber density prediction service is offered via an exposure interface.",
"requirement": "Ensure AI services are discoverable through the exposure interface."
},
{
"context": "The application requests further profile information regarding robustness, sustainability and explainability aspects.",
"requirement": "Expose a profile that includes robustness, sustainability and explainability metrics."
},
{
"context": "A service consumer shall be able to provide service criteria regarding robustness, environmental sustainability, and explainability when requesting an AI service to the 6G system.",
"requirement": "Accept consumer-supplied criteria for robustness, sustainability and explainability."
},
{
"context": "In some cases the AI service could not be fulfilled, or could fall back to a non-AI mechanism if the criteria cannot be met.",
"requirement": "Support rejection or graceful fallback when agreed criteria are not satisfied."
}
]
}
]
}
]
}
}
# ---- Solution Models ----
class SolutionModel(BaseModel):
Context: str = Field(..., description="Full context provided for this category.")
Requirements: List[str] = Field(..., description="List of each requirement as string.")
Problem_Description: str = Field(..., alias="Problem Description",
description="Description of the problem being solved.")
Solution_Description: str = Field(..., alias="Solution Description",
description="Detailed description of the solution.")
References: Optional[str] = Field('', description="References to documents used for the solution.")
class Config:
allow_population_by_field_name = True # Enables alias handling on input/output
class SolutionsResponse(BaseModel):
solutions: List[SolutionModel]
# ---- FastAPI app ----
app = FastAPI()
# ---- LLM Integration ----
def ask_llm(user_message, model='compound-beta', system_prompt="You are a helpful assistant"):
from groq import Groq # Import here so the app starts without the module if needed
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": user_message
}
],
stream=False,
)
ai_reply = response.choices[0].message.content
return ai_reply
solution_prompt = """You are an expert system designer. Your task is to find a solution that addresses as many of the provided requirements as possible, while carefully considering the given context.
** Use live web search tool to browse internet for reliable sources. **
Respond strictly in the following JSON format:
{
"Context": "<Insert the full context provided for this category>",
"Requirements": [
"<List each requirement clearly as a string item>"
],
"Problem Description": "<Describe the problem the solution is solving without introducing it>",
"Solution Description": "<Explain the proposed solution, detailing how it meets each of the specified requirements and aligns with the given context. Prioritize completeness and practicality.>",
"References": "<The references to the documents used to write the solution>"
}
text
⚠️ Rules:
Do not omit any part of the JSON structure.
Replace newline characters with \"\\n\" (double backslash-n for JSON)
Ensure all fields are present, even if empty.
The solution must aim to maximize requirement satisfaction while respecting the context.
Provide a clear and well-reasoned description of how your solution addresses each requirement.
"""
# ---- Endpoints ----
@app.get("/")
def greet_json():
return {"Hey!": "SoA Finder is running!!"}
@app.post("/find_solutions", response_model=SolutionsResponse)
async def find_solutions(requirements: ReqGroupingResponse):
solutions = []
for category in requirements.categories:
category_title = category.title
category_requirements = category.requirements
# Compose the LLM prompt
problem_description = solution_prompt
problem_description += f"Category title: {category_title}\n"
context_list = []
requirement_list = []
for req_item in category_requirements:
context_list.append(f"- Context: {req_item.context}")
requirement_list.append(f"- Requirement: {req_item.requirement}")
problem_description += "Contexts:\n" + "\n".join(context_list) + "\n\n"
problem_description += "Requirements:\n" + "\n".join(requirement_list)
llm_response = ask_llm(problem_description)
print(f"Solution for '{category_title}' category:")
print(llm_response)
# Clean and parse the LLM response
try:
# Remove code blocks if present
cleaned = llm_response.strip()
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
if cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
cleaned = cleaned.strip()
# Replace double backslashes with single if needed for parsing
cleaned = cleaned.replace('\\\\n', '\\n')
parsed = json.loads(cleaned)
# Use alias-aware population for SolutionModel
solution_obj = SolutionModel.parse_obj(parsed)
solutions.append(solution_obj)
except Exception as e:
# Append error info as a solution model (helps debug)
error_solution = SolutionModel(
Context="",
Requirements=[],
Problem_Description=f"Failed to parse LLM response: {str(e)}",
Solution_Description=f"Original LLM output: {llm_response}",
References=""
)
solutions.append(error_solution)
return SolutionsResponse(solutions=solutions)