baconnier commited on
Commit
88acf81
·
verified ·
1 Parent(s): 51d80c4

Update prompt_refiner.py

Browse files
Files changed (1) hide show
  1. prompt_refiner.py +75 -142
prompt_refiner.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  import re
3
- from typing import Optional, Dict, Any
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
@@ -9,163 +9,96 @@ from variables import *
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
- explanation_of_refinements: str = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
- @validator('initial_prompt_evaluation', 'refined_prompt', 'explanation_of_refinements')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
 
 
 
 
 
 
 
 
21
  class PromptRefiner:
22
  def __init__(self, api_token: str, meta_prompts):
23
  self.client = InferenceClient(token=api_token, timeout=120)
24
  self.meta_prompts = meta_prompts
25
-
26
- def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> tuple:
27
- try:
28
- selected_meta_prompt = self.meta_prompts.get(
29
- meta_prompt_choice,
30
- self.meta_prompts["star"]
31
- )
32
-
33
- messages = [
34
- {
35
- "role": "system",
36
- "content": 'You are an expert at refining and extending prompts. Given a basic prompt, provide a more relevant and detailed prompt.'
37
- },
38
- {
39
- "role": "user",
40
- "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
41
- }
42
- ]
43
-
44
- response = self.client.chat_completion(
45
- model=prompt_refiner_model,
46
- messages=messages,
47
- max_tokens=3000,
48
- temperature=0.8
49
- )
50
-
51
- response_content = response.choices[0].message.content.strip()
52
- result = self._parse_response(response_content)
53
-
54
- # Create and validate LLMResponse
55
- llm_response = LLMResponse(**result)
56
-
57
- return (
58
- llm_response.initial_prompt_evaluation,
59
- llm_response.refined_prompt,
60
- llm_response.explanation_of_refinements,
61
- llm_response.dict()
62
- )
63
 
64
- except HfHubHTTPError as e:
65
- return self._create_error_response("Model timeout. Please try again later.")
66
- except Exception as e:
67
- return self._create_error_response(f"Unexpected error: {str(e)}")
 
 
 
68
 
69
- def _create_error_response(self, error_message: str) -> tuple:
70
- error_response = LLMResponse(
71
- initial_prompt_evaluation=f"Error: {error_message}",
72
- refined_prompt="The selected model is currently unavailable.",
73
- explanation_of_refinements="An error occurred during processing.",
74
- response_content={"error": error_message}
75
- )
76
- return (
77
- error_response.initial_prompt_evaluation,
78
- error_response.refined_prompt,
79
- error_response.explanation_of_refinements,
80
- error_response.dict()
81
- )
82
 
83
  def _parse_response(self, response_content: str) -> dict:
84
  try:
85
- # First attempt: Try to extract JSON from <json> tags
86
- json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
87
- if json_match:
88
- json_str = json_match.group(1)
89
- json_str = re.sub(r'\n\s*', ' ', json_str)
90
- json_str = json_str.replace('"', '\\"')
91
- json_output = json.loads(f'"{json_str}"')
92
-
93
- if isinstance(json_output, str):
94
- json_output = json.loads(json_output)
95
-
96
- return {
97
- "initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
98
- "refined_prompt": json_output.get("refined_prompt", ""),
99
- "explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
100
- "response_content": json_output
101
- }
102
-
103
- # Second attempt: Try to extract fields using regex
104
- output = {}
105
- for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
106
- pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
107
- match = re.search(pattern, response_content, re.DOTALL)
108
- output[key] = match.group(1) if match else ""
109
-
110
- output["response_content"] = response_content
111
- return output
112
 
113
- except (json.JSONDecodeError, ValueError) as e:
114
- print(f"Error parsing response: {e}")
115
  print(f"Raw content: {response_content}")
116
- return {
117
- "initial_prompt_evaluation": "Error parsing response",
118
- "refined_prompt": "",
119
- "explanation_of_refinements": str(e),
120
- "response_content": str(e)
121
- }
122
 
123
- def apply_prompt(self, prompt: str, model: str) -> str:
124
- try:
125
- messages = [
126
- {
127
- "role": "system",
128
- "content": """You are a markdown formatting expert. Format your responses with proper spacing and structure following these rules:
129
-
130
- 1. Paragraph Spacing:
131
- - Add TWO blank lines between major sections (##)
132
- - Add ONE blank line between subsections (###)
133
- - Add ONE blank line between paragraphs within sections
134
- - Add ONE blank line before and after lists
135
- - Add ONE blank line before and after code blocks
136
- - Add ONE blank line before and after blockquotes
137
-
138
- 2. Section Formatting:
139
- # Title
140
-
141
- ## Major Section
142
-
143
- [blank line]
144
- Content paragraph 1
145
- [blank line]
146
- Content paragraph 2
147
- [blank line]"""
148
- },
149
- {
150
- "role": "user",
151
- "content": prompt
152
- }
153
- ]
154
-
155
- response = self.client.chat_completion(
156
- model=model,
157
- messages=messages,
158
- max_tokens=3000,
159
- temperature=0.8,
160
- stream=True
161
- )
162
-
163
- full_response = ""
164
- for chunk in response:
165
- if chunk.choices[0].delta.content is not None:
166
- full_response += chunk.choices[0].delta.content
167
-
168
- return full_response.replace('\n\n', '\n').strip()
169
-
170
- except Exception as e:
171
- return f"Error: {str(e)}"
 
