awacke1 commited on
Commit
d0d59f7
Β·
verified Β·
1 Parent(s): a98adb2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -0
app.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import openai
4
+ import pandas as pd
5
+ import time
6
+ from typing import List, Tuple
7
+ from uuid import uuid4
8
+
9
+ # πŸ”‘ Set the OpenAI API key from an environment variable
10
+ openai.api_key = os.getenv("OPENAI_API_KEY")
11
+
12
+ # πŸ†” Function to generate a unique session ID for caching
13
+ def get_session_id():
14
+ if 'session_id' not in st.session_state:
15
+ st.session_state.session_id = str(uuid4())
16
+ return st.session_state.session_id
17
+
18
+ # 🧠 STaR Algorithm Implementation
19
+ class SelfTaughtReasoner:
20
+ def __init__(self, model_engine="text-davinci-003"):
21
+ self.model_engine = model_engine
22
+ self.prompt_examples = []
23
+ self.iterations = 0
24
+ self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
25
+ self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
26
+ self.fine_tuned_model = None # πŸ—οΈ Placeholder for fine-tuned model
27
+
28
+ def add_prompt_example(self, problem: str, rationale: str, answer: str):
29
+ """
30
+ βž• Adds a prompt example to the few-shot examples.
31
+ """
32
+ self.prompt_examples.append({
33
+ 'Problem': problem,
34
+ 'Rationale': rationale,
35
+ 'Answer': answer
36
+ })
37
+
38
+ def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str:
39
+ """
40
+ πŸ“ Constructs the prompt for the OpenAI API call.
41
+ """
42
+ prompt = ""
43
+ for example in self.prompt_examples:
44
+ prompt += f"Problem: {example['Problem']}\n"
45
+ prompt += f"Rationale: {example['Rationale']}\n"
46
+ prompt += f"Answer: {example['Answer']}\n\n"
47
+
48
+ prompt += f"Problem: {problem}\n"
49
+ if include_answer:
50
+ prompt += f"Answer (as hint): {answer}\n"
51
+ prompt += "Rationale:"
52
+ return prompt
53
+
54
+ def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]:
55
+ """
56
+ πŸ€” Generates a rationale and answer for a given problem.
57
+ """
58
+ prompt = self.construct_prompt(problem)
59
+ try:
60
+ response = openai.Completion.create(
61
+ engine=self.model_engine,
62
+ prompt=prompt,
63
+ max_tokens=150,
64
+ temperature=0.7,
65
+ top_p=1,
66
+ frequency_penalty=0,
67
+ presence_penalty=0,
68
+ stop=["\n\n", "Problem:", "Answer:"]
69
+ )
70
+ rationale = response.choices[0].text.strip()
71
+ # πŸ“ Now generate the answer using the rationale
72
+ prompt += f" {rationale}\nAnswer:"
73
+ answer_response = openai.Completion.create(
74
+ engine=self.model_engine,
75
+ prompt=prompt,
76
+ max_tokens=10,
77
+ temperature=0,
78
+ top_p=1,
79
+ frequency_penalty=0,
80
+ presence_penalty=0,
81
+ stop=["\n", "\n\n", "Problem:"]
82
+ )
83
+ answer = answer_response.choices[0].text.strip()
84
+ return rationale, answer
85
+ except Exception as e:
86
+ st.error(f"❌ Error generating rationale and answer: {e}")
87
+ return "", ""
88
+
89
+ def rationalize(self, problem: str, correct_answer: str) -> Tuple[str, str]:
90
+ """
91
+ 🧐 Generates a rationale for a given problem using the correct answer as a hint.
92
+ """
93
+ prompt = self.construct_prompt(problem, include_answer=True, answer=correct_answer)
94
+ try:
95
+ response = openai.Completion.create(
96
+ engine=self.model_engine,
97
+ prompt=prompt,
98
+ max_tokens=150,
99
+ temperature=0.7,
100
+ top_p=1,
101
+ frequency_penalty=0,
102
+ presence_penalty=0,
103
+ stop=["\n\n", "Problem:", "Answer:"]
104
+ )
105
+ rationale = response.choices[0].text.strip()
106
+ # πŸ“ Now generate the answer using the rationale
107
+ prompt += f" {rationale}\nAnswer:"
108
+ answer_response = openai.Completion.create(
109
+ engine=self.model_engine,
110
+ prompt=prompt,
111
+ max_tokens=10,
112
+ temperature=0,
113
+ top_p=1,
114
+ frequency_penalty=0,
115
+ presence_penalty=0,
116
+ stop=["\n", "\n\n", "Problem:"]
117
+ )
118
+ answer = answer_response.choices[0].text.strip()
119
+ return rationale, answer
120
+ except Exception as e:
121
+ st.error(f"❌ Error during rationalization: {e}")
122
+ return "", ""
123
+
124
+ def fine_tune_model(self):
125
+ """
126
+ πŸ› οΈ Fine-tunes the model on the generated rationales.
127
+ This is a placeholder function as fine-tuning would require
128
+ training a new model which is beyond the scope of this app.
129
+ """
130
+ # πŸ”„ In actual implementation, you would prepare the training data
131
+ # and use OpenAI's fine-tuning API or other methods to fine-tune
132
+ # the model. For demonstration, we'll just simulate the process.
133
+ time.sleep(1) # ⏳ Simulate time taken for fine-tuning
134
+ self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}"
135
+ st.success(f"βœ… Model fine-tuned: {self.fine_tuned_model}")
136
+
137
+ def run_iteration(self, dataset: pd.DataFrame):
138
+ """
139
+ πŸ”„ Runs one iteration of the STaR process.
140
+ """
141
+ st.write(f"### Iteration {self.iterations + 1}")
142
+ progress_bar = st.progress(0)
143
+ total = len(dataset)
144
+ for idx, row in dataset.iterrows():
145
+ problem = row['Problem']
146
+ correct_answer = row['Answer']
147
+ # πŸ€– Generate rationale and answer
148
+ rationale, answer = self.generate_rationale_and_answer(problem)
149
+ is_correct = (answer.lower() == correct_answer.lower())
150
+ # πŸ“ Record the generated data
151
+ self.generated_data = self.generated_data.append({
152
+ 'Problem': problem,
153
+ 'Rationale': rationale,
154
+ 'Answer': answer,
155
+ 'Is_Correct': is_correct
156
+ }, ignore_index=True)
157
+ # ❌ If incorrect, perform rationalization
158
+ if not is_correct:
159
+ rationale, answer = self.rationalize(problem, correct_answer)
160
+ is_correct = (answer.lower() == correct_answer.lower())
161
+ if is_correct:
162
+ self.rationalized_data = self.rationalized_data.append({
163
+ 'Problem': problem,
164
+ 'Rationale': rationale,
165
+ 'Answer': answer,
166
+ 'Is_Correct': is_correct
167
+ }, ignore_index=True)
168
+ progress_bar.progress((idx + 1) / total)
169
+ # πŸ”§ Fine-tune the model on correct rationales
170
+ st.write("πŸ”„ Fine-tuning the model on correct rationales...")
171
+ self.fine_tune_model()
172
+ self.iterations += 1
173
+
174
+ # πŸ–₯️ Streamlit App
175
+ def main():
176
+ st.title("πŸ€– Self-Taught Reasoner (STaR) Demonstration")
177
+ st.write("""
178
+ This app demonstrates the **Self-Taught Reasoner (STaR)** workflow. Enter problems to solve, and see how the model generates rationales, filters correct answers, and fine-tunes itself iteratively.
179
+ """)
180
+
181
+ # 🧩 Initialize the Self-Taught Reasoner
182
+ if 'star' not in st.session_state:
183
+ st.session_state.star = SelfTaughtReasoner()
184
+
185
+ star = st.session_state.star
186
+
187
+ # πŸ“š Section to add few-shot prompt examples
188
+ st.header("πŸ”Ή Step 1: Add Few-Shot Prompt Examples")
189
+ st.write("Provide a few examples with problems, rationales, and answers to bootstrap the reasoning process.")
190
+
191
+ with st.form(key='prompt_form'):
192
+ example_problem = st.text_area("πŸ“ Example Problem", height=100)
193
+ example_rationale = st.text_area("🧠 Example Rationale", height=150)
194
+ example_answer = st.text_input("βœ… Example Answer")
195
+ submit_example = st.form_submit_button("βž• Add Example")
196
+
197
+ if submit_example:
198
+ if not example_problem or not example_rationale or not example_answer:
199
+ st.warning("⚠️ Please fill in all fields to add an example.")
200
+ else:
201
+ star.add_prompt_example(example_problem, example_rationale, example_answer)
202
+ st.success("πŸŽ‰ Example added.")
203
+
204
+ if star.prompt_examples:
205
+ st.subheader("πŸ“Œ Current Prompt Examples:")
206
+ for idx, example in enumerate(star.prompt_examples):
207
+ st.write(f"**πŸ“š Example {idx + 1}:**")
208
+ st.markdown(f"**Problem:**\n{example['Problem']}")
209
+ st.markdown(f"**Rationale:**\n{example['Rationale']}")
210
+ st.markdown(f"**Answer:**\n{example['Answer']}")
211
+
212
+ # πŸ” Section to input dataset
213
+ st.header("πŸ”Ή Step 2: Input Dataset")
214
+ st.write("Provide a dataset of problems and correct answers for the STaR process.")
215
+
216
+ dataset_input_method = st.radio("πŸ“₯ How would you like to input the dataset?", ("Manual Entry", "Upload CSV"))
217
+
218
+ if dataset_input_method == "Manual Entry":
219
+ with st.form(key='dataset_form'):
220
+ dataset_problems = st.text_area("πŸ“ Enter problems and answers in the format 'Problem | Answer', one per line.", height=200)
221
+ submit_dataset = st.form_submit_button("πŸ“€ Submit Dataset")
222
+
223
+ if submit_dataset:
224
+ if not dataset_problems:
225
+ st.warning("⚠️ Please enter at least one problem and answer.")
226
+ else:
227
+ dataset = []
228
+ lines = dataset_problems.strip().split('\n')
229
+ for line in lines:
230
+ if '|' in line:
231
+ problem, answer = line.split('|', 1)
232
+ dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()})
233
+ else:
234
+ st.error(f"❌ Invalid format in line: {line}")
235
+ if dataset:
236
+ st.session_state.dataset = pd.DataFrame(dataset)
237
+ st.success("βœ… Dataset loaded.")
238
+ else:
239
+ uploaded_file = st.file_uploader("πŸ“‚ Upload a CSV file with 'Problem' and 'Answer' columns.", type=['csv'])
240
+ if uploaded_file:
241
+ try:
242
+ st.session_state.dataset = pd.read_csv(uploaded_file)
243
+ if 'Problem' not in st.session_state.dataset.columns or 'Answer' not in st.session_state.dataset.columns:
244
+ st.error("❌ CSV must contain 'Problem' and 'Answer' columns.")
245
+ del st.session_state.dataset
246
+ else:
247
+ st.success("βœ… Dataset loaded.")
248
+ except Exception as e:
249
+ st.error(f"❌ Error loading CSV: {e}")
250
+
251
+ if 'dataset' in st.session_state:
252
+ st.subheader("πŸ“Š Current Dataset:")
253
+ st.dataframe(st.session_state.dataset.head())
254
+
255
+ # πŸƒβ€β™‚οΈ Section to run the STaR process
256
+ st.header("πŸ”Ή Step 3: Run STaR Process")
257
+ num_iterations = st.number_input("πŸ”’ Number of Iterations to Run:", min_value=1, max_value=10, value=1)
258
+ run_star = st.button("πŸš€ Run STaR")
259
+
260
+ if run_star:
261
+ if not star.prompt_examples:
262
+ st.warning("⚠️ Please add at least one prompt example before running STaR.")
263
+ elif not openai.api_key:
264
+ st.warning("⚠️ OpenAI API key not found. Please set the `OPENAI_API_KEY` environment variable.")
265
+ else:
266
+ for _ in range(num_iterations):
267
+ star.run_iteration(st.session_state.dataset)
268
+
269
+ st.header("πŸ“ˆ Results")
270
+ st.subheader("🧾 Generated Data")
271
+ st.dataframe(star.generated_data)
272
+
273
+ st.subheader("🧩 Rationalized Data")
274
+ st.dataframe(star.rationalized_data)
275
+
276
+ st.write("πŸ”„ The model has been fine-tuned iteratively. You can now test it with new problems.")
277
+
278
+ # πŸ§ͺ Section to test the fine-tuned model
279
+ st.header("πŸ”Ή Step 4: Test the Fine-Tuned Model")
280
+ test_problem = st.text_area("πŸ“ Enter a new problem to solve:", height=100)
281
+ test_button = st.button("βœ… Solve Problem")
282
+
283
+ if test_button:
284
+ if not test_problem:
285
+ st.warning("⚠️ Please enter a problem to solve.")
286
+ elif not star.fine_tuned_model:
287
+ st.warning("⚠️ The model has not been fine-tuned yet. Please run the STaR process first.")
288
+ else:
289
+ # πŸ€– For demonstration, we'll use the same generate_rationale_and_answer function
290
+ # In actual implementation, you would use the fine-tuned model
291
+ st.write("πŸ”„ Generating rationale and answer using the fine-tuned model...")
292
+ rationale, answer = star.generate_rationale_and_answer(test_problem)
293
+ st.subheader("🧠 Rationale:")
294
+ st.write(rationale)
295
+ st.subheader("βœ… Answer:")
296
+ st.write(answer)
297
+
298
+ # πŸ“ Footer
299
+ st.write("---")
300
+ st.write("πŸ› οΈ Developed as a demonstration of the **STaR** method.")
301
+
302
+ if __name__ == "__main__":
303
+ main()