jeffrey
add data creation studio waitlist button
6a7f3cc
import os
import tempfile
from typing import List, Callable
import gradio as gr
import pandas as pd
from autorag.data.parse import langchain_parse
from autorag.data.parse.base import _add_last_modified_datetime
from autorag.data.parse.llamaparse import llama_parse
from autorag.data.qa.schema import Raw
from autorag.utils import result_to_dataframe
from llama_index.llms.openai import OpenAI
from src.create import default_create, fast_create, advanced_create
from src.util import on_submit_openai_key, on_submit_llama_cloud_key, on_submit_upstage_key
@result_to_dataframe(["texts", "path", "page", "last_modified_datetime"])
def original_parse(fn: Callable, **kwargs):
result = fn(**kwargs)
result = _add_last_modified_datetime(result)
return result
def change_lang_choice(lang: str) -> str:
lang_dict = {
"English": "en",
"한국어": "ko",
"日本語": "ja"
}
return lang_dict[lang]
def change_visible_status_api_key(parse_method: str):
if parse_method == "llama-parse":
return gr.update(visible=True), gr.update(visible=False)
elif parse_method == "upstage🇰🇷":
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
def run_parse(file_lists: List[str], parse_method: str, original_raw_df, progress=gr.Progress()):
# save an input file to a directory
progress(0.05)
langchain_parse_original = langchain_parse.__wrapped__
if parse_method in ["pdfminer", "pdfplumber", "pypdfium2", "pypdf", "pymupdf"]:
raw_df: pd.DataFrame = original_parse(langchain_parse_original,
data_path_list=file_lists, parse_method=parse_method)
elif parse_method == "llama-parse":
llama_cloud_api_key = os.getenv("LLAMA_CLOUD_API_KEY")
if llama_cloud_api_key is None:
return "Please submit your Llama Cloud API key first.", original_raw_df
raw_df: pd.DataFrame = original_parse(llama_parse.__wrapped__, data_path_list=file_lists)
elif parse_method == "upstage🇰🇷":
upstage_api_key = os.getenv("UPSTAGE_API_KEY")
if upstage_api_key is None:
return "Please submit your Upstage API key first.", original_raw_df
raw_df: pd.DataFrame = original_parse(langchain_parse_original,
data_path_list=file_lists, parse_method="upstagedocumentparse")
else:
return "Unsupported parse method.", original_raw_df
progress(0.8)
return "Parsing Complete. Download at the bottom button.", raw_df
def run_chunk(use_existed_raw: bool, raw_df: pd.DataFrame, raw_file: str, chunk_method: str, chunk_size: int, chunk_overlap: int,
lang: str = "English", original_corpus_df = None, progress=gr.Progress()):
lang = change_lang_choice(lang)
if not use_existed_raw:
raw_df = pd.read_parquet(raw_file, engine="pyarrow")
raw_instance = Raw(raw_df)
if chunk_method in ["Token", "Sentence"]:
corpus = raw_instance.chunk("llama_index_chunk", chunk_method=chunk_method, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, add_file_name=lang)
elif chunk_method in ["Semantic"]:
corpus = raw_instance.chunk("llama_index_chunk", chunk_method="Semantic_llama_index",
embed_model="openai", breakpoint_percnetile_threshold=0.95,
add_file_name=lang)
elif chunk_method == "Recursive":
corpus = raw_instance.chunk("langchain_chunk", chunk_method="recursivecharacter",
add_file_name=lang, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
else:
gr.Error("Unsupported chunk method.")
return "Unsupported chunk method.", original_corpus_df
progress(0.8)
return "Chunking Complete. Download at the bottom button.", corpus.data
def run_qa(use_existed_corpus: bool, corpus_df: pd.DataFrame, corpus_file: str, qa_method: str,
model_name: str, qa_cnt: int, batch_size: int, lang: str = "English", original_qa_df = None,
progress=gr.Progress()):
lang = change_lang_choice(lang)
if not use_existed_corpus:
corpus_df = pd.read_parquet(corpus_file, engine="pyarrow")
if os.getenv("OPENAI_API_KEY") is None:
gr.Error("Please submit your OpenAI API key first.")
return "Please submit your OpenAI API key first.", original_qa_df
if model_name is None:
gr.Error("Please select a model first.")
return "Please select a model first.", original_qa_df
llm = OpenAI(model=model_name)
if qa_method == "default":
qa = default_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
elif qa_method == "fast":
qa = fast_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
elif qa_method == "advanced":
qa = advanced_create(corpus_df, llm=llm, n=qa_cnt, lang=lang, progress=progress, batch_size=batch_size)
else:
gr.Error("Unsupported QA method.")
return "Unsupported QA method.", original_qa_df
return "QA Creation Complete. Download at the bottom button.", qa.data
def download_state(state: pd.DataFrame, change_name: str):
if state is None:
gr.Error("No data to download.")
return ""
with tempfile.TemporaryDirectory() as temp_dir:
filename = os.path.join(temp_dir, f"{change_name}.parquet")
state.to_parquet(filename, engine="pyarrow")
yield filename
with gr.Blocks(theme="earneleh/paris") as demo:
raw_df_state = gr.State()
corpus_df_state = gr.State()
qa_df_state = gr.State()
gr.HTML("<h1>AutoRAG Data Creation 🛠️</h1>")
with gr.Row():
openai_key_textbox = gr.Textbox(label="Please input your OpenAI API key and press Enter.", type="password",
info="You can get your API key from https://platform.openai.com/account/api-keys\n\n"
"AutoRAG do not store your API key.",
autofocus=True)
api_key_status_box = gr.Textbox(label="OpenAI API status", value="Not Set", interactive=False)
lang_choice = gr.Radio(["English", "한국어", "日本語"], label="Language",
value="English", info="Choose Langauge. En, Ko, Ja are supported.",
interactive=True)
with gr.Row(visible=False) as llama_cloud_api_key_row:
llama_key_textbox = gr.Textbox(label="Please input your Llama Cloud API key and press Enter.", type="password",
info="You can get your API key from https://docs.cloud.llamaindex.ai/llamacloud/getting_started/api_key\n\n"
"AutoRAG do not store your API key.",)
llama_key_status_box = gr.Textbox(label="Llama Cloud API status", value="Not Set", interactive=False)
with gr.Row(visible=False) as upstage_api_key_row:
upstage_key_textbox = gr.Textbox(label="Please input your Upstage API key and press Enter.", type="password",
info="You can get your API key from https://upstage.ai/\n\n"
"AutoRAG do not store your API key.",)
upstage_key_status_box = gr.Textbox(label="Upstage API status", value="Not Set", interactive=False)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## 1. Parse your PDF files\n\nUpload your pdf files and make it to raw.parquet.")
document_file_input = gr.File(label="Upload Files", type="filepath", file_count="multiple")
parse_choice = gr.Dropdown(
["pdfminer", "pdfplumber", "pypdfium2", "pypdf", "pymupdf", "llama-parse", "upstage🇰🇷"],
label="Parsing Method", info="Choose parsing method that you want")
parse_button = gr.Button(value="Run Parsing")
parse_status = gr.Textbox(value="Not Started", interactive=False)
raw_download_button = gr.DownloadButton(value=download_state, inputs=[raw_df_state, gr.State("raw")],
label="Download raw.parquet")
with gr.Column(scale=1):
gr.Markdown(
"## 2. Chunk your raw.parquet\n\nUse parsed raw.parquet or upload your own. It will make a corpus.parquet."
)
raw_file_input = gr.File(label="Upload raw.parquet", type="filepath", file_count="single", visible=False)
use_previous_raw_file = gr.Checkbox(label="Use previous raw.parquet", value=True)
chunk_choice = gr.Dropdown(
["Token", "Sentence", "Semantic", "Recursive"],
label="Chunking Method", info="Choose chunking method that you want")
chunk_size = gr.Slider(minimum=128, maximum=1024, step=128, label="Chunk Size", value=256)
chunk_overlap = gr.Slider(minimum=16, maximum=256, step=16, label="Chunk Overlap", value=32)
chunk_button = gr.Button(value="Run Chunking")
chunk_status = gr.Textbox(value="Not Started", interactive=False)
corpus_download_button = gr.DownloadButton(label="Download corpus.parquet",
value=download_state, inputs=[corpus_df_state, gr.State("corpus")])
with gr.Column(scale=1):
gr.Markdown(
"## 3. Create QA dataset from your corpus.parquet\n\nQA dataset is essential to run AutoRAG. Upload corpus.parquet & select QA method and run.")
gr.HTML("<b style='color: red; background-color: black; font-weight: bold;'>Warning: QA Creation uses an OpenAI model, which can be costly. Start with a small batch to gauge expenses.</b>")
corpus_file_input = gr.File(label="Upload corpus.parquet", type="filepath", file_count="single",
visible=False)
use_previous_corpus_file = gr.Checkbox(label="Use previous corpus.parquet", value=True)
qa_choice = gr.Radio(["default", "fast", "advanced"], label="QA Method",
info="Choose QA method that you want")
model_choice = gr.Radio(["gpt-4o-mini", "gpt-4o"], label="Select model for data creation",
)
qa_cnt = gr.Slider(minimum=20, maximum=150, step=5, label="Number of QA pairs", value=80)
batch_size = gr.Slider(minimum=1, maximum=16, step=1,
label="Batch Size to OpenAI model. If there is an error, decrease this.", value=16)
run_qa_button = gr.Button(value="Run QA Creation")
qa_status = gr.Textbox(value="Not Started", interactive=False)
gr.Markdown("### Do you want to customize your QA dataset? Join a waitlist for AutoRAG data creation studio.")
gr.Button("Join Data Creation Studio Waitlist", link="https://tally.so/r/wdDo6N")
qa_download_button = gr.DownloadButton(label="Download qa.parquet",
value=download_state, inputs=[qa_df_state, gr.State("qa")])
#================================================================================================#
# Logics
use_previous_raw_file.change(lambda x: gr.update(visible=not x), inputs=[use_previous_raw_file],
outputs=[raw_file_input])
use_previous_corpus_file.change(lambda x: gr.update(visible=not x), inputs=[use_previous_corpus_file],
outputs=[corpus_file_input])
openai_key_textbox.submit(on_submit_openai_key, inputs=[openai_key_textbox], outputs=api_key_status_box)
# Parsing
parse_button.click(run_parse, inputs=[document_file_input, parse_choice, raw_df_state],
outputs=[parse_status, raw_df_state])
# Chunking
chunk_button.click(run_chunk, inputs=[use_previous_raw_file, raw_df_state, raw_file_input, chunk_choice, chunk_size, chunk_overlap,
lang_choice, corpus_df_state],
outputs=[chunk_status, corpus_df_state])
# QA Creation
run_qa_button.click(run_qa, inputs=[use_previous_corpus_file, corpus_df_state, corpus_file_input, qa_choice,
model_choice, qa_cnt, batch_size, lang_choice,
qa_df_state],
outputs=[qa_status, qa_df_state])
# API Key visibility
parse_choice.change(change_visible_status_api_key, inputs=[parse_choice],
outputs=[llama_cloud_api_key_row, upstage_api_key_row])
llama_key_textbox.submit(on_submit_llama_cloud_key, inputs=[llama_key_textbox], outputs=llama_key_status_box)
upstage_key_textbox.submit(on_submit_upstage_key, inputs=[upstage_key_textbox], outputs=upstage_key_status_box)
# if __name__ == "__main__":
# demo.launch(share=False, debug=True)
demo.launch(share=False, debug=False)