File size: 6,501 Bytes
1a0cf07
 
8ff6b24
 
1a0cf07
8ff6b24
1a0cf07
8ff6b24
 
118d254
8ff6b24
ed548d7
 
 
 
 
 
 
 
 
 
 
68d33d7
 
 
 
7d3f3d3
ed548d7
7d3f3d3
1a0cf07
 
 
 
 
 
7d3f3d3
 
 
194e86c
1a0cf07
 
7d3f3d3
8ff6b24
7d3f3d3
 
 
194e86c
1a0cf07
 
 
 
7d3f3d3
 
 
194e86c
1a0cf07
 
ed548d7
 
 
 
dfd0ee3
 
 
 
 
 
 
 
 
 
 
 
 
7d3f3d3
1a0cf07
ed548d7
1a0cf07
 
 
 
 
 
7d3f3d3
ed548d7
 
1a0cf07
 
7d3f3d3
d38edd5
 
ed548d7
5507ae8
7d3f3d3
 
1a0cf07
8ff6b24
1a0cf07
ed548d7
 
0e87c57
 
1a0cf07
 
ed548d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e87c57
 
 
ed548d7
7d3f3d3
 
 
ed548d7
 
 
7d3f3d3
1a0cf07
 
ed548d7
 
1a0cf07
ed548d7
474fca3
 
1a0cf07
0e87c57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import json
import logging
import multiprocessing
import os

import gradio as gr

from swiftsage.agents import SwiftSage
from swiftsage.utils.commons import PromptTemplate, api_configs, setup_logging
from pkg_resources import resource_filename

#ENGINE = "Together"
#SWIFT_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
#FEEDBACK_MODEL_ID = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
#SAGE_MODEL_ID = "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"
 

# ENGINE = "Groq"
# SWIFT_MODEL_ID = "llama-3.1-8b-instant"
# FEEDBACK_MODEL_ID = "llama-3.1-8b-instant"
# SAGE_MODEL_ID = "llama-3.1-70b-versatile"

ENGINE = "SambaNova"
SWIFT_MODEL_ID = "Meta-Llama-3.1-8B-Instruct"
FEEDBACK_MODEL_ID = "Meta-Llama-3.1-70B-Instruct"
SAGE_MODEL_ID = "Meta-Llama-3.1-405B-Instruct"

def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, swift_temperature, swift_top_p, sage_temperature, sage_top_p, feedback_temperature, feedback_top_p):
    global ENGINE
    # Configuration for each LLM
    max_iterations = int(max_iterations)
    reward_threshold = int(reward_threshold)

    swift_config = {
        "model_id": swift_model_id,
        "api_config": api_configs[ENGINE],
        "temperature": float(swift_temperature),
        "top_p": float(swift_top_p),
        "max_tokens": 8192,
    }

    feedback_config = {
        "model_id": feedback_model_id,
        "api_config": api_configs[ENGINE],
        "temperature": float(feedback_temperature),
        "top_p": float(feedback_top_p),
        "max_tokens": 8192,
    }

    sage_config = {
        "model_id": sage_model_id,
        "api_config": api_configs[ENGINE],
        "temperature": float(sage_temperature),
        "top_p": float(sage_top_p),
        "max_tokens": 8192,
    }

    # specify the path to the prompt templates
    # prompt_template_dir = './swiftsage/prompt_templates'
    # prompt_template_dir = resource_filename('swiftsage', 'prompt_templates')

    # Try multiple locations for the prompt templates
    possible_paths = [
        resource_filename('swiftsage', 'prompt_templates'),
        os.path.join(os.path.dirname(__file__), '..', 'swiftsage', 'prompt_templates'),
        os.path.join(os.path.dirname(__file__), 'swiftsage', 'prompt_templates'),
        '/app/swiftsage/prompt_templates',  # For Docker environments
    ]

    prompt_template_dir = None
    for path in possible_paths:
        if os.path.exists(path):
            prompt_template_dir = path
            break

    dataset = [] 
    embeddings = [] # TODO: for retrieval augmentation (not implemented yet now)
    s2 = SwiftSage(
        dataset,
        embeddings,
        prompt_template_dir,
        swift_config,
        sage_config,
        feedback_config,
        use_retrieval=use_retrieval,
        start_with_sage=start_with_sage,
    )

    reasoning, solution, messages = s2.solve(problem, max_iterations, reward_threshold)
    reasoning = reasoning.replace("The generated code is:", "\n---\nThe generated code is:").strip()
    solution = solution.replace("Answer (from running the code):\n ", " ").strip()
    # generate HTML for the log messages and display them with wrap and a scroll bar and a max height in the code block with log style 
    
    log_messages = "<pre style='white-space: pre-wrap; max-height: 500px; overflow-y: scroll;'><code class='log'>" + "\n".join(messages) + "</code></pre>"
    return reasoning, solution, log_messages


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    # gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning")
    # use the html and center the title 
    gr.HTML("<h1 style='text-align: center;'>🏦 Bank Failure Predictor</h1>")
    gr.HTML("<span>This tool predicts the likelihood of bank failure based on balance sheet data.</span>")

    with gr.Row(): 
        swift_model_id = gr.Textbox(label="😄 Swift Model ID", value=SWIFT_MODEL_ID)
        feedback_model_id = gr.Textbox(label="🤔 Feedback Model ID", value=FEEDBACK_MODEL_ID)
        sage_model_id = gr.Textbox(label="😎 Sage Model ID", value=SAGE_MODEL_ID)
        # the following two should have a smaller width
    
    with gr.Accordion(label="⚙️ Advanced Options", open=False):
        with gr.Row():
            with gr.Column():
                max_iterations = gr.Textbox(label="Max Iterations", value="5")
                reward_threshold = gr.Textbox(label="feedback Threshold", value="8")
            # TODO: add top-p and temperature for each module for controlling 
            with gr.Column():
                top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9")
                temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.5")
            with gr.Column():
                top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9")
                temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.5")
            with gr.Column():
                top_p_feedback = gr.Textbox(label="Top-p for Feedback", value="0.9")
                temperature_feedback = gr.Textbox(label="Temperature for Feedback", value="0.5")

            use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False)
            start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False)

    problem = gr.Textbox(label="Input balance sheet data or parameters", value="Enter the bank's financial data here...", lines=5)

    solve_button = gr.Button("🔮 Predict Failure Chance")
    reasoning_output = gr.Textbox(label="Prediction steps with Code", interactive=False)
    solution_output = gr.Textbox(label="Prediction Result", interactive=False)

    # add a log display for showing the log messages
    with gr.Accordion(label="📜 Log Messages", open=False):
        log_output = gr.HTML("<p>No log messages yet.</p>")

    solve_button.click(
        solve_problem,
        inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, feedback_model_id, use_retrieval, start_with_sage, temperature_swift, top_p_swift, temperature_sage, top_p_sage, temperature_feedback, top_p_feedback],
        outputs=[reasoning_output, solution_output, log_output],
    )



if __name__ == '__main__':
    # make logs dir if it does not exist
    if not os.path.exists('logs'):
        os.makedirs('logs')
    multiprocessing.set_start_method('spawn')
    demo.launch(share=True, show_api=True)