baconnier commited on
Commit
046b80d
·
verified ·
1 Parent(s): 8de21bb

Update prompt_refiner.py

Browse files
Files changed (1) hide show
  1. prompt_refiner.py +200 -190
prompt_refiner.py CHANGED
@@ -7,207 +7,217 @@ from huggingface_hub.errors import HfHubHTTPError
7
  from variables import *
8
 
9
  class LLMResponse(BaseModel):
10
- initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
- refined_prompt: str = Field(..., description="The refined version of the prompt")
12
- explanation_of_refinements: Union[str, List[str]] = Field(..., description="Explanation of the refinements made")
13
- response_content: Optional[Dict[str, Any]] = Field(None, description="Raw response content")
14
 
15
- @validator('initial_prompt_evaluation', 'refined_prompt')
16
- def clean_text_fields(cls, v):
17
- if isinstance(v, str):
18
- return v.strip().replace('\\n', '\n').replace('\\"', '"')
19
- return v
 
 
 
20
 
21
- @validator('explanation_of_refinements')
22
- def clean_refinements(cls, v):
23
- if isinstance(v, str):
24
- return v.strip().replace('\\n', '\n').replace('\\"', '"')
25
- elif isinstance(v, list):
26
- return [item.strip().replace('\\n', '\n').replace('\\"', '"').replace('•', '-')
27
- for item in v if isinstance(item, str)]
28
- return v
 
 
 
 
 
 
29
 
30
  class PromptRefiner:
31
- def __init__(self, api_token: str, meta_prompts: dict):
32
- self.client = InferenceClient(token=api_token, timeout=120)
33
- self.meta_prompts = meta_prompts
34
 
35
- def _clean_json_string(self, content: str) -> str:
36
- """Clean and prepare JSON string for parsing."""
37
- content = content.replace('•', '-') # Replace bullet points
38
- content = re.sub(r'\s+', ' ', content) # Normalize whitespace
39
- content = content.replace('\\"', '"') # Fix escaped quotes
40
- return content.strip()
41
 
42
- def _parse_response(self, response_content: str) -> dict:
43
- """Parse the LLM response with enhanced error handling."""
44
- try:
45
- # Extract content between <json> tags
46
- json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
47
- if json_match:
48
- json_str = self._clean_json_string(json_match.group(1))
49
- try:
50
- # Try parsing the cleaned JSON
51
- parsed_json = json.loads(json_str)
52
- if isinstance(parsed_json, str):
53
- parsed_json = json.loads(parsed_json)
54
-
55
- return {
56
- "initial_prompt_evaluation": parsed_json.get("initial_prompt_evaluation", ""),
57
- "refined_prompt": parsed_json.get("refined_prompt", ""),
58
- "explanation_of_refinements": parsed_json.get("explanation_of_refinements", ""),
59
- "response_content": parsed_json
60
- }
61
- except json.JSONDecodeError:
62
- # If JSON parsing fails, try regex parsing
63
- return self._parse_with_regex(json_str)
64
-
65
- # If no JSON tags found, try regex parsing
66
- return self._parse_with_regex(response_content)
67
 
68
- except Exception as e:
69
- print(f"Error parsing response: {str(e)}")
70
- print(f"Raw content: {response_content}")
71
- return self._create_error_dict(str(e))
72
 
73
- def _parse_with_regex(self, content: str) -> dict:
74
- """Parse content using regex when JSON parsing fails."""
75
- output = {}
76
-
77
- # Handle explanation_of_refinements list format
78
- refinements_match = re.search(r'"explanation_of_refinements":\s*\[(.*?)\]', content, re.DOTALL)
79
- if refinements_match:
80
- refinements_str = refinements_match.group(1)
81
- refinements = [
82
- item.strip().strip('"').strip("'").replace('•', '-')
83
- for item in re.findall(r'[•"]([^"•]+)[•"]', refinements_str)
84
- ]
85
- output["explanation_of_refinements"] = refinements
86
- else:
87
- # Try single string format
88
- pattern = r'"explanation_of_refinements":\s*"(.*?)"(?:,|\})'
89
- match = re.search(pattern, content, re.DOTALL)
90
- output["explanation_of_refinements"] = match.group(1).strip() if match else ""
91
 
