File size: 17,154 Bytes
4bd2901
 
 
1420dd0
4bd2901
0498138
4bd2901
 
 
 
 
 
4af6426
 
4bd2901
 
 
 
 
 
 
 
 
 
 
4af6426
 
 
 
 
 
 
 
 
 
1420dd0
 
 
 
 
 
 
 
 
4bd2901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4af6426
4bd2901
 
 
 
 
 
4af6426
 
 
 
4bd2901
4af6426
 
4bd2901
 
4af6426
 
4bd2901
4af6426
4bd2901
4af6426
4bd2901
 
4af6426
 
 
 
4bd2901
4af6426
4bd2901
 
4af6426
 
 
 
 
 
 
 
 
 
 
1420dd0
4bd2901
4af6426
1420dd0
4bd2901
 
 
 
 
4af6426
4bd2901
4af6426
4bd2901
 
 
 
4af6426
 
4bd2901
4af6426
 
 
4bd2901
4af6426
 
 
 
 
 
 
 
 
 
 
4bd2901
 
 
 
 
 
 
 
 
 
 
 
4af6426
4bd2901
 
 
7d29c31
4bd2901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4af6426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bd2901
c6db2b5
4af6426
 
c6db2b5
 
 
 
 
 
 
 
4af6426
 
 
 
 
 
 
c6db2b5
 
4bd2901
c6db2b5
4bd2901
c6db2b5
 
 
 
 
 
 
4bd2901
 
4af6426
4bd2901
 
 
 
 
 
 
 
 
 
 
 
 
 
4af6426
4bd2901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4af6426
 
 
 
 
 
 
4bd2901
4af6426
4bd2901
 
 
 
4af6426
4bd2901
4af6426
4bd2901
 
 
 
 
4af6426
4bd2901
4af6426
 
4bd2901
 
 
 
 
 
4af6426
4bd2901
4af6426
4bd2901
 
 
 
 
 
 
 
4af6426
4bd2901
4af6426
4bd2901
 
 
 
 
 
 
 
4af6426
4bd2901
 
 
 
 
4af6426
4bd2901
 
 
 
 
 
 
a444f06
4bd2901
 
 
 
 
c6db2b5
4bd2901
 
 
 
 
 
 
4af6426
4bd2901
4af6426
 
c6db2b5
 
4bd2901
c6db2b5
4af6426
c6db2b5
4bd2901
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
import gradio as gr
import requests
import os
import re
import markdownify
import fitz
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import random
from gretel_client import Gretel
from gretel_client.config import GretelClientConfigurationError

# Directory for saving processed files
output_dir = 'processed_files'
os.makedirs(output_dir, exist_ok=True)

# Function to download and convert a PDF to text
def pdf_to_text(pdf_path):
    pdf_document = fitz.open(pdf_path)
    text = ''
    for page_num in range(pdf_document.page_count):
        page = pdf_document.load_page(page_num)
        text += page.get_text()
    return text

# Function to read a TXT file
def txt_to_text(txt_path):
    with open(txt_path, 'r') as file:
        return file.read()

# Function to read a Markdown file
def markdown_to_text(md_path):
    with open(md_path, 'r') as file:
        return file.read()

def sanitize_key(filename):
    # Replace spaces with underscores
    filename = filename.replace(" ", "_")
    # Remove special characters except for underscores
    filename = re.sub(r'[^a-zA-Z0-9_]', '', filename)
    # Ensure the key is not too long
    filename = filename[:100]  # Truncate to 100 characters if necessary
    return filename

# Function to split text into chunks
def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50):
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = text_splitter.split_text(text)
    return [chunk for chunk in chunks if len(chunk) >= min_chunk_chars]

# Function to save chunks to files
def save_chunks(file_id, chunks, output_dir):
    for i, chunk in enumerate(chunks):
        chunk_filename = f"{file_id}_chunk_{i+1}.md"
        chunk_path = os.path.join(output_dir, chunk_filename)
        with open(chunk_path, 'w') as file:
            file.write(chunk)

# Function to read chunks from files
def read_chunks_from_files(output_dir):
    chunks_dict = {}
    for filename in os.listdir(output_dir):
        if filename.endswith('.md') and 'chunk' in filename:
            file_id = filename.split('_chunk_')[0]
            chunk_path = os.path.join(output_dir, filename)
            with open(chunk_path, 'r') as file:
                chunk = file.read()
            if file_id not in chunks_dict:
                chunks_dict[file_id] = []
            chunks_dict[file_id].append(chunk)
    return chunks_dict

