import gradio as gr import requests import logging import os import json import shutil import glob import queue import lancedb from datetime import datetime from dotenv import load_dotenv, set_key import yaml import pandas as pd from typing import List, Optional from pydantic import BaseModel # Set up logging log_queue = queue.Queue() logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) load_dotenv('indexing/.env') API_BASE_URL = os.getenv('API_BASE_URL', 'http://localhost:8012') LLM_API_BASE = os.getenv('LLM_API_BASE', 'http://localhost:11434') EMBEDDINGS_API_BASE = os.getenv('EMBEDDINGS_API_BASE', 'http://localhost:11434') ROOT_DIR = os.getenv('ROOT_DIR', 'indexing') # Data models class IndexingRequest(BaseModel): llm_model: str embed_model: str llm_api_base: str embed_api_base: str root: str verbose: bool = False nocache: bool = False resume: Optional[str] = None reporter: str = "rich" emit: List[str] = ["parquet"] custom_args: Optional[str] = None class PromptTuneRequest(BaseModel): root: str = "./{ROOT_DIR}" domain: Optional[str] = None method: str = "random" limit: int = 15 language: Optional[str] = None max_tokens: int = 2000 chunk_size: int = 200 no_entity_types: bool = False output: str = "./{ROOT_DIR}/prompts" class QueueHandler(logging.Handler): def __init__(self, log_queue): super().__init__() self.log_queue = log_queue def emit(self, record): self.log_queue.put(self.format(record)) queue_handler = QueueHandler(log_queue) logging.getLogger().addHandler(queue_handler) def update_logs(): logs = [] while not log_queue.empty(): logs.append(log_queue.get()) return "\n".join(logs) ##########SETTINGS################ def load_settings(): config_path = os.getenv('GRAPHRAG_CONFIG', 'config.yaml') if os.path.exists(config_path): with open(config_path, 'r') as config_file: config = yaml.safe_load(config_file) else: config = {} settings = { 'llm_model': os.getenv('LLM_MODEL', config.get('llm_model')), 'embedding_model': os.getenv('EMBEDDINGS_MODEL', config.get('embedding_model')), 'community_level': int(os.getenv('COMMUNITY_LEVEL', config.get('community_level', 2))), 'token_limit': int(os.getenv('TOKEN_LIMIT', config.get('token_limit', 4096))), 'api_key': os.getenv('GRAPHRAG_API_KEY', config.get('api_key')), 'api_base': os.getenv('LLM_API_BASE', config.get('api_base')), 'embeddings_api_base': os.getenv('EMBEDDINGS_API_BASE', config.get('embeddings_api_base')), 'api_type': os.getenv('API_TYPE', config.get('api_type', 'openai')), } return settings #######FILE_MANAGEMENT############## def list_output_files(root_dir): output_dir = os.path.join(root_dir, "output") files = [] for root, _, filenames in os.walk(output_dir): for filename in filenames: files.append(os.path.join(root, filename)) return files def update_file_list(): files = list_input_files() return gr.update(choices=[f["path"] for f in files]) def update_file_content(file_path): if not file_path: return "" try: with open(file_path, 'r', encoding='utf-8') as file: content = file.read() return content except Exception as e: logging.error(f"Error reading file: {str(e)}") return f"Error reading file: {str(e)}" def list_output_folders(): output_dir = os.path.join(ROOT_DIR, "output") folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] return sorted(folders, reverse=True) def update_output_folder_list(): folders = list_output_folders() return gr.update(choices=folders, value=folders[0] if folders else None) def list_folder_contents(folder_name): folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") contents = [] if os.path.exists(folder_path): for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) if os.path.isdir(item_path): contents.append(f"[DIR] {item}") else: _, ext = os.path.splitext(item) contents.append(f"[{ext[1:].upper()}] {item}") return contents def update_folder_content_list(folder_name): if isinstance(folder_name, list) and folder_name: folder_name = folder_name[0] elif not folder_name: return gr.update(choices=[]) contents = list_folder_contents(folder_name) return gr.update(choices=contents) def handle_content_selection(folder_name, selected_item): if isinstance(selected_item, list) and selected_item: selected_item = selected_item[0] # Take the first item if it's a list if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): dir_name = selected_item[6:] # Remove "[DIR] " prefix sub_contents = list_folder_contents(os.path.join(ROOT_DIR, "output", folder_name, dir_name)) return gr.update(choices=sub_contents), "", "" elif isinstance(selected_item, str): file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present file_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts", file_name) file_size = os.path.getsize(file_path) file_type = os.path.splitext(file_name)[1] file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" content = read_file_content(file_path) return gr.update(), file_info, content else: return gr.update(), "", "" def initialize_selected_folder(folder_name): if not folder_name: return "Please select a folder first.", gr.update(choices=[]) folder_path = os.path.join(ROOT_DIR, "output", folder_name, "artifacts") if not os.path.exists(folder_path): return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) contents = list_folder_contents(folder_path) return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) def upload_file(file): if file is not None: input_dir = os.path.join(ROOT_DIR, 'input') os.makedirs(input_dir, exist_ok=True) # Get the original filename from the uploaded file original_filename = file.name # Create the destination path destination_path = os.path.join(input_dir, os.path.basename(original_filename)) # Move the uploaded file to the destination path shutil.move(file.name, destination_path) logging.info(f"File uploaded and moved to: {destination_path}") status = f"File uploaded: {os.path.basename(original_filename)}" else: status = "No file uploaded" # Get the updated file list updated_file_list = [f["path"] for f in list_input_files()] return status, gr.update(choices=updated_file_list), update_logs() def list_input_files(): input_dir = os.path.join(ROOT_DIR, 'input') files = [] if os.path.exists(input_dir): files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] def delete_file(file_path): try: os.remove(file_path) logging.info(f"File deleted: {file_path}") status = f"File deleted: {os.path.basename(file_path)}" except Exception as e: logging.error(f"Error deleting file: {str(e)}") status = f"Error deleting file: {str(e)}" # Get the updated file list updated_file_list = [f["path"] for f in list_input_files()] return status, gr.update(choices=updated_file_list), update_logs() def read_file_content(file_path): try: if file_path.endswith('.parquet'): df = pd.read_parquet(file_path) # Get basic information about the DataFrame info = f"Parquet File: {os.path.basename(file_path)}\n" info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" # Display first few rows info += "First 5 rows:\n" info += df.head().to_string() + "\n\n" # Display basic statistics info += "Basic Statistics:\n" info += df.describe().to_string() return info else: with open(file_path, 'r', encoding='utf-8', errors='replace') as file: content = file.read() return content except Exception as e: logging.error(f"Error reading file: {str(e)}") return f"Error reading file: {str(e)}" def save_file_content(file_path, content): try: with open(file_path, 'w') as file: file.write(content) logging.info(f"File saved: {file_path}") status = f"File saved: {os.path.basename(file_path)}" except Exception as e: logging.error(f"Error saving file: {str(e)}") status = f"Error saving file: {str(e)}" return status, update_logs() def manage_data(): db = lancedb.connect(f"{ROOT_DIR}/lancedb") tables = db.table_names() table_info = "" if tables: table = db[tables[0]] table_info = f"Table: {tables[0]}\nSchema: {table.schema}" input_files = list_input_files() return { "database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", "input_files": input_files } def find_latest_graph_file(root_dir): pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") graph_files = glob.glob(pattern) if not graph_files: # If no files found, try excluding .DS_Store output_dir = os.path.join(root_dir, "output") run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] if run_dirs: latest_run = max(run_dirs) pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") graph_files = glob.glob(pattern) if not graph_files: return None # Sort files by modification time, most recent first latest_file = max(graph_files, key=os.path.getmtime) return latest_file def find_latest_output_folder(): root_dir =f"{ROOT_DIR}/output" folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] if not folders: raise ValueError("No output folders found") # Sort folders by creation time, most recent first sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) latest_folder = None timestamp = None for folder in sorted_folders: try: # Try to parse the folder name as a timestamp timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") latest_folder = folder break except ValueError: # If the folder name is not a valid timestamp, skip it continue if latest_folder is None: raise ValueError("No valid timestamp folders found") latest_path = os.path.join(root_dir, latest_folder) artifacts_path = os.path.join(latest_path, "artifacts") if not os.path.exists(artifacts_path): raise ValueError(f"Artifacts folder not found in {latest_path}") return latest_path, latest_folder def initialize_data(): global entity_df, relationship_df, text_unit_df, report_df, covariate_df tables = { "entity_df": "create_final_nodes", "relationship_df": "create_final_edges", "text_unit_df": "create_final_text_units", "report_df": "create_final_reports", "covariate_df": "create_final_covariates" } timestamp = None # Initialize timestamp to None try: latest_output_folder, timestamp = find_latest_output_folder() artifacts_folder = os.path.join(latest_output_folder, "artifacts") for df_name, file_prefix in tables.items(): file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") matching_files = glob.glob(file_pattern) if matching_files: latest_file = max(matching_files, key=os.path.getctime) df = pd.read_parquet(latest_file) globals()[df_name] = df logging.info(f"Successfully loaded {df_name} from {latest_file}") else: logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") globals()[df_name] = pd.DataFrame() except Exception as e: logging.error(f"Error initializing data: {str(e)}") for df_name in tables.keys(): globals()[df_name] = pd.DataFrame() return timestamp # Call initialize_data and store the timestamp current_timestamp = initialize_data() ###########MODELS################## def normalize_api_base(api_base: str) -> str: """Normalize the API base URL by removing trailing slashes and /v1 or /api suffixes.""" api_base = api_base.rstrip('/') if api_base.endswith('/v1') or api_base.endswith('/api'): api_base = api_base[:-3] return api_base def is_ollama_api(base_url: str) -> bool: """Check if the given base URL is for Ollama API.""" try: response = requests.get(f"{normalize_api_base(base_url)}/api/tags") return response.status_code == 200 except requests.RequestException: return False def get_ollama_models(base_url: str) -> List[str]: """Fetch available models from Ollama API.""" try: response = requests.get(f"{normalize_api_base(base_url)}/api/tags") response.raise_for_status() models = response.json().get('models', []) return [model['name'] for model in models] except requests.RequestException as e: logger.error(f"Error fetching Ollama models: {str(e)}") return [] def get_openai_compatible_models(base_url: str) -> List[str]: """Fetch available models from OpenAI-compatible API.""" try: response = requests.get(f"{normalize_api_base(base_url)}/v1/models") response.raise_for_status() models = response.json().get('data', []) return [model['id'] for model in models] except requests.RequestException as e: logger.error(f"Error fetching OpenAI-compatible models: {str(e)}") return [] def get_local_models(base_url: str) -> List[str]: """Get available models based on the API type.""" if is_ollama_api(base_url): return get_ollama_models(base_url) else: return get_openai_compatible_models(base_url) def get_model_params(base_url: str, model_name: str) -> dict: """Get model parameters for Ollama models.""" if is_ollama_api(base_url): try: response = requests.post(f"{normalize_api_base(base_url)}/api/show", json={"name": model_name}) response.raise_for_status() model_info = response.json() return model_info.get('parameters', {}) except requests.RequestException as e: logger.error(f"Error fetching Ollama model parameters: {str(e)}") return {} #########API########### def start_indexing(request: IndexingRequest): url = f"{API_BASE_URL}/v1/index" try: response = requests.post(url, json=request.dict()) response.raise_for_status() result = response.json() return result['message'], gr.update(interactive=False), gr.update(interactive=True) except requests.RequestException as e: logger.error(f"Error starting indexing: {str(e)}") return f"Error: {str(e)}", gr.update(interactive=True), gr.update(interactive=False) def check_indexing_status(): url = f"{API_BASE_URL}/v1/index_status" try: response = requests.get(url) response.raise_for_status() result = response.json() return result['status'], "\n".join(result['logs']) except requests.RequestException as e: logger.error(f"Error checking indexing status: {str(e)}") return "Error", f"Failed to check indexing status: {str(e)}" def start_prompt_tuning(request: PromptTuneRequest): url = f"{API_BASE_URL}/v1/prompt_tune" try: response = requests.post(url, json=request.dict()) response.raise_for_status() result = response.json() return result['message'], gr.update(interactive=False) except requests.RequestException as e: logger.error(f"Error starting prompt tuning: {str(e)}") return f"Error: {str(e)}", gr.update(interactive=True) def check_prompt_tuning_status(): url = f"{API_BASE_URL}/v1/prompt_tune_status" try: response = requests.get(url) response.raise_for_status() result = response.json() return result['status'], "\n".join(result['logs']) except requests.RequestException as e: logger.error(f"Error checking prompt tuning status: {str(e)}") return "Error", f"Failed to check prompt tuning status: {str(e)}" def update_model_params(model_name): params = get_model_params(model_name) return gr.update(value=json.dumps(params, indent=2)) ########################### css = """ html, body { margin: 0; padding: 0; height: 100vh; overflow: hidden; } .gradio-container { margin: 0 !important; padding: 0 !important; width: 100vw !important; max-width: 100vw !important; height: 100vh !important; max-height: 100vh !important; overflow: auto; display: flex; flex-direction: column; } #main-container { flex: 1; display: flex; overflow: hidden; } #left-column, #right-column { height: 100%; overflow-y: auto; padding: 10px; } #left-column { flex: 1; } #right-column { flex: 2; display: flex; flex-direction: column; } #chat-container { flex: 0 0 auto; /* Don't allow this to grow */ height: 100%; display: flex; flex-direction: column; overflow: hidden; border: 1px solid var(--color-accent); border-radius: 8px; padding: 10px; overflow-y: auto; } #chatbot { overflow-y: hidden; height: 100%; } #chat-input-row { margin-top: 10px; } #visualization-plot { width: 100%; aspect-ratio: 1 / 1; max-height: 600px; /* Adjust this value as needed */ } #vis-controls-row { display: flex; justify-content: space-between; align-items: center; margin-top: 10px; } #vis-controls-row > * { flex: 1; margin: 0 5px; } #vis-status { margin-top: 10px; } /* Chat input styling */ #chat-input-row { display: flex; flex-direction: column; } #chat-input-row > div { width: 100% !important; } #chat-input-row input[type="text"] { width: 100% !important; } /* Adjust padding for all containers */ .gr-box, .gr-form, .gr-panel { padding: 10px !important; } /* Ensure all textboxes and textareas have full height */ .gr-textbox, .gr-textarea { height: auto !important; min-height: 100px !important; } /* Ensure all dropdowns have full width */ .gr-dropdown { width: 100% !important; } :root { --color-background: #2C3639; --color-foreground: #3F4E4F; --color-accent: #A27B5C; --color-text: #DCD7C9; } body, .gradio-container { background-color: var(--color-background); color: var(--color-text); } .gr-button { background-color: var(--color-accent); color: var(--color-text); } .gr-input, .gr-textarea, .gr-dropdown { background-color: var(--color-foreground); color: var(--color-text); border: 1px solid var(--color-accent); } .gr-panel { background-color: var(--color-foreground); border: 1px solid var(--color-accent); } .gr-box { border-radius: 8px; margin-bottom: 10px; background-color: var(--color-foreground); } .gr-padded { padding: 10px; } .gr-form { background-color: var(--color-foreground); } .gr-input-label, .gr-radio-label { color: var(--color-text); } .gr-checkbox-label { color: var(--color-text); } .gr-markdown { color: var(--color-text); } .gr-accordion { background-color: var(--color-foreground); border: 1px solid var(--color-accent); } .gr-accordion-header { background-color: var(--color-accent); color: var(--color-text); } #visualization-container { display: flex; flex-direction: column; border: 2px solid var(--color-accent); border-radius: 8px; margin-top: 20px; padding: 10px; background-color: var(--color-foreground); height: calc(100vh - 300px); /* Adjust this value as needed */ } #visualization-plot { width: 100%; height: 100%; } #vis-controls-row { display: flex; justify-content: space-between; align-items: center; margin-top: 10px; } #vis-controls-row > * { flex: 1; margin: 0 5px; } #vis-status { margin-top: 10px; } #log-container { background-color: var(--color-foreground); border: 1px solid var(--color-accent); border-radius: 8px; padding: 10px; margin-top: 20px; max-height: auto; overflow-y: auto; } .setting-accordion .label-wrap { cursor: pointer; } .setting-accordion .icon { transition: transform 0.3s ease; } .setting-accordion[open] .icon { transform: rotate(90deg); } .gr-form.gr-box { border: none !important; background: none !important; } .model-params { border-top: 1px solid var(--color-accent); margin-top: 10px; padding-top: 10px; } """ def create_interface(): settings = load_settings() llm_api_base = normalize_api_base(settings['api_base']) embeddings_api_base = normalize_api_base(settings['embeddings_api_base']) with gr.Blocks(theme=gr.themes.Base(), css=css) as demo: gr.Markdown("# GraphRAG Indexer") with gr.Tabs(): with gr.TabItem("Indexing"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Indexing Configuration") with gr.Row(): llm_name = gr.Dropdown(label="LLM Model", choices=[], value=settings['llm_model'], allow_custom_value=True) refresh_llm_btn = gr.Button("🔄", size='sm', scale=0) with gr.Row(): embed_name = gr.Dropdown(label="Embedding Model", choices=[], value=settings['embedding_model'], allow_custom_value=True) refresh_embed_btn = gr.Button("🔄", size='sm', scale=0) save_config_button = gr.Button("Save Configuration", variant="primary") config_status = gr.Textbox(label="Configuration Status", lines=2) with gr.Row(): with gr.Column(scale=1): root_dir = gr.Textbox(label="Root Directory (Edit in .env file)", value=f"{ROOT_DIR}") with gr.Group(): verbose = gr.Checkbox(label="Verbose", interactive=True, value=True) nocache = gr.Checkbox(label="No Cache", interactive=True, value=True) with gr.Accordion("Advanced Options", open=True): resume = gr.Textbox(label="Resume Timestamp (optional)") reporter = gr.Dropdown( label="Reporter", choices=["rich", "print", "none"], value="rich", interactive=True ) emit_formats = gr.CheckboxGroup( label="Emit Formats", choices=["json", "csv", "parquet"], value=["parquet"], interactive=True ) custom_args = gr.Textbox(label="Custom CLI Arguments", placeholder="--arg1 value1 --arg2 value2") with gr.Column(scale=1): gr.Markdown("## Indexing Output") index_output = gr.Textbox(label="Output", lines=10) index_status = gr.Textbox(label="Status", lines=2) run_index_button = gr.Button("Run Indexing", variant="primary") check_status_button = gr.Button("Check Indexing Status") with gr.TabItem("Prompt Tuning"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Prompt Tuning Configuration") pt_root = gr.Textbox(label="Root Directory", value=f"{ROOT_DIR}", interactive=True) pt_domain = gr.Textbox(label="Domain (optional)") pt_method = gr.Dropdown( label="Method", choices=["random", "top", "all"], value="random", interactive=True ) pt_limit = gr.Number(label="Limit", value=15, precision=0, interactive=True) pt_language = gr.Textbox(label="Language (optional)") pt_max_tokens = gr.Number(label="Max Tokens", value=2000, precision=0, interactive=True) pt_chunk_size = gr.Number(label="Chunk Size", value=200, precision=0, interactive=True) pt_no_entity_types = gr.Checkbox(label="No Entity Types", value=False) pt_output_dir = gr.Textbox(label="Output Directory", value=f"{ROOT_DIR}/prompts", interactive=True) save_pt_config_button = gr.Button("Save Prompt Tuning Configuration", variant="primary") with gr.Column(scale=1): gr.Markdown("## Prompt Tuning Output") pt_output = gr.Textbox(label="Output", lines=10) pt_status = gr.Textbox(label="Status", lines=10) run_pt_button = gr.Button("Run Prompt Tuning", variant="primary") check_pt_status_button = gr.Button("Check Prompt Tuning Status") with gr.TabItem("Data Management"): with gr.Row(): with gr.Column(scale=1): with gr.Accordion("File Upload", open=True): file_upload = gr.File(label="Upload File", file_types=[".txt", ".csv", ".parquet"]) upload_btn = gr.Button("Upload File", variant="primary") upload_output = gr.Textbox(label="Upload Status", visible=True) with gr.Accordion("File Management", open=True): file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) refresh_btn = gr.Button("Refresh File List", variant="secondary") file_content = gr.TextArea(label="File Content", lines=10) with gr.Row(): delete_btn = gr.Button("Delete Selected File", variant="stop") save_btn = gr.Button("Save Changes", variant="primary") operation_status = gr.Textbox(label="Operation Status", visible=True) with gr.Column(scale=1): with gr.Accordion("Output Folders", open=True): output_folder_list = gr.Dropdown(label="Select Output Folder", choices=[], interactive=True) refresh_output_btn = gr.Button("Refresh Output Folders", variant="secondary") folder_content_list = gr.Dropdown(label="Folder Contents", choices=[], interactive=True, multiselect=False) file_info = gr.Textbox(label="File Info", lines=3) output_content = gr.TextArea(label="File Content", lines=10) # Event handlers def refresh_llm_models(): models = get_local_models(llm_api_base) return gr.update(choices=models) def refresh_embed_models(): models = get_local_models(embeddings_api_base) return gr.update(choices=models) refresh_llm_btn.click( refresh_llm_models, outputs=[llm_name] ) refresh_embed_btn.click( refresh_embed_models, outputs=[embed_name] ) # Initialize model lists on page load demo.load(refresh_llm_models, outputs=[llm_name]) demo.load(refresh_embed_models, outputs=[embed_name]) def create_indexing_request(): return IndexingRequest( llm_model=llm_name.value, embed_model=embed_name.value, llm_api_base=llm_api_base, embed_api_base=embeddings_api_base, root=root_dir.value, verbose=verbose.value, nocache=nocache.value, resume=resume.value if resume.value else None, reporter=reporter.value, emit=[fmt for fmt in emit_formats.value], custom_args=custom_args.value if custom_args.value else None ) run_index_button.click( lambda: start_indexing(create_indexing_request()), outputs=[index_output, run_index_button, check_status_button] ) check_status_button.click( check_indexing_status, outputs=[index_status, index_output] ) def create_prompt_tune_request(): return PromptTuneRequest( root=pt_root.value, domain=pt_domain.value if pt_domain.value else None, method=pt_method.value, limit=int(pt_limit.value), language=pt_language.value if pt_language.value else None, max_tokens=int(pt_max_tokens.value), chunk_size=int(pt_chunk_size.value), no_entity_types=pt_no_entity_types.value, output=pt_output_dir.value ) def update_pt_output(request): result, button_update = start_prompt_tuning(request) return result, button_update, gr.update(value=f"Request: {request.dict()}") run_pt_button.click( lambda: update_pt_output(create_prompt_tune_request()), outputs=[pt_output, run_pt_button, pt_status] ) check_pt_status_button.click( check_prompt_tuning_status, outputs=[pt_status, pt_output] ) # Add event handlers for real-time updates pt_root.change(lambda x: gr.update(value=f"Root Directory changed to: {x}"), inputs=[pt_root], outputs=[pt_status]) pt_limit.change(lambda x: gr.update(value=f"Limit changed to: {x}"), inputs=[pt_limit], outputs=[pt_status]) pt_max_tokens.change(lambda x: gr.update(value=f"Max Tokens changed to: {x}"), inputs=[pt_max_tokens], outputs=[pt_status]) pt_chunk_size.change(lambda x: gr.update(value=f"Chunk Size changed to: {x}"), inputs=[pt_chunk_size], outputs=[pt_status]) pt_output_dir.change(lambda x: gr.update(value=f"Output Directory changed to: {x}"), inputs=[pt_output_dir], outputs=[pt_status]) # Event handlers for Data Management upload_btn.click( upload_file, inputs=[file_upload], outputs=[upload_output, file_list, operation_status] ) refresh_btn.click( update_file_list, outputs=[file_list] ) refresh_output_btn.click( update_output_folder_list, outputs=[output_folder_list] ) file_list.change( update_file_content, inputs=[file_list], outputs=[file_content] ) delete_btn.click( delete_file, inputs=[file_list], outputs=[operation_status, file_list, operation_status] ) save_btn.click( save_file_content, inputs=[file_list, file_content], outputs=[operation_status, operation_status] ) output_folder_list.change( update_folder_content_list, inputs=[output_folder_list], outputs=[folder_content_list] ) folder_content_list.change( handle_content_selection, inputs=[output_folder_list, folder_content_list], outputs=[folder_content_list, file_info, output_content] ) # Event handler for saving configuration save_config_button.click( update_env_file, inputs=[llm_name, embed_name], outputs=[config_status] ) # Event handler for saving prompt tuning configuration save_pt_config_button.click( save_prompt_tuning_config, inputs=[pt_root, pt_domain, pt_method, pt_limit, pt_language, pt_max_tokens, pt_chunk_size, pt_no_entity_types, pt_output_dir], outputs=[pt_status] ) # Initialize file list and output folder list demo.load(update_file_list, outputs=[file_list]) demo.load(update_output_folder_list, outputs=[output_folder_list]) return demo def update_env_file(llm_model, embed_model): env_path = os.path.join(ROOT_DIR, '.env') set_key(env_path, 'LLM_MODEL', llm_model) set_key(env_path, 'EMBEDDINGS_MODEL', embed_model) # Reload the environment variables load_dotenv(env_path, override=True) return f"Environment updated: LLM_MODEL={llm_model}, EMBEDDINGS_MODEL={embed_model}" def save_prompt_tuning_config(root, domain, method, limit, language, max_tokens, chunk_size, no_entity_types, output_dir): config = { 'prompt_tuning': { 'root': root, 'domain': domain, 'method': method, 'limit': limit, 'language': language, 'max_tokens': max_tokens, 'chunk_size': chunk_size, 'no_entity_types': no_entity_types, 'output': output_dir } } config_path = os.path.join(ROOT_DIR, 'prompt_tuning_config.yaml') with open(config_path, 'w') as f: yaml.dump(config, f) return f"Prompt Tuning configuration saved to {config_path}" demo = create_interface() if __name__ == "__main__": demo.launch(server_port=7861)