File size: 4,685 Bytes
7ed59a1
 
88acf81
51d80c4
7ed59a1
 
51d80c4
 
 
 
 
88acf81
51d80c4
 
88acf81
51d80c4
 
 
 
7ed59a1
88acf81
 
 
 
 
 
 
 
7ed59a1
51d80c4
7ed59a1
 
 
88acf81
 
 
 
 
 
 
51d80c4
88acf81
 
 
 
 
 
7ed59a1
 
 
88acf81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ed59a1
88acf81
 
7ed59a1
88acf81
7ed59a1
88acf81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import re
from typing import Optional, Dict, Any, Union
from pydantic import BaseModel, Field, validator
from huggingface_hub import InferenceClient
from huggingface_hub.errors import HfHubHTTPError
from variables import *

class LLMResponse(BaseModel):
    initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
    refined_prompt: str = Field(..., description="The refined version of the prompt")
    explanation_of_refinements: Union[str, list] = Field(..., description="Explanation of the refinements made")
    response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")

    @validator('initial_prompt_evaluation', 'refined_prompt')
    def clean_text_fields(cls, v):
        if isinstance(v, str):
            return v.strip().replace('\\n', '\n').replace('\\"', '"')
        return v

    @validator('explanation_of_refinements')
    def clean_refinements(cls, v):
        if isinstance(v, str):
            return v.strip().replace('\\n', '\n').replace('\\"', '"')
        elif isinstance(v, list):
            return [item.strip().replace('\\n', '\n').replace('\\"', '"') if isinstance(item, str) else item for item in v]
        return v

class PromptRefiner:
    def __init__(self, api_token: str, meta_prompts):
        self.client = InferenceClient(token=api_token, timeout=120)
        self.meta_prompts = meta_prompts

    def _sanitize_json_string(self, json_str: str) -> str:
        """Clean and prepare JSON string for parsing."""
        json_str = json_str.lstrip('\ufeff').strip()
        json_str = json_str.replace('\n', ' ')
        json_str = re.sub(r'\s+', ' ', json_str)
        json_str = json_str.replace('•', '*')
        return json_str

    def _extract_json_content(self, content: str) -> str:
        """Extract JSON content from between <json> tags."""
        json_match = re.search(r'<json>\s*(.*?)\s*</json>', content, re.DOTALL)
        if json_match:
            return self._sanitize_json_string(json_match.group(1))
        return content

    def _parse_response(self, response_content: str) -> dict:
        try:
            # First attempt: Try to parse the entire content as JSON
            cleaned_content = self._sanitize_json_string(response_content)
            try:
                parsed_json = json.loads(cleaned_content)
                if isinstance(parsed_json, str):
                    parsed_json = json.loads(parsed_json)
                return self._normalize_json_output(parsed_json)
            except json.JSONDecodeError:
                # Second attempt: Try to extract JSON from <json> tags
                json_content = self._extract_json_content(response_content)
                try:
                    parsed_json = json.loads(json_content)
                    if isinstance(parsed_json, str):
                        parsed_json = json.loads(parsed_json)
                    return self._normalize_json_output(parsed_json)
                except json.JSONDecodeError:
                    # Third attempt: Try to parse using regex
                    return self._parse_with_regex(response_content)

        except Exception as e:
            print(f"Error parsing response: {str(e)}")
            print(f"Raw content: {response_content}")
            return self._create_error_dict(str(e))

    def _normalize_json_output(self, json_output: dict) -> dict:
        """Normalize JSON output to expected format."""
        return {
            "initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
            "refined_prompt": json_output.get("refined_prompt", ""),
            "explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
            "response_content": json_output
        }

    def _parse_with_regex(self, content: str) -> dict:
        """Parse content using regex patterns."""
        output = {}
        for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
            pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
            match = re.search(pattern, content, re.DOTALL)
            output[key] = match.group(1) if match else ""
        
        output["response_content"] = content
        return output

    def _create_error_dict(self, error_message: str) -> dict:
        """Create standardized error response dictionary."""
        return {
            "initial_prompt_evaluation": f"Error parsing response: {error_message}",
            "refined_prompt": "",
            "explanation_of_refinements": "",
            "response_content": {"error": error_message}
        }

    # Rest of your code remains the same...