def process_files(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction):
    selected_files = []
    if use_example:
        example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf"
        file_path = os.path.join(output_dir, example_file_url.split('/')[-1])
        if not os.path.exists(file_path):
            response = requests.get(example_file_url)
            with open(file_path, 'wb') as file:
                file.write(response.content)
        selected_files = [file_path]
    elif uploaded_files is not None:
        for uploaded_file in uploaded_files:
            file_path = os.path.join(output_dir, uploaded_file.name)
            # with open(file_path, 'wb') as file:
                # file.write(uploaded_file.read())
            selected_files.append(file_path)
    else:
        chunk_text = "No files processed"
        return None, 0, chunk_text, None

    chunks_dict = {}
    for file_path in selected_files:
        file_extension = os.path.splitext(file_path)[1].lower()
        if file_extension == '.pdf':
            text = pdf_to_text(file_path)
        elif file_extension == '.txt':
            text = txt_to_text(file_path)
        elif file_extension == '.md':
            text = markdown_to_text(file_path)
        else:
            text = ""

        markdown_text = markdownify.markdownify(text)
        file_id = os.path.splitext(os.path.basename(file_path))[0]
        file_id = sanitize_key(file_id)
        markdown_path = os.path.join(output_dir, f"{file_id}.md")
        with open(markdown_path, 'w') as file:
            file.write(markdown_text)
        chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars)
        save_chunks(file_id, chunks, output_dir)
        chunks_dict[file_id + file_extension] = chunks

    all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]
    
    current_chunk += direction
    if current_chunk < 0:
        current_chunk = 0
    elif current_chunk >= len(all_chunks):
        current_chunk = len(all_chunks) - 1

    chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."    
    
    return chunks_dict, selected_files, chunk_text, current_chunk#, use_example_update

def show_chunks(chunks_dict, selected_files, current_chunk, direction):
    all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]

    current_chunk += direction
    if current_chunk < 0:
        current_chunk = 0
    elif current_chunk >= len(all_chunks):
        current_chunk = len(all_chunks) - 1

    chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."
    return chunk_text, current_chunk

# Validate API key and return button state
def check_api_key(api_key):
    try:
        Gretel(api_key=api_key, validate=True, clear=True)
        is_valid = True
        status_message = "Valid"
    except GretelClientConfigurationError:
        is_valid = False
        status_message = "Invalid"
    return gr.update(interactive=is_valid), status_message

def generate_synthetic_records(api_key, chunks_dict, num_records):

    gretel = Gretel(api_key=api_key, validate=True, clear=True)

    navigator = gretel.factories.initialize_navigator_api("tabular", backend_model="gretelai/Mistral-7B-Instruct-v0.2/industry")

    INTRO_PROMPT = "From the source text below, create a dataset with the following columns:\n"
    COLUMN_DETAILS = (
        "* `topic`: Select a topic relevant for the given source text.\n"
        "* `user_profile`: The complexity level of the question and truth, categorized into beginner, intermediate, and expert.\n"
        "  - Beginner users are about building foundational knowledge about the product and ask about basic features, benefits, and uses of the product.\n"
        "  - Intermediate users have a deep understanding of the product and are focusing on practical applications, comparisons with other products, and intermediate-level features and benefits.\n"
        "  - Expert users demonstrate in-depth technical knowledge, strategic application, and advanced troubleshooting. This level is for those who need to know the product inside and out, possibly for roles in sales, technical support, or product development.\n"
        "* `question`: Ask a set of unique questions related to the topic that a user might ask. "
        "Questions should be relatively complex and specific enough to be addressed in a short answer.\n"
        "* `answer`: Respond to the question with a clear, textbook quality answer that provides relevant details to fully address the question.\n"
        "* `context`: Copy the exact sentence(s) from the source text and surrounding details from where the answer can be derived.\n"
    )
    PROMPT = INTRO_PROMPT + COLUMN_DETAILS

    GENERATE_PARAMS = {
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 40
    }

    EVAL_PROMPT = """
    Add the following columns to the provided table based on the generated text in the `answer` column:
    * conformance: A rating from 'Poor' to 'Excellent' indicating the conformance of the generated text to the requested format, tags, and descriptions provided.
    - 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
    * quality: A rating from 'Poor' to 'Excellent' based on the grammatical correctness, coherence, and relevance of the generated text.
    - 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
    * toxicity: A rating from 'Non-Toxic' to 'Highly Toxic' indicating the level of toxic content in the generated text.
    - 'Non-Toxic', 'Moderately Toxic', 'Highly Toxic'
    * bias: A rating from 'Unbiased' to 'Heavily Biased' indicating the level of unintended biases in the generated text.
    - 'Unbiased', 'Moderately Biased', 'Heavily Biased'
    * groundedness: A rating from 'Ungrounded' to 'Fully Grounded' indicating the level of factual correctness in the generated text.
    - 'Ungrounded', 'Moderately Grounded', 'Fully Grounded'
    """

    EVAL_GENERATE_PARAMS = {
        "temperature": 0.2,
        "top_p": 0.5,
        "top_k": 40
    }

    df_in = pd.DataFrame()
    try:
        documents = list(chunks_dict.keys())
        all_chunks = [(doc, chunk) for doc in documents for chunk in chunks_dict[doc]]

        for _ in range(num_records):
            doc, chunk = random.choice(all_chunks)
            df_doc = pd.DataFrame({'document': [doc], 'text': [chunk]})
            df_in = pd.concat([df_in, df_doc], ignore_index=True)

        df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS)
        df = df.drop(columns=['text'])
        df = navigator.edit(EVAL_PROMPT, seed_data=df, **EVAL_GENERATE_PARAMS)
        df.rename(columns={
            "question": "synthetic_question",
            "answer": "synthetic_answer",
            "context": "original_context"
        }, inplace=True)
        
        csv_file = os.path.join(output_dir, "synthetic_qa.csv")
        df.to_csv(csv_file, index=False)

        return gr.update(value=df, visible=True), csv_file

    except:
        return gr.update(value=None, visible=False), None

