| | import os |
| | import shutil |
| | import time |
| | import csv |
| | import uuid |
| | from itertools import cycle |
| | from typing import List, Tuple, Optional |
| | from datetime import datetime |
| | import gradio as gr |
| |
|
| | from .data_fetcher import read_hacker_news_rss, format_published_time |
| | from .model_trainer import ( |
| | authenticate_hf, |
| | train_with_dataset, |
| | get_top_hits, |
| | load_embedding_model, |
| | upload_model_to_hub |
| | ) |
| | from .config import AppConfig |
| | from .vibe_logic import VibeChecker |
| | from sentence_transformers import SentenceTransformer |
| |
|
| | class HackerNewsFineTuner: |
| | """ |
| | Encapsulates all application logic and state for a single user session. |
| | """ |
| |
|
| | def __init__(self, config: AppConfig = AppConfig): |
| | |
| | self.config = config |
| | |
| | |
| | self.session_id = str(uuid.uuid4()) |
| | |
| | |
| | self.session_root = self.config.ARTIFACTS_DIR / self.session_id |
| | self.output_dir = self.session_root / "embedding_gemma_finetuned" |
| | self.dataset_export_file = self.session_root / "training_dataset.csv" |
| | |
| | |
| | os.makedirs(self.output_dir, exist_ok=True) |
| | print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}") |
| |
|
| | |
| | self.model: Optional[SentenceTransformer] = None |
| | self.vibe_checker: Optional[VibeChecker] = None |
| | self.titles: List[str] = [] |
| | self.last_hn_dataset: List[List[str]] = [] |
| | self.imported_dataset: List[List[str]] = [] |
| |
|
| | |
| | authenticate_hf(self.config.HF_TOKEN) |
| |
|
| | def _update_vibe_checker(self): |
| | """Initializes or updates the VibeChecker with the current model state.""" |
| | if self.model: |
| | self.vibe_checker = VibeChecker( |
| | model=self.model, |
| | query_anchor=self.config.QUERY_ANCHOR, |
| | task_name=self.config.TASK_NAME |
| | ) |
| | else: |
| | self.vibe_checker = None |
| |
|
| | |
| |
|
| | def refresh_data_and_model(self) -> Tuple[List[str], str]: |
| | """ |
| | Reloads model and fetches data. |
| | Returns: |
| | - List of titles (for the UI) |
| | - Status message string |
| | """ |
| | print(f"[{self.session_id}] Reloading model and data...") |
| |
|
| | self.last_hn_dataset = [] |
| | self.imported_dataset = [] |
| |
|
| | |
| | try: |
| | self.model = load_embedding_model(self.config.MODEL_NAME) |
| | self._update_vibe_checker() |
| | except Exception as e: |
| | error_msg = f"CRITICAL ERROR: Model failed to load. {e}" |
| | print(error_msg) |
| | self.model = None |
| | self._update_vibe_checker() |
| | return [], error_msg |
| |
|
| | |
| | news_feed, status_msg = read_hacker_news_rss(self.config) |
| | titles_out = [] |
| | status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}" |
| |
|
| | if news_feed is not None and news_feed.entries: |
| | titles_out = [item.title for item in news_feed.entries] |
| | else: |
| | titles_out = ["Error fetching news."] |
| | gr.Warning(f"Data reload failed. {status_msg}") |
| |
|
| | self.titles = titles_out |
| |
|
| | |
| | return self.titles, status_value |
| |
|
| | |
| | def import_additional_dataset(self, file_path: str) -> str: |
| | if not file_path: |
| | return "Please upload a CSV file." |
| | new_dataset, num_imported = [], 0 |
| | try: |
| | with open(file_path, 'r', newline='', encoding='utf-8') as f: |
| | reader = csv.reader(f) |
| | try: |
| | header = next(reader) |
| | |
| | if not (header and header[0].lower().strip() == 'anchor'): |
| | f.seek(0) |
| | except StopIteration: |
| | return "Error: Uploaded file is empty." |
| |
|
| | for row in reader: |
| | if len(row) == 3: |
| | new_dataset.append([s.strip() for s in row]) |
| | num_imported += 1 |
| | if num_imported == 0: |
| | raise ValueError("No valid rows found.") |
| | self.imported_dataset = new_dataset |
| | return f"Imported {num_imported} triplets." |
| | except Exception as e: |
| | return f"Import failed: {e}" |
| |
|
| | def export_dataset(self) -> Optional[str]: |
| | if not self.last_hn_dataset: |
| | gr.Warning("No dataset generated yet.") |
| | return None |
| | |
| | file_path = self.dataset_export_file |
| | try: |
| | with open(file_path, 'w', newline='', encoding='utf-8') as f: |
| | writer = csv.writer(f) |
| | writer.writerow(['Anchor', 'Positive', 'Negative']) |
| | writer.writerows(self.last_hn_dataset) |
| | gr.Info(f"Dataset exported.") |
| | return str(file_path) |
| | except Exception as e: |
| | gr.Error(f"Export failed: {e}") |
| | return None |
| |
|
| | def download_model(self) -> Optional[str]: |
| | if not os.path.exists(self.output_dir): |
| | gr.Warning("No model trained yet.") |
| | return None |
| | |
| | timestamp = int(time.time()) |
| | try: |
| | base_name = self.session_root / f"model_finetuned_{timestamp}" |
| | archive_path = shutil.make_archive( |
| | base_name=str(base_name), |
| | format='zip', |
| | root_dir=self.output_dir, |
| | ) |
| | gr.Info(f"Model zipped.") |
| | return archive_path |
| | except Exception as e: |
| | gr.Error(f"Zip failed: {e}") |
| | return None |
| |
|
| | def upload_model(self, repo_name: str, oauth_token_str: str) -> str: |
| | """ |
| | Calls the model trainer upload function using the session's output directory. |
| | """ |
| | if not os.path.exists(self.output_dir): |
| | return "❌ Error: No trained model found in this session. Run training first." |
| | if not repo_name.strip(): |
| | return "❌ Error: Please specify a repository name." |
| | |
| | return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str) |
| |
|
| |
|
| | |
| | def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]: |
| | """ |
| | Creates triplets (Anchor, Positive, Negative) from the selected indices. |
| | Uses cycling to balance the dataset if the number of positives != negatives. |
| | """ |
| | if not pos_ids or not neg_ids: |
| | return [] |
| |
|
| | |
| | pos_titles = [self.titles[i] for i in pos_ids] |
| | neg_titles = [self.titles[i] for i in neg_ids] |
| |
|
| | dataset = [] |
| |
|
| | |
| | |
| | |
| | |
| | if len(pos_titles) >= len(neg_titles): |
| | |
| | neg_cycle = cycle(neg_titles) |
| | for p_title in pos_titles: |
| | dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)]) |
| | else: |
| | |
| | pos_cycle = cycle(pos_titles) |
| | for n_title in neg_titles: |
| | dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title]) |
| |
|
| | return dataset |
| |
|
| | def training(self, pos_ids: List[int], neg_ids: List[int]) -> str: |
| | """ |
| | Main training entry point. |
| | Args: |
| | pos_ids: Indices of stories marked as "Favorite" |
| | neg_ids: Indices of stories marked as "Dislike" |
| | """ |
| | if self.model is None: |
| | raise gr.Error("Model not loaded.") |
| | |
| | if self.imported_dataset: |
| | self.last_hn_dataset = self.imported_dataset |
| | else: |
| | |
| | if not pos_ids: |
| | raise gr.Error("Please select at least one 'Favorite' story.") |
| | if not neg_ids: |
| | raise gr.Error("Please select at least one 'Dislike' story.") |
| | |
| | |
| | self.last_hn_dataset = self._create_hn_dataset(pos_ids, neg_ids) |
| | |
| | if not self.last_hn_dataset: |
| | raise gr.Error("Dataset generation failed (Empty dataset).") |
| |
|
| | def semantic_search_fn() -> str: |
| | return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR) |
| |
|
| | result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n" |
| | print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...") |
| | |
| | train_with_dataset( |
| | model=self.model, |
| | dataset=self.last_hn_dataset, |
| | output_dir=self.output_dir, |
| | task_name=self.config.TASK_NAME, |
| | search_fn=semantic_search_fn |
| | ) |
| | |
| | self._update_vibe_checker() |
| | print(f"[{self.session_id}] Training Complete.") |
| |
|
| | result += "### Search (After):\n" + f"{semantic_search_fn()}" |
| | return result |
| |
|
| | def is_model_tuned(self) -> bool: |
| | return True if self.last_hn_dataset else False |
| |
|
| | |
| | def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]: |
| | model_name = "<unsaved>" |
| | if self.last_hn_dataset: |
| | model_name = f"./{self.output_dir}" |
| |
|
| | info_text = (f"**Session:** {self.session_id[:6]}<br>" |
| | f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
| | f"**Tuned Model:** `{model_name}`") |
| |
|
| | if not self.vibe_checker: |
| | return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_css("gray")), info_text |
| | if not news_text or len(news_text.split()) < 3: |
| | return "N/A", "Text too short", gr.update(value=self._generate_vibe_css("gray")), info_text |
| |
|
| | try: |
| | vibe_result = self.vibe_checker.check(news_text) |
| | status = vibe_result.status_html.split('>')[1].split('<')[0] |
| | return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_css(vibe_result.color_hsl)), info_text |
| | except Exception as e: |
| | return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_css("gray")), info_text |
| |
|
| | def _generate_vibe_css(self, color: str) -> str: |
| | """Generates a style block to update the Mood Lamp textbox background.""" |
| | return f"<style>#mood_lamp input {{ background-color: {color} !important; transition: background-color 0.5s ease; }}</style>" |
| |
|
| | |
| | def fetch_and_display_mood_feed(self) -> str: |
| | if not self.vibe_checker: |
| | return "Model not ready. Please wait or reload." |
| | |
| | feed, status = read_hacker_news_rss(self.config) |
| | if not feed or not feed.entries: |
| | return f"**Feed Error:** {status}" |
| |
|
| | scored_entries = [] |
| | for entry in feed.entries: |
| | title = entry.get('title') |
| | if not title: continue |
| | |
| | vibe_result = self.vibe_checker.check(title) |
| | scored_entries.append({ |
| | "title": title, |
| | "link": entry.get('link', '#'), |
| | "comments": entry.get('comments', '#'), |
| | "published": format_published_time(entry.published_parsed), |
| | "mood": vibe_result |
| | }) |
| |
|
| | scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True) |
| |
|
| | model_name = "<unsaved>" |
| | if self.last_hn_dataset: |
| | model_name = f"./{self.output_dir}" |
| |
|
| | md = (f"## Hacker News Top Stories\n" |
| | f"**Session:** {self.session_id[:6]}<br>" |
| | f"**Base Model:** `{self.config.MODEL_NAME}`<br>" |
| | f"**Tuned Model:** `{model_name}`<br>" |
| | f"**Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
| | "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n") |
| | |
| | for item in scored_entries: |
| | md += (f"| {item['mood'].status_html} " |
| | f"| {item['mood'].raw_score:.4f} " |
| | f"| [{item['title']}]({item['link']}) " |
| | f"| [Comments]({item['comments']}) " |
| | f"| {item['published']} |\n") |
| | return md |
| |
|