File size: 8,536 Bytes
cc69c66
c567880
 
 
e1dd4c7
5ba04fc
 
c567880
cc69c66
c567880
 
5ba04fc
c567880
 
cc69c66
 
c567880
cc69c66
 
c567880
 
cc69c66
5ba04fc
a0b3bbf
c567880
5ba04fc
c567880
 
 
 
 
68d7e91
c567880
 
 
 
2a76ab8
 
 
 
 
 
 
 
 
 
 
 
c567880
2a76ab8
 
 
 
c567880
 
68d7e91
c567880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc69c66
 
 
 
 
c567880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68d7e91
 
 
 
 
c567880
 
e1dd4c7
c567880
 
 
 
 
68d7e91
c567880
 
cc69c66
 
 
 
e1dd4c7
cc69c66
e1dd4c7
cc69c66
 
 
 
c567880
 
e1dd4c7
c567880
e1dd4c7
c567880
 
 
 
 
 
 
 
 
 
 
 
 
5ba04fc
 
 
671bc96
c567880
 
 
cc69c66
c567880
cc69c66
c567880
 
cc69c66
c567880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1dd4c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c567880
e1dd4c7
c567880
 
2a76ab8
 
 
 
e1dd4c7
2a76ab8
c567880
cc69c66
 
 
 
e1dd4c7
cc69c66
 
 
c567880
cc69c66
e1dd4c7
c567880
5ba04fc
c567880
7fe6ec9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import copy
import json
import os
import zipfile
import pandas as pd

import gradio as gr
import spaces
import torch
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from schema_to_sql import dd_to_sql
from utils import (
    MAX_NEW_TOKENS,
    TEMPERATURE,
    create_summary_tables,
    get_example_ai_model_output_many,
    get_example_ai_model_output_simple,
    get_prompt_with_files_uploaded,
)
from parsing import try_parsing_actual_model_output

LOCAL_DIR = "tsvs"
ZIP_PATH = "tsvs.zip"

AUTH_TOKEN = os.environ.get("HF_TOKEN", False)

BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
LORA_ADAPTER = "uc-ctds/data-model-curator"

MAX_RETRY_ATTEMPTS = 3

print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

model_loaded = False

try:
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL, token=AUTH_TOKEN, device_map="auto"
    )
    model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, token=AUTH_TOKEN)
    model = model.to("cuda")
    model = model.eval()

    peft_config = PeftConfig.from_pretrained(LORA_ADAPTER, token=AUTH_TOKEN)
    model = PeftModel.from_pretrained(model, LORA_ADAPTER, token=AUTH_TOKEN)

    model_loaded = True
except Exception:
    print("No HF_TOKEN found. Ensure you follow setup instructions!")
    # continue on so setup instructions load


@spaces.GPU(duration=450)
def run_llm_inference(model_prompt):
    retry_count = 1

    print("Tokenizing Input")
    inputs = tokenizer(model_prompt, return_tensors="pt")
    inputs = inputs.to(model.device)
    prompt_length = inputs["input_ids"].shape[1]

    print("Generating Initial Response")
    outputs = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
    )

    # Decode and parse output
    print("Decoding output")
    output_data_model = tokenizer.decode(outputs[0][prompt_length:])
    output_data_model = output_data_model.split("<|eot_id|>")[0]
    print(output_data_model)

    # Test output for JSON schema validity
    try:
        test_respone = json.loads(output_data_model)
        valid_output = True
        print("Yay - model passed")
        return output_data_model

    except:
        valid_output = False

    while (valid_output is False) and (retry_count <= MAX_RETRY_ATTEMPTS):

        print(
            f"Attempt {retry_count} did not generate a proper JSON output, proceeding to attempt {retry_count+1} of {MAX_RETRY_ATTEMPTS+1}"
        )
        retry_count += 1

        # Try generating a new response
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
        )

        output_data_model = tokenizer.decode(outputs[0][prompt_length:])
        output_data_model = output_data_model.split("<|eot_id|>")[0]
        print(output_data_model)

        parsed_output_data_model = try_parsing_actual_model_output(output_data_model)
        if "error" not in parsed_output_data_model:
            output_data_model = copy.deepcopy(parsed_output_data_model)

        # Test output for JSON schema validity
        try:
            json.loads(output_data_model)
            valid_output = True
            print("Yay - model passed")
            return output_data_model
        except:
            valid_output = False

    # Handle cases when the model fails to generate a proper json schema
    if (valid_output is False) and (retry_count > MAX_RETRY_ATTEMPTS):
        print(
            "Failed To Generate a Proper Schema, try checking the prompt or input TSVs and running again"
        )
        output_data_model = '{"nodes": [{"name": "Attempt Failed - Check logs for suggested next steps", "links": []}]}'

    return output_data_model


