liuze commited on
Commit
0ee12fc
·
1 Parent(s): 8181cd8
Files changed (2) hide show
  1. app.py +243 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ from typing import Dict, Tuple, Optional, Any
7
+ from generate import load_chat_model
8
+ from pydantic import BaseModel, ConfigDict
9
+ from jinja2 import Template
10
+ from typing_extensions import Literal
11
+ import yaml
12
+ from generate.chat_completion import ChatCompletionModel
13
+ from dotenv import load_dotenv
14
+ import logging
15
+ load_dotenv()
16
+
17
+
18
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ EXAMPLES = [
23
+ {
24
+ "title": "Example 1: Basic Improvement",
25
+ "prompt": "From the following list of Wikipedia article titles, identify which article this sentence came from.\nRespond with just the article title and nothing else.\n\nArticle titles:\n{{titles}}\n\nSentence to classify:\n{{sentence}}",
26
+ "feedback": "The categories can only be: technology, culture, history, other."
27
+ }
28
+ # TODO: Add more examples here
29
+ ]
30
+
31
+ class Prompt(BaseModel):
32
+ name: str
33
+ description: Optional[str] = None
34
+ temperature: Optional[float] = None
35
+ max_tokens: Optional[int] = None
36
+ prompt: str
37
+ format: Optional[Literal["fstring", "jinja2"]] = "fstring"
38
+ model_config = ConfigDict(extra='allow')
39
+
40
+ @classmethod
41
+ def from_yaml(cls, yaml_string: str) -> 'Prompt':
42
+ yaml_data = yaml.safe_load(yaml_string)
43
+ return cls(**yaml_data)
44
+
45
+ def render(self, context: Dict[str, Any] = {}) -> str:
46
+ if self.format == "jinja2":
47
+ return self._render_jinja(context)
48
+ else:
49
+ return self._render_fstring(context)
50
+
51
+ def _render_fstring(self, context: Dict[str, Any]) -> str:
52
+ try:
53
+ return eval(f"f'''{self.prompt}'''", context)
54
+ except Exception as e:
55
+ raise ValueError(f"Error rendering f-string: {e}")
56
+
57
+ def _render_jinja(self, context: Dict[str, Any]) -> str:
58
+ try:
59
+ template = Template(self.prompt)
60
+ return template.render(**context)
61
+ except Exception as e:
62
+ raise ValueError(f"Error rendering Jinja template: {e}")
63
+
64
+ def get_extra_field(self, field_name: str, default: Any = None) -> Any:
65
+ """Safely get an extra field with a default value if it doesn't exist."""
66
+ return getattr(self, field_name, default)
67
+
68
+ def demo_card_click(e: gr.EventData):
69
+ index = e._data['component']['index']
70
+ return DEMO_LIST[index]['description']
71
+
72
+ def analyze_prompt(target_prompt: str, feedback: str, language: str, model: ChatCompletionModel) -> str:
73
+ prompt_template = Prompt.from_yaml(open('./prompts/analyze_prompt.yaml', 'r').read())
74
+ prompt = prompt_template.render({
75
+ "prompt": target_prompt,
76
+ "feedback": feedback,
77
+ "language": language
78
+ })
79
+ output = model.generate(prompt, max_tokens=prompt_template.max_tokens, temperature=prompt_template.temperature)
80
+ output = output.message.content
81
+ logger.info(f"Prompt Analysis: {output}")
82
+ report = re.findall(r"<report>(.*?)</report>", output, flags=re.DOTALL)[0]
83
+ return report
84
+
85
+ def optimize_prompt(report: str, target_prompt: str, language: str, model: ChatCompletionModel) -> str:
86
+ prompt_template = Prompt.from_yaml(open('./prompts/optimize_prompt.yaml', 'r').read())
87
+ prompt = prompt_template.render({
88
+ "report": report,
89
+ "prompt": target_prompt,
90
+ "language": language
91
+ })
92
+ output = model.generate(prompt, max_tokens=prompt_template.max_tokens, temperature=prompt_template.temperature)
93
+ output = output.message.content
94
+ logger.info(f"Prompt Result: {output}")
95
+ return output
96
+
97
+ def process_analysis(target_prompt: str, feedback: str, language: str, model_id: str) -> Tuple[str, int]:
98
+ """First step: Analyze the prompt and return the report"""
99
+ model = load_chat_model(model_id)
100
+ report = analyze_prompt(target_prompt, feedback, language, model)
101
+ return report
102
+
103
+ def process_optimization(report: str, target_prompt: str, language: str, model_id: str) -> Tuple[str, int]:
104
+ """Second step: Generate the optimized prompt based on the analysis"""
105
+ model = load_chat_model(model_id)
106
+ optimization = optimize_prompt(report, target_prompt, language, model)
107
+ return optimization
108
+
109
+ def save_results(prompt: str, feedback: str, analysis: str, optimization: str, lang: str, model: str):
110
+ data = {
111
+ "original_prompt": prompt,
112
+ "feedback": feedback,
113
+ "analysis": analysis,
114
+ "optimized_prompt": optimization,
115
+ "language": lang,
116
+ "model": model
117
+ }
118
+
119
+ # 时间戳文件名
120
+ timestamp = time.strftime("%Y%m%d%H%M%S")
121
+ temp_file = f"results_{timestamp}.json"
122
+ with open(temp_file, "w", encoding="utf-8") as f:
123
+ json.dump(data, f, ensure_ascii=False, indent=2)
124
+ return temp_file
125
+
126
+ def update_api_key(key: str, value: str):
127
+ """Update the API key environment variable"""
128
+ if key and value:
129
+ os.environ[key] = value
130
+ logger.info(f"API key updated: {key}={value[:5]}...")
131
+
132
+ def fill_example(example_idx: int):
133
+ """Auto-fill prompt and feedback with example content"""
134
+ if 0 <= int(example_idx) < len(EXAMPLES):
135
+ example = EXAMPLES[int(example_idx)]
136
+ return example["prompt"], example["feedback"]
137
+ return "", ""
138
+
139
+ with gr.Blocks() as demo:
140
+ with gr.Row():
141
+ # Left Column
142
+ with gr.Column(scale=1):
143
+ with gr.Row():
144
+ key_input = gr.Dropdown(
145
+ ["DEEPSEEK_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY"],
146
+ label="API Key",
147
+ )
148
+ value_input = gr.Textbox(
149
+ show_label=True,
150
+ placeholder="Your API value...",
151
+ type="password",
152
+ label="API Value"
153
+ )
154
+ value_input.change(
155
+ fn=update_api_key,
156
+ inputs=[key_input, value_input],
157
+ outputs=[]
158
+ )
159
+ with gr.Row():
160
+ language = gr.Dropdown(
161
+ choices=["中文", "English"],
162
+ value="中文",
163
+ label="Language/语言"
164
+ )
165
+ model = gr.Dropdown(
166
+ choices=["deepseek/deepseek-chat", "anthropic/claude-3-5-sonnet-latest", "openai/gpt-4o"],
167
+ value="deepseek/deepseek-chat",
168
+ label="Model/模型"
169
+ )
170
+
171
+
172
+ with gr.Row():
173
+ prompt_input = gr.Textbox(
174
+ label="Original Prompt/待优化的Prompt",
175
+ placeholder="Enter your prompt here...",
176
+ lines=5
177
+ )
178
+ with gr.Row():
179
+ feedback_input = gr.Textbox(
180
+ label="Feedback/反馈",
181
+ placeholder="Enter your feedback here...",
182
+ lines=3
183
+ )
184
+ # Add example module
185
+ with gr.Row():
186
+ example_dropdown = gr.Dropdown(
187
+ choices=[(example["title"], i) for i, example in enumerate(EXAMPLES)],
188
+ label="Examples/示例",
189
+ value=None
190
+ )
191
+ # Add example auto-fill event
192
+ example_dropdown.change(
193
+ fn=fill_example,
194
+ inputs=[example_dropdown],
195
+ outputs=[prompt_input, feedback_input]
196
+ )
197
+
198
+ with gr.Row():
199
+ submit_btn = gr.Button("Optimize")
200
+ download_btn = gr.Button("Download")
201
+
202
+ # Right Column
203
+ with gr.Column(scale=1):
204
+ analysis_output = gr.Textbox(
205
+ label="Prompt Analysis",
206
+ lines=10,
207
+ )
208
+ optimization_output = gr.Textbox(
209
+ label="Optimized Prompt",
210
+ lines=20,
211
+ )
212
+ copy_btn = gr.Button("Copy to Clipboard")
213
+
214
+ # First submit handler for analysis
215
+ submit_btn.click(
216
+ fn=process_analysis,
217
+ inputs=[prompt_input, feedback_input, language, model],
218
+ outputs=analysis_output
219
+ ).then(
220
+ fn=process_optimization,
221
+ inputs=[analysis_output, prompt_input, language, model],
222
+ outputs=optimization_output
223
+ )
224
+
225
+ # Copy button handler
226
+ copy_btn.click(
227
+ fn=None,
228
+ inputs=[optimization_output],
229
+ outputs=None,
230
+ js="(text) => {navigator.clipboard.writeText(text); return null;}"
231
+ )
232
+
233
+ download_btn.click(
234
+ fn=save_results,
235
+ inputs=[prompt_input, feedback_input, analysis_output, optimization_output, language, model],
236
+ outputs=[
237
+ gr.File(label="Download Results")
238
+ ]
239
+ )
240
+
241
+
242
+ if __name__ == "__main__":
243
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ jinja2
2
+ generate-core