1
  import json
2
  import re
3
+ from typing import Optional, Dict, Any, Union
4
  from pydantic import BaseModel, Field, validator
5
  from huggingface_hub import InferenceClient
6
  from huggingface_hub.errors import HfHubHTTPError
 
9
  class LLMResponse(BaseModel):
10
  initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
  refined_prompt: str = Field(..., description="The refined version of the prompt")
12
+ explanation_of_refinements: Union[str, list] = Field(..., description="Explanation of the refinements made")
13
  response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
+ @validator('initial_prompt_evaluation', 'refined_prompt')
16
  def clean_text_fields(cls, v):
17
  if isinstance(v, str):
18
  return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
  return v
20
 
21
+ @validator('explanation_of_refinements')
22
+ def clean_refinements(cls, v):
23
+ if isinstance(v, str):
24
+ return v.strip().replace('\\n', '\n').replace('\\"', '"')
25
+ elif isinstance(v, list):
26
+ return [item.strip().replace('\\n', '\n').replace('\\"', '"') if isinstance(item, str) else item for item in v]
27
+ return v
28
+
29
  class PromptRefiner:
30
  def __init__(self, api_token: str, meta_prompts):
31
  self.client = InferenceClient(token=api_token, timeout=120)
32
  self.meta_prompts = meta_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ def _sanitize_json_string(self, json_str: str) -> str:
35
+ """Clean and prepare JSON string for parsing."""
36
+ json_str = json_str.lstrip('\ufeff').strip()
37
+ json_str = json_str.replace('\n', ' ')
38
+ json_str = re.sub(r'\s+', ' ', json_str)
39
+ json_str = json_str.replace('•', '*')
40
+ return json_str
41
 
42
+ def _extract_json_content(self, content: str) -> str:
43
+ """Extract JSON content from between <json> tags."""
44
+ json_match = re.search(r'<json>\s*(.*?)\s*</json>', content, re.DOTALL)
45
+ if json_match:
46
+ return self._sanitize_json_string(json_match.group(1))
47
+ return content
 
 
 
 
 
 
 
48
 
49
  def _parse_response(self, response_content: str) -> dict:
50
  try:
51
+ # First attempt: Try to parse the entire content as JSON
52
+ cleaned_content = self._sanitize_json_string(response_content)
53
+ try:
54
+ parsed_json = json.loads(cleaned_content)
55
+ if isinstance(parsed_json, str):
56
+ parsed_json = json.loads(parsed_json)
57
+ return self._normalize_json_output(parsed_json)
58
+ except json.JSONDecodeError:
59
+ # Second attempt: Try to extract JSON from <json> tags
60
+ json_content = self._extract_json_content(response_content)
61
+ try:
62
+ parsed_json = json.loads(json_content)
63
+ if isinstance(parsed_json, str):
64
+ parsed_json = json.loads(parsed_json)
65
+ return self._normalize_json_output(parsed_json)
66
+ except json.JSONDecodeError:
67
+ # Third attempt: Try to parse using regex
68
+ return self._parse_with_regex(response_content)
 
 
 
 
 
 
 
 
 
69
 
70
+ except Exception as e:
71
+ print(f"Error parsing response: {str(e)}")
72
  print(f"Raw content: {response_content}")
73
+ return self._create_error_dict(str(e))
 
 
 
 
 
74
 
75
+ def _normalize_json_output(self, json_output: dict) -> dict:
76
+ """Normalize JSON output to expected format."""
77
+ return {
78
+ "initial_prompt_evaluation": json_output.get("initial_prompt_evaluation", ""),
79
+ "refined_prompt": json_output.get("refined_prompt", ""),
80
+ "explanation_of_refinements": json_output.get("explanation_of_refinements", ""),
81
+ "response_content": json_output
82
+ }
83
+
84
+ def _parse_with_regex(self, content: str) -> dict:
85
+ """Parse content using regex patterns."""
86
+ output = {}
87
+ for key in ["initial_prompt_evaluation", "refined_prompt", "explanation_of_refinements"]:
88
+ pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
89
+ match = re.search(pattern, content, re.DOTALL)
90
+ output[key] = match.group(1) if match else ""
91
+
92
+ output["response_content"] = content
93
+ return output
94
+
95
+ def _create_error_dict(self, error_message: str) -> dict:
96
+ """Create standardized error response dictionary."""
97
+ return {
98
+ "initial_prompt_evaluation": f"Error parsing response: {error_message}",
99
+ "refined_prompt": "",
100
+ "explanation_of_refinements": "",
101
+ "response_content": {"error": error_message}
102
+ }
103
+
104
+ # Rest of your code remains the same...