def download_dataframe(df):
    csv_file = os.path.join(output_dir, "synthetic_qa.csv")
    df.to_csv(csv_file, index=False)
    return csv_file

# CSS styling to center the logo and prevent right-click download
logo_css = """
<style>
#logo-container {
    display: flex;
    justify-content: center;
    width: 100%;
}
#logo-container svg {
    pointer-events: none; /* Disable pointer events on the SVG */
}
</style>
"""

# HTML content to include the logo
html_content = f"""
{logo_css}
<div id="logo-container">
    <svg width="181" height="72" viewBox="0 0 181 72" fill="none" xmlns="http://www.w3.org/2000/svg">
    <g clip-path="url(#clip0_849_78)">
    <path d="M53.4437 41.3178V53.5794H44.4782V18.8754H53.4437V27.0498C55.1339 21.1048 58.9552 18.1323 63.144 18.1323C65.3487 18.1323 67.2593 18.5782 68.8025 19.3956L67.2593 27.57C65.863 26.9011 64.0993 26.604 62.0417 26.604C56.3097 26.604 53.4437 31.5085 53.4437 41.3178Z" fill="#3C1AE6"/>
    <path fill-rule="evenodd" clip-rule="evenodd" d="M103.383 45.9252C100.444 51.573 94.1975 54.3226 87.3631 54.3226C82.366 54.3226 78.1773 52.6134 74.7234 49.2693C71.2694 45.8509 69.5793 41.4664 69.5793 36.1159C69.5793 30.7654 71.2694 26.4553 74.7234 23.1112C78.1773 19.7671 82.366 18.1323 87.3631 18.1323C92.3603 18.1323 96.4755 19.7671 99.7824 23.1112C103.089 26.4553 104.78 30.7654 104.78 36.1159C104.78 37.019 104.715 37.987 104.647 39.0198L104.633 39.2371H78.4712C79.0591 43.7701 82.8805 46.6684 87.951 46.6684C91.5519 46.6684 95.0058 45.1078 96.549 42.2097L103.383 45.9252ZM78.3978 33.0691H96.0346C95.3733 28.3875 91.9194 25.7122 87.4366 25.7122C82.66 25.7122 79.0591 28.5361 78.3978 33.0691Z" fill="#3C1AE6"/>
    <path d="M121.87 26.158V53.5794H112.979V26.158H106.732V18.8754H112.979V5.64777H121.87V18.8754H129.146V26.158H121.87Z" fill="#3C1AE6"/>
    <path fill-rule="evenodd" clip-rule="evenodd" d="M164.903 45.9252C161.963 51.573 155.716 54.3226 148.882 54.3226C143.885 54.3226 139.696 52.6134 136.242 49.2693C132.789 45.8509 131.098 41.4664 131.098 36.1159C131.098 30.7654 132.789 26.4553 136.242 23.1112C139.696 19.7671 143.885 18.1323 148.882 18.1323C153.879 18.1323 157.994 19.7671 161.301 23.1112C164.609 26.4553 166.299 30.7654 166.299 36.1159C166.299 37.0174 166.235 37.9834 166.167 39.0141L166.152 39.2371H139.99C140.578 43.7701 144.399 46.6684 149.47 46.6684C153.072 46.6684 156.525 45.1078 158.069 42.2097L164.903 45.9252ZM139.917 33.0691H157.554C156.893 28.3875 153.439 25.7122 148.956 25.7122C144.179 25.7122 140.578 28.5361 139.917 33.0691Z" fill="#3C1AE6"/>
    <path d="M180.597 0V53.5794H171.631V0H180.597Z" fill="#3C1AE6"/>
    <path d="M27.1716 19.3782C27.1716 14.947 30.7321 11.3548 35.1241 11.3548V19.3782C35.1764 19.3959 27.1716 19.3782 27.1716 19.3782Z" fill="#3C1AE6"/>
    <path d="M34.7984 54.5253C34.7984 64.11 27.2527 71.9206 17.8936 71.9206C8.62804 71.9206 1.13987 64.2655 0.991031 54.8122L0.988777 54.5253H8.94397C8.94397 59.7209 12.9737 63.8921 17.8936 63.8921C22.746 63.8921 26.7325 59.8342 26.8409 54.7381L26.8431 54.5253H34.7984Z" fill="#3C1AE6"/>
    <path fill-rule="evenodd" clip-rule="evenodd" d="M35.7872 36.4522C35.7872 26.4724 27.7758 18.3822 17.8936 18.3822C8.01121 18.3822 0 26.4724 0 36.4522C0 46.4322 8.01121 54.5224 17.8936 54.5224C27.7758 54.5224 35.7872 46.4322 35.7872 36.4522ZM8.61542 36.4522C8.61542 31.2775 12.7694 27.0826 17.8936 27.0826C23.0178 27.0826 27.1716 31.2775 27.1716 36.4522C27.1716 41.6271 23.0178 45.822 17.8936 45.822C12.7694 45.822 8.61542 41.6271 8.61542 36.4522Z" fill="#3C1AE6"/>
    </g>
    <defs>
    <clipPath id="clip0_849_78">
    <rect width="181" height="72" fill="white"/>
    </clipPath>
    </defs>
    </svg>
</div>
"""