92
- # Extract other fields
93
- for key in ["initial_prompt_evaluation", "refined_prompt"]:
94
- pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
95
- match = re.search(pattern, content, re.DOTALL)
96
- output[key] = match.group(1).strip() if match else ""
97
-
98
- output["response_content"] = content
99
- return output
 
100
 
101
- def _create_error_dict(self, error_message: str) -> dict:
102
- """Create a standardized error response dictionary."""
103
- return {
104
- "initial_prompt_evaluation": f"Error parsing response: {error_message}",
105
- "refined_prompt": "",
106
- "explanation_of_refinements": "",
107
- "response_content": {"error": error_message}
108
- }
109
 
110
- def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> Tuple[str, str, str, dict]:
111
- """Refine the given prompt using the selected meta prompt."""
112
- try:
113
- selected_meta_prompt = self.meta_prompts.get(
114
- meta_prompt_choice,
115
- self.meta_prompts["star"]
116
- )
117
-
118
- messages = [
119
- {
120
- "role": "system",
121
- "content": 'You are an expert at refining and extending prompts. Given a basic prompt, provide a more relevant and detailed prompt.'
122
- },
123
- {
124
- "role": "user",
125
- "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
126
- }
127
- ]
128
-
129
- response = self.client.chat_completion(
130
- model=prompt_refiner_model,
131
- messages=messages,
132
- max_tokens=3000,
133
- temperature=0.8
134
- )
135
-
136
- response_content = response.choices[0].message.content.strip()
137
- result = self._parse_response(response_content)
138
-
139
- try:
140
- llm_response = LLMResponse(**result)
141
- return (
142
- llm_response.initial_prompt_evaluation,
143
- llm_response.refined_prompt,
144
- llm_response.explanation_of_refinements,
145
- llm_response.dict()
146
- )
147
- except Exception as e:
148
- print(f"Error creating LLMResponse: {e}")
149
- return self._create_error_response(f"Error validating response: {str(e)}")
150
 
151
- except HfHubHTTPError as e:
152
- return self._create_error_response("Model timeout. Please try again later.")
153
- except Exception as e:
154
- return self._create_error_response(f"Unexpected error: {str(e)}")
155
 
156
- def _create_error_response(self, error_message: str) -> Tuple[str, str, str, dict]:
157
- """Create a standardized error response tuple."""
158
- return (
159
- f"Error: {error_message}",
160
- "The selected model is currently unavailable.",
161
- "An error occurred during processing.",
162
- {"error": error_message}
163
- )
164
 
165
- def apply_prompt(self, prompt: str, model: str) -> str:
166
- """Apply formatting to the prompt using the specified model."""
167
- try:
168
- messages = [
169
- {
170
- "role": "system",
171
- "content": """You are a markdown formatting expert. Format your responses with proper spacing and structure following these rules:
172
- 1. Paragraph Spacing:
173
- - Add TWO blank lines between major sections (##)
174
- - Add ONE blank line between subsections (###)
175
- - Add ONE blank line between paragraphs within sections
176
- - Add ONE blank line before and after lists
177
- - Add ONE blank line before and after code blocks
178
- - Add ONE blank line before and after blockquotes
179
-
180
- 2. Section Formatting:
181
- # Title
182
-
183
- ## Major Section
184
-
185
- [blank line]
186
- Content paragraph 1
187
- [blank line]
188
- Content paragraph 2
189
- [blank line]"""
190
- },
191
- {
192
- "role": "user",
193
- "content": prompt
194
- }
195
- ]
196
-
197
- response = self.client.chat_completion(
198
- model=model,
199
- messages=messages,
200
- max_tokens=3000,
201
- temperature=0.8,
202
- stream=True
203
- )
204
-
205
- full_response = ""
206
- for chunk in response:
207
- if chunk.choices[0].delta.content is not None:
208
- full_response += chunk.choices[0].delta.content
209
-
210
- return full_response.replace('\n\n', '\n').strip()
211
-
212
- except Exception as e:
213
- return f"Error: {str(e)}"
 
7
  from variables import *
8
 
9
  class LLMResponse(BaseModel):
10
+ initial_prompt_evaluation: str = Field(..., description="Evaluation of the initial prompt")
11
+ refined_prompt: str = Field(..., description="The refined version of the prompt")
12
+ explanation_of_refinements: Union[str, List[str]] = Field(..., description="Explanation of the refinements made")
13
+ response_content: Optional[Union[Dict[str, Any], str]] = Field(None, description="Raw response content")
14
 
