asoria's picture
asoria HF staff
Minor details for RAG
117da13
raw
history blame
10.2 kB
import gradio as gr
from gradio_huggingfacehub_search import HuggingfaceHubSearch
import nbformat as nbf
from huggingface_hub import HfApi
from httpx import Client
import logging
import pandas as pd
from utils.notebook_utils import (
eda_cells,
replace_wildcards,
rag_cells,
embeggins_cells,
)
from dotenv import load_dotenv
import os
# TODOs:
# Improve UI code preview
# Add template for training
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
NOTEBOOKS_REPOSITORY = os.getenv("NOTEBOOKS_REPOSITORY")
assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
assert (
NOTEBOOKS_REPOSITORY is not None
), "You need to set NOTEBOOKS_REPOSITORY in your environment variables"
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
client = Client(headers=HEADERS)
logging.basicConfig(level=logging.INFO)
def get_compatible_libraries(dataset: str):
try:
response = client.get(
f"{BASE_DATASETS_SERVER_URL}/compatible-libraries?dataset={dataset}"
)
response.raise_for_status()
return response.json()
except Exception as e:
logging.error(f"Error fetching compatible libraries: {e}")
raise
def create_notebook_file(cells, notebook_name):
nb = nbf.v4.new_notebook()
nb["cells"] = [
nbf.v4.new_code_cell(
cmd["source"]
if isinstance(cmd["source"], str)
else "\n".join(cmd["source"])
)
if cmd["cell_type"] == "code"
else nbf.v4.new_markdown_cell(cmd["source"])
for cmd in cells
]
with open(notebook_name, "w") as f:
nbf.write(nb, f)
logging.info(f"Notebook {notebook_name} created successfully")
def get_first_rows_as_df(dataset: str, config: str, split: str, limit: int):
try:
resp = client.get(
f"{BASE_DATASETS_SERVER_URL}/first-rows?dataset={dataset}&config={config}&split={split}"
)
resp.raise_for_status()
content = resp.json()
rows = content["rows"]
rows = [row["row"] for row in rows]
first_rows_df = pd.DataFrame.from_dict(rows).sample(frac=1).head(limit)
return first_rows_df
except Exception as e:
logging.error(f"Error fetching first rows: {e}")
raise
def longest_string_column(df):
longest_col = None
max_length = 0
for col in df.select_dtypes(include=["object", "string"]):
max_col_length = df[col].str.len().max()
if max_col_length > max_length:
max_length = max_col_length
longest_col = col
return longest_col
def generate_eda_cells(dataset_id):
yield from generate_cells(dataset_id, eda_cells, "eda")
def generate_rag_cells(dataset_id):
yield from generate_cells(dataset_id, rag_cells, "rag")
def generate_embedding_cells(dataset_id):
yield from generate_cells(dataset_id, embeggins_cells, "embeddings")
def _push_to_hub(
dataset_id,
notebook_file,
):
logging.info(f"Pushing notebook to hub: {dataset_id} on file {notebook_file}")
notebook_name = notebook_file.split("/")[-1]
api = HfApi(token=HF_TOKEN)
try:
logging.info(f"About to push {notebook_file} - {dataset_id}")
api.upload_file(
path_or_fileobj=notebook_file,
path_in_repo=notebook_name,
repo_id=NOTEBOOKS_REPOSITORY,
repo_type="dataset",
)
except Exception as e:
logging.info("Failed to push notebook", e)
raise
def generate_cells(dataset_id, cells, notebook_type="eda"):
logging.info(f"Generating notebook for dataset {dataset_id}")
try:
libraries = get_compatible_libraries(dataset_id)
except Exception as err:
gr.Error("Unable to retrieve dataset info from HF Hub.")
logging.error(f"Failed to fetch compatible libraries: {err}")
return []
if not libraries:
logging.error(f"Dataset not compatible with pandas library - not libraries")
yield "", "## ❌ This dataset is not compatible with pandas library ❌"
return
pandas_library = next(
(lib for lib in libraries.get("libraries", []) if lib["library"] == "pandas"),
None,
)
if not pandas_library:
logging.error("Dataset not compatible with pandas library - not pandas library")
yield "", "## ❌ This dataset is not compatible with pandas library ❌"
return
first_config_loading_code = pandas_library["loading_codes"][0]
first_code = first_config_loading_code["code"]
first_config = first_config_loading_code["config_name"]
first_split = list(first_config_loading_code["arguments"]["splits"].keys())[0]
df = get_first_rows_as_df(dataset_id, first_config, first_split, 3)
longest_col = longest_string_column(df)
html_code = f"<iframe src='https://huggingface.co/datasets/{dataset_id}/embed/viewer' width='80%' height='560px'></iframe>"
wildcards = ["{dataset_name}", "{first_code}", "{html_code}", "{longest_col}"]
replacements = [dataset_id, first_code, html_code, longest_col]
has_numeric_columns = len(df.select_dtypes(include=["number"]).columns) > 0
has_categoric_columns = len(df.select_dtypes(include=["object"]).columns) > 0
if notebook_type in ("rag", "embeddings") and not has_categoric_columns:
logging.error(
"Dataset does not have categorical columns, which are required for RAG generation."
)
yield (
"",
"## ❌ This dataset does not have categorical columns, which are required for Embeddings/RAG generation ❌",
)
return
if notebook_type == "eda" and not (has_categoric_columns or has_numeric_columns):
logging.error(
"Dataset does not have categorical or numeric columns, which are required for EDA generation."
)
yield (
"",
"## ❌ This dataset does not have categorical or numeric columns, which are required for EDA generation ❌",
)
return
cells = replace_wildcards(
cells, wildcards, replacements, has_numeric_columns, has_categoric_columns
)
generated_text = ""
# Show only the first 30 lines, would like to have a scroll in gr.Code https://github.com/gradio-app/gradio/issues/9192
for cell in cells:
if cell["cell_type"] == "markdown":
continue
generated_text += cell["source"] + "\n\n"
yield generated_text, ""
if generated_text.count("\n") > 30:
generated_text += (
f"## See more lines available in the generated notebook πŸ€— ......"
)
yield generated_text, ""
break
notebook_name = f"{dataset_id.replace('/', '-')}-{notebook_type}.ipynb"
create_notebook_file(cells, notebook_name=notebook_name)
_push_to_hub(dataset_id, notebook_name)
notebook_link = f"https://colab.research.google.com/#fileId=https%3A//huggingface.co/datasets/asoria/dataset-notebook-creator-content/blob/main/{notebook_name}"
yield (
generated_text,
f"## βœ… Here you have the [generated notebook]({notebook_link}) βœ…",
)
with gr.Blocks(fill_height=True, fill_width=True) as demo:
gr.Markdown("# πŸ€– Dataset notebook creator πŸ•΅οΈ")
with gr.Row(equal_height=True):
with gr.Column(scale=2):
text_input = gr.Textbox(label="Suggested notebook type", visible=False)
dataset_name = HuggingfaceHubSearch(
label="Hub Dataset ID",
placeholder="Search for dataset id on Huggingface",
search_type="dataset",
value="",
)
dataset_samples = gr.Examples(
examples=[
[
"scikit-learn/iris",
"Try this dataset for Exploratory Data Analysis",
],
[
"infinite-dataset-hub/GlobaleCuisineRecipes",
"Try this dataset for Embeddings generation",
],
[
"infinite-dataset-hub/GlobalBestSellersSummaries",
"Try this dataset for RAG generation",
],
],
inputs=[dataset_name, text_input],
cache_examples=False,
)
@gr.render(inputs=dataset_name)
def embed(name):
if not name:
return gr.Markdown("### No dataset provided")
html_code = f"""
<iframe
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
frameborder="0"
width="100%"
height="350px"
></iframe>
"""
return gr.HTML(value=html_code, elem_classes="viewer")
with gr.Row():
generate_eda_btn = gr.Button("Exploratory Data Analysis")
generate_embedding_btn = gr.Button("Embeddings")
generate_rag_btn = gr.Button("RAG")
generate_training_btn = gr.Button(
"Training - Coming soon", interactive=False
)
with gr.Column(scale=2):
code_component = gr.Code(
language="python", label="Notebook Code Preview", lines=40
)
go_to_notebook = gr.Markdown("", visible=True)
generate_eda_btn.click(
generate_eda_cells,
inputs=[dataset_name],
outputs=[code_component, go_to_notebook],
)
generate_embedding_btn.click(
generate_embedding_cells,
inputs=[dataset_name],
outputs=[code_component, go_to_notebook],
)
generate_rag_btn.click(
generate_rag_cells,
inputs=[dataset_name],
outputs=[code_component, go_to_notebook],
)
gr.Markdown(
"🚧 Note: Some code may not be compatible with datasets that contain binary data or complex structures. 🚧"
)
demo.launch()