# Define custom CSS to set the font size
css = """
#small span{
 font-size: 0.8em;
}
"""

# Gradio interface
with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column(scale=3):
            gr.HTML(html_content)

            with gr.Tab("Upload Files"):
                use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True)
                uploaded_files = gr.File(label="Upload your files (TXT, Markdown, or PDF)", file_count="multiple", file_types=[".pdf", ".txt", ".md"])
            
            chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500)
            chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100)
            min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750)

            process_button = gr.Button("Process Files")

            chunks_dict = gr.State()
            selected_files = gr.State()
            current_chunk = gr.State(value=0)

            chunk_text = gr.Textbox(label="Chunk Text", lines=10)

            def toggle_use_example(file_list):
                return gr.update(
                    value=False,
                    interactive=file_list is None or len(file_list) == 0
                )

            uploaded_files.change(
                toggle_use_example,
                inputs=[uploaded_files],
                outputs=[use_example]
            )

            process_button.click(
                process_files,
                inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)],
                outputs=[chunks_dict, selected_files, chunk_text, current_chunk]
            )

            with gr.Row():
                prev_button = gr.Button("Previous Chunk", scale=1)
                next_button = gr.Button("Next Chunk", scale=1)

            prev_button.click(
                show_chunks,
                inputs=[chunks_dict, selected_files, current_chunk, gr.State(-1)],
                outputs=[chunk_text, current_chunk]
            )

            next_button.click(
                show_chunks,
                inputs=[chunks_dict, selected_files, current_chunk, gr.State(1)],
                outputs=[chunk_text, current_chunk]
            )
        
        with gr.Column(scale=7):
            gr.Markdown("# Generate Question-Answer pairs")

            with gr.Row():
                api_key_input = gr.Textbox(label="Gretel API Key (available at https://console.gretel.ai)", type="password", placeholder="Enter your API key", scale=2)
                validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1)

            num_records = gr.Number(label="Number of Records", value=10)

            generate_button = gr.Button("Generate Synthetic Records", interactive=False)
            download_link = gr.File(label="Download Link", visible=False)

            api_key_input.change(
                fn=check_api_key,
                inputs=[api_key_input],
                outputs=[generate_button, validate_status]
            )

            output_df = gr.Dataframe(headers=["",], wrap=True, visible=True, elem_id="small")

            def generate_and_prepare_download(api_key, chunks_dict, num_records):
                df, csv_file = generate_synthetic_records(api_key, chunks_dict, num_records)
                return df, gr.update(value=csv_file, visible=df['value']!=None)

            generate_button.click(
                fn=generate_and_prepare_download,
                inputs=[api_key_input, chunks_dict, num_records],
                outputs=[output_df, download_link]
            )

demo.launch()