mvansegbroeck's picture
Update app.py
7d29c31 verified
import gradio as gr
import requests
import os
import re
import markdownify
import fitz
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import random
from gretel_client import Gretel
from gretel_client.config import GretelClientConfigurationError
# Directory for saving processed files
output_dir = 'processed_files'
os.makedirs(output_dir, exist_ok=True)
# Function to download and convert a PDF to text
def pdf_to_text(pdf_path):
pdf_document = fitz.open(pdf_path)
text = ''
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
text += page.get_text()
return text
# Function to read a TXT file
def txt_to_text(txt_path):
with open(txt_path, 'r') as file:
return file.read()
# Function to read a Markdown file
def markdown_to_text(md_path):
with open(md_path, 'r') as file:
return file.read()
def sanitize_key(filename):
# Replace spaces with underscores
filename = filename.replace(" ", "_")
# Remove special characters except for underscores
filename = re.sub(r'[^a-zA-Z0-9_]', '', filename)
# Ensure the key is not too long
filename = filename[:100] # Truncate to 100 characters if necessary
return filename
# Function to split text into chunks
def split_text_into_chunks(text, chunk_size=25, chunk_overlap=5, min_chunk_chars=50):
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_text(text)
return [chunk for chunk in chunks if len(chunk) >= min_chunk_chars]
# Function to save chunks to files
def save_chunks(file_id, chunks, output_dir):
for i, chunk in enumerate(chunks):
chunk_filename = f"{file_id}_chunk_{i+1}.md"
chunk_path = os.path.join(output_dir, chunk_filename)
with open(chunk_path, 'w') as file:
file.write(chunk)
# Function to read chunks from files
def read_chunks_from_files(output_dir):
chunks_dict = {}
for filename in os.listdir(output_dir):
if filename.endswith('.md') and 'chunk' in filename:
file_id = filename.split('_chunk_')[0]
chunk_path = os.path.join(output_dir, filename)
with open(chunk_path, 'r') as file:
chunk = file.read()
if file_id not in chunks_dict:
chunks_dict[file_id] = []
chunks_dict[file_id].append(chunk)
return chunks_dict
def process_files(uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, direction):
selected_files = []
if use_example:
example_file_url = "https://gretel-datasets.s3.us-west-2.amazonaws.com/rag/GDPR_2016.pdf"
file_path = os.path.join(output_dir, example_file_url.split('/')[-1])
if not os.path.exists(file_path):
response = requests.get(example_file_url)
with open(file_path, 'wb') as file:
file.write(response.content)
selected_files = [file_path]
elif uploaded_files is not None:
for uploaded_file in uploaded_files:
file_path = os.path.join(output_dir, uploaded_file.name)
# with open(file_path, 'wb') as file:
# file.write(uploaded_file.read())
selected_files.append(file_path)
else:
chunk_text = "No files processed"
return None, 0, chunk_text, None
chunks_dict = {}
for file_path in selected_files:
file_extension = os.path.splitext(file_path)[1].lower()
if file_extension == '.pdf':
text = pdf_to_text(file_path)
elif file_extension == '.txt':
text = txt_to_text(file_path)
elif file_extension == '.md':
text = markdown_to_text(file_path)
else:
text = ""
markdown_text = markdownify.markdownify(text)
file_id = os.path.splitext(os.path.basename(file_path))[0]
file_id = sanitize_key(file_id)
markdown_path = os.path.join(output_dir, f"{file_id}.md")
with open(markdown_path, 'w') as file:
file.write(markdown_text)
chunks = split_text_into_chunks(markdown_text, chunk_size=chunk_size, chunk_overlap=chunk_overlap, min_chunk_chars=min_chunk_chars)
save_chunks(file_id, chunks, output_dir)
chunks_dict[file_id + file_extension] = chunks
all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]
current_chunk += direction
if current_chunk < 0:
current_chunk = 0
elif current_chunk >= len(all_chunks):
current_chunk = len(all_chunks) - 1
chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."
return chunks_dict, selected_files, chunk_text, current_chunk#, use_example_update
def show_chunks(chunks_dict, selected_files, current_chunk, direction):
all_chunks = [chunk for chunks in chunks_dict.values() for chunk in chunks]
current_chunk += direction
if current_chunk < 0:
current_chunk = 0
elif current_chunk >= len(all_chunks):
current_chunk = len(all_chunks) - 1
chunk_text = all_chunks[current_chunk] if all_chunks else "No chunks available."
return chunk_text, current_chunk
# Validate API key and return button state
def check_api_key(api_key):
try:
Gretel(api_key=api_key, validate=True, clear=True)
is_valid = True
status_message = "Valid"
except GretelClientConfigurationError:
is_valid = False
status_message = "Invalid"
return gr.update(interactive=is_valid), status_message
def generate_synthetic_records(api_key, chunks_dict, num_records):
gretel = Gretel(api_key=api_key, validate=True, clear=True)
navigator = gretel.factories.initialize_navigator_api("tabular", backend_model="gretelai/Mistral-7B-Instruct-v0.2/industry")
INTRO_PROMPT = "From the source text below, create a dataset with the following columns:\n"
COLUMN_DETAILS = (
"* `topic`: Select a topic relevant for the given source text.\n"
"* `user_profile`: The complexity level of the question and truth, categorized into beginner, intermediate, and expert.\n"
" - Beginner users are about building foundational knowledge about the product and ask about basic features, benefits, and uses of the product.\n"
" - Intermediate users have a deep understanding of the product and are focusing on practical applications, comparisons with other products, and intermediate-level features and benefits.\n"
" - Expert users demonstrate in-depth technical knowledge, strategic application, and advanced troubleshooting. This level is for those who need to know the product inside and out, possibly for roles in sales, technical support, or product development.\n"
"* `question`: Ask a set of unique questions related to the topic that a user might ask. "
"Questions should be relatively complex and specific enough to be addressed in a short answer.\n"
"* `answer`: Respond to the question with a clear, textbook quality answer that provides relevant details to fully address the question.\n"
"* `context`: Copy the exact sentence(s) from the source text and surrounding details from where the answer can be derived.\n"
)
PROMPT = INTRO_PROMPT + COLUMN_DETAILS
GENERATE_PARAMS = {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40
}
EVAL_PROMPT = """
Add the following columns to the provided table based on the generated text in the `answer` column:
* conformance: A rating from 'Poor' to 'Excellent' indicating the conformance of the generated text to the requested format, tags, and descriptions provided.
- 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
* quality: A rating from 'Poor' to 'Excellent' based on the grammatical correctness, coherence, and relevance of the generated text.
- 'Poor', 'Fair', 'Good', 'Very Good', 'Excellent'
* toxicity: A rating from 'Non-Toxic' to 'Highly Toxic' indicating the level of toxic content in the generated text.
- 'Non-Toxic', 'Moderately Toxic', 'Highly Toxic'
* bias: A rating from 'Unbiased' to 'Heavily Biased' indicating the level of unintended biases in the generated text.
- 'Unbiased', 'Moderately Biased', 'Heavily Biased'
* groundedness: A rating from 'Ungrounded' to 'Fully Grounded' indicating the level of factual correctness in the generated text.
- 'Ungrounded', 'Moderately Grounded', 'Fully Grounded'
"""
EVAL_GENERATE_PARAMS = {
"temperature": 0.2,
"top_p": 0.5,
"top_k": 40
}
df_in = pd.DataFrame()
try:
documents = list(chunks_dict.keys())
all_chunks = [(doc, chunk) for doc in documents for chunk in chunks_dict[doc]]
for _ in range(num_records):
doc, chunk = random.choice(all_chunks)
df_doc = pd.DataFrame({'document': [doc], 'text': [chunk]})
df_in = pd.concat([df_in, df_doc], ignore_index=True)
df = navigator.edit(PROMPT, seed_data=df_in, **GENERATE_PARAMS)
df = df.drop(columns=['text'])
df = navigator.edit(EVAL_PROMPT, seed_data=df, **EVAL_GENERATE_PARAMS)
df.rename(columns={
"question": "synthetic_question",
"answer": "synthetic_answer",
"context": "original_context"
}, inplace=True)
csv_file = os.path.join(output_dir, "synthetic_qa.csv")
df.to_csv(csv_file, index=False)
return gr.update(value=df, visible=True), csv_file
except:
return gr.update(value=None, visible=False), None
def download_dataframe(df):
csv_file = os.path.join(output_dir, "synthetic_qa.csv")
df.to_csv(csv_file, index=False)
return csv_file
# CSS styling to center the logo and prevent right-click download
logo_css = """
<style>
#logo-container {
display: flex;
justify-content: center;
width: 100%;
}
#logo-container svg {
pointer-events: none; /* Disable pointer events on the SVG */
}
</style>
"""
# HTML content to include the logo
html_content = f"""
{logo_css}
<div id="logo-container">
<svg width="181" height="72" viewBox="0 0 181 72" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_849_78)">
<path d="M53.4437 41.3178V53.5794H44.4782V18.8754H53.4437V27.0498C55.1339 21.1048 58.9552 18.1323 63.144 18.1323C65.3487 18.1323 67.2593 18.5782 68.8025 19.3956L67.2593 27.57C65.863 26.9011 64.0993 26.604 62.0417 26.604C56.3097 26.604 53.4437 31.5085 53.4437 41.3178Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M103.383 45.9252C100.444 51.573 94.1975 54.3226 87.3631 54.3226C82.366 54.3226 78.1773 52.6134 74.7234 49.2693C71.2694 45.8509 69.5793 41.4664 69.5793 36.1159C69.5793 30.7654 71.2694 26.4553 74.7234 23.1112C78.1773 19.7671 82.366 18.1323 87.3631 18.1323C92.3603 18.1323 96.4755 19.7671 99.7824 23.1112C103.089 26.4553 104.78 30.7654 104.78 36.1159C104.78 37.019 104.715 37.987 104.647 39.0198L104.633 39.2371H78.4712C79.0591 43.7701 82.8805 46.6684 87.951 46.6684C91.5519 46.6684 95.0058 45.1078 96.549 42.2097L103.383 45.9252ZM78.3978 33.0691H96.0346C95.3733 28.3875 91.9194 25.7122 87.4366 25.7122C82.66 25.7122 79.0591 28.5361 78.3978 33.0691Z" fill="#3C1AE6"/>
<path d="M121.87 26.158V53.5794H112.979V26.158H106.732V18.8754H112.979V5.64777H121.87V18.8754H129.146V26.158H121.87Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M164.903 45.9252C161.963 51.573 155.716 54.3226 148.882 54.3226C143.885 54.3226 139.696 52.6134 136.242 49.2693C132.789 45.8509 131.098 41.4664 131.098 36.1159C131.098 30.7654 132.789 26.4553 136.242 23.1112C139.696 19.7671 143.885 18.1323 148.882 18.1323C153.879 18.1323 157.994 19.7671 161.301 23.1112C164.609 26.4553 166.299 30.7654 166.299 36.1159C166.299 37.0174 166.235 37.9834 166.167 39.0141L166.152 39.2371H139.99C140.578 43.7701 144.399 46.6684 149.47 46.6684C153.072 46.6684 156.525 45.1078 158.069 42.2097L164.903 45.9252ZM139.917 33.0691H157.554C156.893 28.3875 153.439 25.7122 148.956 25.7122C144.179 25.7122 140.578 28.5361 139.917 33.0691Z" fill="#3C1AE6"/>
<path d="M180.597 0V53.5794H171.631V0H180.597Z" fill="#3C1AE6"/>
<path d="M27.1716 19.3782C27.1716 14.947 30.7321 11.3548 35.1241 11.3548V19.3782C35.1764 19.3959 27.1716 19.3782 27.1716 19.3782Z" fill="#3C1AE6"/>
<path d="M34.7984 54.5253C34.7984 64.11 27.2527 71.9206 17.8936 71.9206C8.62804 71.9206 1.13987 64.2655 0.991031 54.8122L0.988777 54.5253H8.94397C8.94397 59.7209 12.9737 63.8921 17.8936 63.8921C22.746 63.8921 26.7325 59.8342 26.8409 54.7381L26.8431 54.5253H34.7984Z" fill="#3C1AE6"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M35.7872 36.4522C35.7872 26.4724 27.7758 18.3822 17.8936 18.3822C8.01121 18.3822 0 26.4724 0 36.4522C0 46.4322 8.01121 54.5224 17.8936 54.5224C27.7758 54.5224 35.7872 46.4322 35.7872 36.4522ZM8.61542 36.4522C8.61542 31.2775 12.7694 27.0826 17.8936 27.0826C23.0178 27.0826 27.1716 31.2775 27.1716 36.4522C27.1716 41.6271 23.0178 45.822 17.8936 45.822C12.7694 45.822 8.61542 41.6271 8.61542 36.4522Z" fill="#3C1AE6"/>
</g>
<defs>
<clipPath id="clip0_849_78">
<rect width="181" height="72" fill="white"/>
</clipPath>
</defs>
</svg>
</div>
"""
# Define custom CSS to set the font size
css = """
#small span{
font-size: 0.8em;
}
"""
# Gradio interface
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(scale=3):
gr.HTML(html_content)
with gr.Tab("Upload Files"):
use_example = gr.Checkbox(label="Continue with Example PDF", value=False, interactive=True)
uploaded_files = gr.File(label="Upload your files (TXT, Markdown, or PDF)", file_count="multiple", file_types=[".pdf", ".txt", ".md"])
chunk_size = gr.Slider(label="Chunk Size (tokens)", minimum=10, maximum=1500, step=10, value=500)
chunk_overlap = gr.Slider(label="Chunk Overlap (tokens)", minimum=0, maximum=500, step=5, value=100)
min_chunk_chars = gr.Slider(label="Minimum Chunk Characters", minimum=10, maximum=2500, step=10, value=750)
process_button = gr.Button("Process Files")
chunks_dict = gr.State()
selected_files = gr.State()
current_chunk = gr.State(value=0)
chunk_text = gr.Textbox(label="Chunk Text", lines=10)
def toggle_use_example(file_list):
return gr.update(
value=False,
interactive=file_list is None or len(file_list) == 0
)
uploaded_files.change(
toggle_use_example,
inputs=[uploaded_files],
outputs=[use_example]
)
process_button.click(
process_files,
inputs=[uploaded_files, use_example, chunk_size, chunk_overlap, min_chunk_chars, current_chunk, gr.State(0)],
outputs=[chunks_dict, selected_files, chunk_text, current_chunk]
)
with gr.Row():
prev_button = gr.Button("Previous Chunk", scale=1)
next_button = gr.Button("Next Chunk", scale=1)
prev_button.click(
show_chunks,
inputs=[chunks_dict, selected_files, current_chunk, gr.State(-1)],
outputs=[chunk_text, current_chunk]
)
next_button.click(
show_chunks,
inputs=[chunks_dict, selected_files, current_chunk, gr.State(1)],
outputs=[chunk_text, current_chunk]
)
with gr.Column(scale=7):
gr.Markdown("# Generate Question-Answer pairs")
with gr.Row():
api_key_input = gr.Textbox(label="Gretel API Key (available at https://console.gretel.ai)", type="password", placeholder="Enter your API key", scale=2)
validate_status = gr.Textbox(label="Validation Status", interactive=False, scale=1)
num_records = gr.Number(label="Number of Records", value=10)
generate_button = gr.Button("Generate Synthetic Records", interactive=False)
download_link = gr.File(label="Download Link", visible=False)
api_key_input.change(
fn=check_api_key,
inputs=[api_key_input],
outputs=[generate_button, validate_status]
)
output_df = gr.Dataframe(headers=["",], wrap=True, visible=True, elem_id="small")
def generate_and_prepare_download(api_key, chunks_dict, num_records):
df, csv_file = generate_synthetic_records(api_key, chunks_dict, num_records)
return df, gr.update(value=csv_file, visible=df['value']!=None)
generate_button.click(
fn=generate_and_prepare_download,
inputs=[api_key_input, chunks_dict, num_records],
outputs=[output_df, download_link]
)
demo.launch()