15
+ @validator('response_content', pre=True)
16
+ def validate_response_content(cls, v):
17
+ if isinstance(v, str):
18
+ try:
19
+ return json.loads(v)
20
+ except json.JSONDecodeError:
21
+ return {"raw_content": v}
22
+ return v
23
 
24
+ @validator('initial_prompt_evaluation', 'refined_prompt')
25
+ def clean_text_fields(cls, v):
26
+ if isinstance(v, str):
27
+ return v.strip().replace('\\n', '\n').replace('\\"', '"')
28
+ return v
29
+
30
+ @validator('explanation_of_refinements')
31
+ def clean_refinements(cls, v):
32
+ if isinstance(v, str):
33
+ return v.strip().replace('\\n', '\n').replace('\\"', '"')
34
+ elif isinstance(v, list):
35
+ return [item.strip().replace('\\n', '\n').replace('\\"', '"').replace('•', '-')
36
+ for item in v if isinstance(item, str)]
37
+ return v
38
 
39
  class PromptRefiner:
40
+ def __init__(self, api_token: str, meta_prompts: dict):
41
+ self.client = InferenceClient(token=api_token, timeout=120)
42
+ self.meta_prompts = meta_prompts
43
 
44
+ def _clean_json_string(self, content: str) -> str:
45
+ """Clean and prepare JSON string for parsing."""
46
+ content = content.replace('•', '-') # Replace bullet points
47
+ content = re.sub(r'\s+', ' ', content) # Normalize whitespace
48
+ content = content.replace('\\"', '"') # Fix escaped quotes
49
+ return content.strip()
50
 
51
+ def _parse_response(self, response_content: str) -> dict:
52
+ """Parse the LLM response with enhanced error handling."""
53
+ try:
54
+ # Extract content between <json> tags
55
+ json_match = re.search(r'<json>\s*(.*?)\s*</json>', response_content, re.DOTALL)
56
+ if json_match:
57
+ json_str = self._clean_json_string(json_match.group(1))
58
+ try:
59
+ # Try parsing the cleaned JSON
60
+ parsed_json = json.loads(json_str)
61
+ if isinstance(parsed_json, str):
62
+ parsed_json = json.loads(parsed_json)
63
+
64
+ return {
65
+ "initial_prompt_evaluation": parsed_json.get("initial_prompt_evaluation", ""),
66
+ "refined_prompt": parsed_json.get("refined_prompt", ""),
67
+ "explanation_of_refinements": parsed_json.get("explanation_of_refinements", ""),
68
+ "response_content": parsed_json if isinstance(parsed_json, dict) else {"raw_content": parsed_json}
69
+ }
70
+ except json.JSONDecodeError:
71
+ # If JSON parsing fails, try regex parsing
72
+ return self._parse_with_regex(json_str)
73
+
74
+ # If no JSON tags found, try regex parsing
75
+ return self._parse_with_regex(response_content)
76
 
77
+ except Exception as e:
78
+ print(f"Error parsing response: {str(e)}")
79
+ print(f"Raw content: {response_content}")
80
+ return self._create_error_dict(str(e))
81
 
82
+ def _parse_with_regex(self, content: str) -> dict:
83
+ """Parse content using regex when JSON parsing fails."""
84
+ output = {}
85
+
86
+ # Handle explanation_of_refinements list format
87
+ refinements_match = re.search(r'"explanation_of_refinements":\s*$(.*?)$', content, re.DOTALL)
88
+ if refinements_match:
89
+ refinements_str = refinements_match.group(1)
90
+ refinements = [
91
+ item.strip().strip('"').strip("'").replace('•', '-')
92
+ for item in re.findall(r'[•"]([^"•]+)[•"]', refinements_str)
93
+ ]
94
+ output["explanation_of_refinements"] = refinements
95
+ else:
96
+ # Try single string format
97
+ pattern = r'"explanation_of_refinements":\s*"(.*?)"(?:,|\})'
98
+ match = re.search(pattern, content, re.DOTALL)
99
+ output["explanation_of_refinements"] = match.group(1).strip() if match else ""
100
 
