prompt-plus-plus / prompt_refiner.py
baconnier's picture
Update prompt_refiner.py
88acf81 verified
raw
history blame
4.69 kB
import json
import re
from typing import Optional, Dict, Any, Union
from pydantic import BaseModel, Field, validator
from huggingface_hub import InferenceClient
from huggingface_hub.errors import HfHubHTTPError
from variables import *
class LLMResponse(BaseModel):
initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
refined_prompt: str = Field(..., description="The refined version of the prompt")
explanation_of_refinements: Union[str, list] = Field(..., description="Explanation of the refinements made")
response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
@validator('initial_prompt_evaluation', 'refined_prompt')
def clean_text_fields(cls, v):
if isinstance(v, str):
return v.strip().replace('\\n', '\n').replace('\\"', '"')
return v
@validator('explanation_of_refinements')
def clean_refinements(cls, v):
if isinstance(v, str):
return v.strip().replace('\\n', '\n').replace('\\"', '"')
elif isinstance(v, list):
return [item.strip().replace('\\n', '\n').replace('\\"', '"') if isinstance(item, str) else item for item in v]
return v
class PromptRefiner:
def __init__(self, api_token: str, meta_prompts):
self.client = InferenceClient(token=api_token, timeout=120)
self.meta_prompts = meta_prompts
def _sanitize_json_string(self, json_str: str) -> str:
"""Clean and prepare JSON string for parsing."""
json_str = json_str.lstrip('\ufeff').strip()
json_str = json_str.replace('\n', ' ')
json_str = re.sub(r'\s+', ' ', json_str)
json_str = json_str.replace('•', '*')
return json_str
def _extract_json_content(self, content: str) -> str:
"""Extract JSON content from between <json> tags."""
json_match = re.search(r'<json>\s*(.*?)\s*</json>', content, re.DOTALL)
if json_match:
return self._sanitize_json_string(json_match.group(1))
return content
def _parse_response(self, response_content: str) -> dict:
try:
# First attempt: Try to parse the entire content as JSON
cleaned_content = self._sanitize_json_string(response_content)
try:
parsed_json = json.loads(cleaned_content)
if isinstance(parsed_json, str):
parsed_json = json.loads(parsed_json)
return self._normalize_json_output(parsed_json)
except json.JSONDecodeError:
# Second attempt: Try to extract JSON from <json> tags
json_content = self._extract_json_content(response_content)
try:
parsed_json = json.loads(json_content)
if isinstance(parsed_json, str):
parsed_json = json.loads(parsed_json)
return self._normalize_json_output(parsed_json)
except json.JSONDecodeError:
# Third attempt: Try to parse using regex
return self._parse_with_regex(response_content)
except Exception as e:
print(f"Error parsing response: {str(e)}")
print(f"Raw content: {response_content}")
return self._create_error_dict(str(e))
def _normalize_json_output(self, json_output: dict) -> dict:
"""Normalize JSON output to expected format."""
return {
"initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
"refined_prompt": json_output.get("refined_prompt", ""),
"explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
"response_content": json_output
}
def _parse_with_regex(self, content: str) -> dict:
"""Parse content using regex patterns."""
output = {}
for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
match = re.search(pattern, content, re.DOTALL)
output[key] = match.group(1) if match else ""
output["response_content"] = content
return output
def _create_error_dict(self, error_message: str) -> dict:
"""Create standardized error response dictionary."""
return {
"initial_prompt_evaluation": f"Error parsing response: {error_message}",
"refined_prompt": "",
"explanation_of_refinements": "",
"response_content": {"error": error_message}
}
# Rest of your code remains the same...