awacke1 commited on
Commit
dcdf02a
Β·
verified Β·
1 Parent(s): a96f1fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -0
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import openai
4
+ import pandas as pd
5
+ from uuid import uuid4
6
+ import time
7
+
8
+ # πŸ”‘ Set the OpenAI API key from an environment variable
9
+ openai.api_key = os.getenv("OPENAI_API_KEY")
10
+
11
+ # πŸ†” Function to generate a unique session ID for caching
12
+ def get_session_id():
13
+ if 'session_id' not in st.session_state:
14
+ st.session_state.session_id = str(uuid4())
15
+ return st.session_state.session_id
16
+
17
+ # πŸ“š Predefined examples loaded from Python dictionaries
18
+ EXAMPLES = [
19
+ {
20
+ 'Problem': 'What is deductive reasoning?',
21
+ 'Rationale': 'Deductive reasoning starts from general premises to arrive at a specific conclusion.',
22
+ 'Answer': 'It involves deriving specific conclusions from general premises.'
23
+ },
24
+ {
25
+ 'Problem': 'What is inductive reasoning?',
26
+ 'Rationale': 'Inductive reasoning involves drawing generalizations based on specific observations.',
27
+ 'Answer': 'It involves forming general rules from specific examples.'
28
+ },
29
+ {
30
+ 'Problem': 'Explain abductive reasoning.',
31
+ 'Rationale': 'Abductive reasoning finds the most likely explanation for incomplete observations.',
32
+ 'Answer': 'It involves finding the best possible explanation.'
33
+ }
34
+ ]
35
+
36
+ # 🧠 STaR Algorithm Implementation
37
+ class SelfTaughtReasoner:
38
+ def __init__(self, model_engine="text-davinci-003"):
39
+ self.model_engine = model_engine
40
+ self.prompt_examples = EXAMPLES # Initialize with predefined examples
41
+ self.iterations = 0
42
+ self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
43
+ self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct'])
44
+ self.fine_tuned_model = None # πŸ—οΈ Placeholder for fine-tuned model
45
+
46
+ def add_prompt_example(self, problem: str, rationale: str, answer: str):
47
+ """
48
+ βž• Adds a prompt example to the few-shot examples.
49
+ """
50
+ self.prompt_examples.append({
51
+ 'Problem': problem,
52
+ 'Rationale': rationale,
53
+ 'Answer': answer
54
+ })
55
+
56
+ def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str:
57
+ """
58
+ πŸ“ Constructs the prompt for the OpenAI API call.
59
+ """
60
+ prompt = ""
61
+ for example in self.prompt_examples:
62
+ prompt += f"Problem: {example['Problem']}\n"
63
+ prompt += f"Rationale: {example['Rationale']}\n"
64
+ prompt += f"Answer: {example['Answer']}\n\n"
65
+
66
+ prompt += f"Problem: {problem}\n"
67
+ if include_answer:
68
+ prompt += f"Answer (as hint): {answer}\n"
69
+ prompt += "Rationale:"
70
+ return prompt
71
+
72
+ def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]:
73
+ """
74
+ πŸ€” Generates a rationale and answer for a given problem.
75
+ """
76
+ prompt = self.construct_prompt(problem)
77
+ try:
78
+ response = openai.Completion.create(
79
+ engine=self.model_engine,
80
+ prompt=prompt,
81
+ max_tokens=150,
82
+ temperature=0.7,
83
+ top_p=1,
84
+ frequency_penalty=0,
85
+ presence_penalty=0,
86
+ stop=["\n\n", "Problem:", "Answer:"]
87
+ )
88
+ rationale = response.choices[0].text.strip()
89
+ # πŸ“ Now generate the answer using the rationale
90
+ prompt += f" {rationale}\nAnswer:"
91
+ answer_response = openai.Completion.create(
92
+ engine=self.model_engine,
93
+ prompt=prompt,
94
+ max_tokens=10,
95
+ temperature=0,
96
+ top_p=1,
97
+ frequency_penalty=0,
98
+ presence_penalty=0,
99
+ stop=["\n", "\n\n", "Problem:"]
100
+ )
101
+ answer = answer_response.choices[0].text.strip()
102
+ return rationale, answer
103
+ except Exception as e:
104
+ st.error(f"❌ Error generating rationale and answer: {e}")
105
+ return "", ""
106
+
107
+ def fine_tune_model(self):
108
+ """
109
+ πŸ› οΈ Fine-tunes the model on the generated rationales.
110
+ """
111
+ time.sleep(1) # ⏳ Simulate time taken for fine-tuning
112
+ self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}"
113
+ st.success(f"βœ… Model fine-tuned: {self.fine_tuned_model}")
114
+
115
+ def run_iteration(self, dataset: pd.DataFrame):
116
+ """
117
+ πŸ”„ Runs one iteration of the STaR process.
118
+ """
119
+ st.write(f"### Iteration {self.iterations + 1}")
120
+ progress_bar = st.progress(0)
121
+ total = len(dataset)
122
+ for idx, row in dataset.iterrows():
123
+ problem = row['Problem']
124
+ correct_answer = row['Answer']
125
+ # πŸ€– Generate rationale and answer
126
+ rationale, answer = self.generate_rationale_and_answer(problem)
127
+ is_correct = (answer.lower() == correct_answer.lower())
128
+ # πŸ“ Record the generated data
129
+ self.generated_data = self.generated_data.append({
130
+ 'Problem': problem,
131
+ 'Rationale': rationale,
132
+ 'Answer': answer,
133
+ 'Is_Correct': is_correct
134
+ }, ignore_index=True)
135
+ # ❌ If incorrect, perform rationalization
136
+ if not is_correct:
137
+ rationale, answer = self.rationalize(problem, correct_answer)
138
+ is_correct = (answer.lower() == correct_answer.lower())
139
+ if is_correct:
140
+ self.rationalized_data = self.rationalized_data.append({
141
+ 'Problem': problem,
142
+ 'Rationale': rationale,
143
+ 'Answer': answer,
144
+ 'Is_Correct': is_correct
145
+ }, ignore_index=True)
146
+ progress_bar.progress((idx + 1) / total)
147
+ # πŸ”§ Fine-tune the model on correct rationales
148
+ st.write("πŸ”„ Fine-tuning the model on correct rationales...")
149
+ self.fine_tune_model()
150
+ self.iterations += 1
151
+
152
+ # πŸ–₯️ Streamlit App
153
+ def main():
154
+ st.title("πŸ€– Self-Taught Reasoner (STaR) Demonstration")
155
+
156
+ # 🧩 Initialize the Self-Taught Reasoner
157
+ if 'star' not in st.session_state:
158
+ st.session_state.star = SelfTaughtReasoner()
159
+
160
+ star = st.session_state.star
161
+
162
+ # πŸ“ Wide format layout
163
+ col1, col2 = st.columns([1, 2]) # Column widths: col1 for input, col2 for display
164
+
165
+ # Step 1: Few-Shot Prompt Examples
166
+ with col1:
167
+ st.header("Step 1: Add Few-Shot Prompt Examples")
168
+ st.write("Choose an example from the dropdown or input your own.")
169
+
170
+ selected_example = st.selectbox(
171
+ "Select a predefined example",
172
+ [f"Example {i + 1}: {ex['Problem']}" for i, ex in enumerate(EXAMPLES)]
173
+ )
174
+
175
+ # Prefill with selected example
176
+ example_idx = int(selected_example.split(" ")[1]) - 1
177
+ example_problem = EXAMPLES[example_idx]['Problem']
178
+ example_rationale = EXAMPLES[example_idx]['Rationale']
179
+ example_answer = EXAMPLES[example_idx]['Answer']
180
+
181
+ st.text_area("Problem", value=example_problem, height=50, key="example_problem")
182
+ st.text_area("Rationale", value=example_rationale, height=100, key="example_rationale")
183
+ st.text_input("Answer", value=example_answer, key="example_answer")
184
+
185
+ if st.button("Add Example"):
186
+ star.add_prompt_example(st.session_state.example_problem, st.session_state.example_rationale, st.session_state.example_answer)
187
+ st.success("Example added successfully!")
188
+
189
+ with col2:
190
+ # Display current prompt examples
191
+ if star.prompt_examples:
192
+ st.subheader("Current Prompt Examples:")
193
+ for idx, example in enumerate(star.prompt_examples):
194
+ st.write(f"**Example {idx + 1}:**")
195
+ st.write(f"Problem: {example['Problem']}")
196
+ st.write(f"Rationale: {example['Rationale']}")
197
+ st.write(f"Answer: {example['Answer']}")
198
+
199
+ # Step 2: Input Dataset
200
+ st.header("Step 2: Input Dataset")
201
+ dataset_input_method = st.radio("How would you like to input the dataset?", ("Manual Entry", "Upload CSV"))
202
+
203
+ if dataset_input_method == "Manual Entry":
204
+ dataset_problems = st.text_area("Enter problems and answers in the format 'Problem | Answer', one per line.", height=200)
205
+ if st.button("Submit Dataset"):
206
+ dataset = []
207
+ lines = dataset_problems.strip().split('\n')
208
+ for line in lines:
209
+ if '|' in line:
210
+ problem, answer = line.split('|', 1)
211
+ dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()})
212
+ st.session_state.dataset = pd.DataFrame(dataset)
213
+ st.success("Dataset loaded.")
214
+
215
+ else:
216
+ uploaded_file = st.file_uploader("Upload a CSV file with 'Problem' and 'Answer' columns.", type=['csv'])
217
+ if uploaded_file:
218
+ st.session_state.dataset = pd.read_csv(uploaded_file)
219
+ st.success("Dataset loaded.")
220
+
221
+ if 'dataset' in st.session_state:
222
+ st.subheader("Current Dataset:")
223
+ st.dataframe(st.session_state.dataset.head())
224
+
225
+ # Step 3: Run STaR Process
226
+ st.header("Step 3: Run STaR Process")
227
+ num_iterations = st.number_input("Number of Iterations to Run:", min_value=1, max_value=10, value=1)
228
+ if st.button("Run STaR"):
229
+ for _ in range(num_iterations):
230
+ star.run_iteration(st.session_state.dataset)
231
+
232
+ st.header("Results")
233
+ st.subheader("Generated Data")
234
+ st.dataframe(star.generated_data)
235
+
236
+ st.subheader("Rationalized Data")
237
+ st.dataframe(star.rationalized_data)
238
+
239
+ st.write("The model has been fine-tuned iteratively.")
240
+
241
+ # Step 4: Test the Fine-Tuned Model
242
+ st.header("Step 4: Test the Fine-Tuned Model")
243
+ test_problem = st.text_area("Enter a new problem to solve:", height=100)
244
+ if st.button("Solve Problem"):
245
+ if not test_problem:
246
+ st.warning("Please enter a problem to solve.")
247
+ else:
248
+ rationale, answer = star.generate_rationale_and_answer(test_problem)
249
+ st.subheader("Rationale:")
250
+ st.write(rationale)
251
+ st.subheader("Answer:")
252
+ st.write(answer)
253
+
254
+ # Footer with custom HTML/JS component
255
+ st.markdown("---")
256
+ st.write("Developed as a demonstration of the STaR method with enhanced Streamlit capabilities.")
257
+ st.components.v1.html("""
258
+ <div style="text-align: center; margin-top: 20px;">
259
+ <h3>πŸš€ Boost Your AI Reasoning with STaR! πŸš€</h3>
260
+ </div>
261
+ """)
262
+
263
+ if __name__ == "__main__":
264
+ main()