101
+ # Extract other fields
102
+ for key in ["initial_prompt_evaluation", "refined_prompt"]:
103
+ pattern = rf'"{key}":\s*"(.*?)"(?:,|\}})'
104
+ match = re.search(pattern, content, re.DOTALL)
105
+ output[key] = match.group(1).strip() if match else ""
106
+
107
+ # Store the original content in a structured way
108
+ output["response_content"] = {"raw_content": content}
109
+ return output
110
 
111
+ def _create_error_dict(self, error_message: str) -> dict:
112
+ """Create a standardized error response dictionary."""
113
+ return {
114
+ "initial_prompt_evaluation": f"Error parsing response: {error_message}",
115
+ "refined_prompt": "",
116
+ "explanation_of_refinements": "",
117
+ "response_content": {"error": error_message}
118
+ }
119
 
120
+ def refine_prompt(self, prompt: str, meta_prompt_choice: str) -> Tuple[str, str, str, dict]:
121
+ """Refine the given prompt using the selected meta prompt."""
122
+ try:
123
+ selected_meta_prompt = self.meta_prompts.get(
124
+ meta_prompt_choice,
125
+ self.meta_prompts["star"]
126
+ )
127
+
128
+ messages = [
129
+ {
130
+ "role": "system",
131
+ "content": 'You are an expert at refining and extending prompts. Given a basic prompt, provide a more relevant and detailed prompt.'
132
+ },
133
+ {
134
+ "role": "user",
135
+ "content": selected_meta_prompt.replace("[Insert initial prompt here]", prompt)
136
+ }
137
+ ]
138
+
139
+ response = self.client.chat_completion(
140
+ model=prompt_refiner_model,
141
+ messages=messages,
142
+ max_tokens=3000,
143
+ temperature=0.8
144
+ )
145
+
146
+ response_content = response.choices[0].message.content.strip()
147
+ result = self._parse_response(response_content)
148
+
149
+ try:
150
+ llm_response = LLMResponse(**result)
151
+ return (
152
+ llm_response.initial_prompt_evaluation,
153
+ llm_response.refined_prompt,
154
+ llm_response.explanation_of_refinements,
155
+ llm_response.dict()
156
+ )
157
+ except Exception as e:
158
+ print(f"Error creating LLMResponse: {e}")
159
+ return self._create_error_response(f"Error validating response: {str(e)}")
160
 
161
+ except HfHubHTTPError as e:
162
+ return self._create_error_response("Model timeout. Please try again later.")
163
+ except Exception as e:
164
+ return self._create_error_response(f"Unexpected error: {str(e)}")
165
 
166
+ def _create_error_response(self, error_message: str) -> Tuple[str, str, str, dict]:
167
+ """Create a standardized error response tuple."""
168
+ return (
169
+ f"Error: {error_message}",
170
+ "The selected model is currently unavailable.",
171
+ "An error occurred during processing.",
172
+ {"error": error_message}
173
+ )
174
 
175
+ def apply_prompt(self, prompt: str, model: str) -> str:
176
+ """Apply formatting to the prompt using the specified model."""
177
+ try:
178
+ messages = [
179
+ {
180
+ "role": "system",
181
+ "content": """You are a markdown formatting expert. Format your responses with proper spacing and structure following these rules:
182
+ 1. Paragraph Spacing:
183
+ - Add TWO blank lines between major sections (##)
184
+ - Add ONE blank line between subsections (###)
185
+ - Add ONE blank line between paragraphs within sections
186
+ - Add ONE blank line before and after lists
187
+ - Add ONE blank line before and after code blocks
188
+ - Add ONE blank line before and after blockquotes
189
+
190
+ 2. Section Formatting:
191
+ # Title
192
+
193
+ ## Major Section
194
+
195
+ [blank line]
196
+ Content paragraph 1
197
+ [blank line]
198
+ Content paragraph 2
199
+ [blank line]"""
200
+ },
201
+ {
202
+ "role": "user",
203
+ "content": prompt
204
+ }
205
+ ]
206
+
207
+ response = self.client.chat_completion(
208
+ model=model,
209
+ messages=messages,
210
+ max_tokens=3000,
211
+ temperature=0.8,
212
+ stream=True
213
+ )
214
+
215
+ full_response = ""
216
+ for chunk in response:
217
+ if chunk.choices[0].delta.content is not None:
218
+ full_response += chunk.choices[0].delta.content
219
+
220
+ return full_response.replace('\n\n', '\n').strip()
221
+
222
+ except Exception as e:
223
+ return f"Error: {str(e)}"