def gen_output_from_files_uploaded(filepaths: list[str] = None):
    prompt_from_tsv_upload = get_prompt_with_files_uploaded(filepaths)

    # Run model to get model response (model_response is a string that needs to be loaded to json)
    model_response = run_llm_inference(prompt_from_tsv_upload)
    model_response_json = json.loads(model_response)

    # Create SQL Code
    try:
        sql, validation = dd_to_sql(model_response_json)
    except Exception:
        print(f"Errors converting to SQL, skipping...")
        sql = ""

    # Create Summary Table
    nodes_df, properties_df = pd.DataFrame(), pd.DataFrame()
    try:
        nodes_df, properties_df = create_summary_tables(model_response_json)
    except Exception as exc:
        print(f"summary table creation failed: {exc}")

    return model_response, sql, nodes_df, properties_df


def gen_output_from_example_simple():
    model_response = get_example_ai_model_output_simple()
    model_response_json = json.loads(model_response)
    sql, validation = dd_to_sql(model_response_json)
    nodes_df, properties_df = create_summary_tables(model_response_json)

    return model_response, sql, nodes_df, properties_df


def gen_output_from_example_many():
    model_response = get_example_ai_model_output_many()
    model_response_json = json.loads(model_response)
    sql, validation = dd_to_sql(model_response_json)
    nodes_df, properties_df = create_summary_tables(model_response_json)

    return model_response, sql, nodes_df, properties_df


def zip_tsvs():
    tsv_files = [f for f in os.listdir(LOCAL_DIR) if f.endswith(".tsv")]
    if not tsv_files:
        return None

    with zipfile.ZipFile(ZIP_PATH, "w") as zipf:
        for file in tsv_files:
            file_path = os.path.join(LOCAL_DIR, file)
            zipf.write(file_path, arcname=file)

    return ZIP_PATH


with gr.Blocks() as demo:
    gr.Markdown("# Demonstration of Llama Data Model Generator built with Meta Llama 3")

    gr.Markdown("## (Optional) Get Sample TSV(s) to Upload")

    gr.Markdown("### Example 1: A Simple TSV")
    download_btn = gr.DownloadButton(
        label="Download Simple TSV", value="sample_metadata.tsv"
    )
    gr.Markdown("### Example 2: Many TSVs in a single .zip file.")
    download_btn = gr.DownloadButton(label="Download Many TSVs as .zip", value=zip_tsvs)
    gr.Markdown("You need to extract the .zip if you want to use them.")

    gr.Markdown("## Upload TSVs With Headers (No Data Rows Required)")
    files = gr.Files(
        label="Upload TSVs",
        file_types=[".tsv"],
        type="filepath",
    )

    gr.Markdown(
        "Depending on your Huggingface subscription and availability of free GPUs, this can take a few minutes to complete."
    )
    gr.Markdown(
        "Behind the scenes, our [Llama Data Model Generator](https://huggingface.co/uc-ctds/llama-data-model-generator) AI model is being loaded "
        "onto GPUs and the TSVs uploaded are being sent to the model in a specialized prompt. "
        "For information about the model, please see the model card itself by clicking "
        "the link above."
    )

    # Define Outputs
    with gr.Row(equal_height=True):
        json_out = gr.Code(
            label="Generated Data Model Output",
            value=json.dumps([]),
            language="json",
            interactive=True,
            show_label=True,
            container=True,
            elem_id="json-output",
        )
        sql_out = gr.Textbox(
            label="SQL Defined Relational Schema",
            value=[],
            show_label=True,
            container=True,
        )

    with gr.Row():
        nodes_df_out = gr.Dataframe(label="Generated Node/Table Descriptions")
    with gr.Row():
        properties_df_out = gr.Dataframe(label="Generated Property Descriptions")

    # If files are uploaded, generate prompt and run model
    if model_loaded:
        files.upload(
            fn=gen_output_from_files_uploaded,
            inputs=files,
            outputs=[json_out, sql_out, nodes_df_out, properties_df_out],
        )

    gr.Markdown("Run out of FreeGPU or having issues? Try the example outputs!")
    demo_btn2 = gr.Button("Manually Load 'Simple' Example Output from Previous Run")
    demo_btn2.click(
        fn=gen_output_from_example_simple,
        outputs=[json_out, sql_out, nodes_df_out, properties_df_out],
    )

    demo_btn = gr.Button("Manually Load 'Many' Example Output from Previous Run")
    demo_btn.click(
        fn=gen_output_from_example_many,
        outputs=[json_out, sql_out, nodes_df_out, properties_df_out],
    )

if __name__ == "__main__":
    demo.launch(share=True)