Spaces:
Sleeping
Sleeping
import os | |
import re | |
import logging | |
import uuid | |
import time | |
from datetime import datetime, timezone, timedelta | |
from collections import defaultdict | |
from typing import Optional, Dict, Any, List | |
import asyncio | |
import subprocess | |
import json | |
import tempfile | |
from fastapi import FastAPI, HTTPException, Body, BackgroundTasks, Path, Request | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, Field | |
import openai # For your custom API | |
import google.generativeai as genai # For Gemini API | |
from google.generativeai.types import GenerationConfig | |
# --- Imports for YouTube Transcript API --- | |
# Note: These are not directly used in the yt-dlp path but are good to keep if you ever add fallback methods. | |
from youtube_transcript_api import TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript | |
# --- Logging Configuration --- | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
CUSTOM_API_BASE_URL_DEFAULT = "https://api-q3ieh5raqfuad9o8.aistudio-app.com/v1" | |
CUSTOM_API_MODEL_DEFAULT = "gemma3:27b" | |
DEFAULT_GEMINI_MODEL = "gemini-1.5-flash-latest" | |
GEMINI_REQUEST_TIMEOUT_SECONDS = 300 | |
SUMMARY_REQUEST_TIMEOUT_SECONDS = 180 # A separate timeout for the summary task | |
MAX_TRANSCRIPT_CHARS = 750000 | |
COOKIES_FILE_PATH = "private.txt" | |
# --- In-Memory Task Storage --- | |
tasks_db: Dict[str, Dict[str, Any]] = {} | |
# --- Pydantic Models --- | |
class ChatPayload(BaseModel): | |
message: str | |
temperature: float = Field(0.6, ge=0.0, le=1.0) | |
class GeminiTaskRequest(BaseModel): | |
message: str | |
url: Optional[str] = None | |
gemini_model: Optional[str] = None | |
api_key: Optional[str] = Field(None, description="Gemini API Key (optional; uses Space secret if not provided)") | |
class TaskSubmissionResponse(BaseModel): | |
task_id: str | |
status: str | |
task_detail_url: str | |
class TaskStatusResponse(BaseModel): | |
task_id: str | |
status: str | |
submitted_at: datetime | |
last_updated_at: datetime | |
result: Optional[str] = None | |
error: Optional[str] = None | |
# --- RateLimiter Class --- | |
class RateLimiter: | |
def __init__(self, max_requests: int, time_window: timedelta): | |
self.max_requests = max_requests | |
self.time_window = time_window | |
self.requests: Dict[str, list] = defaultdict(list) | |
def _cleanup_old_requests(self, user_ip: str) -> None: | |
current_time = time.time() | |
self.requests[user_ip] = [ | |
timestamp for timestamp in self.requests[user_ip] | |
if current_time - timestamp < self.time_window.total_seconds() | |
] | |
def is_rate_limited(self, user_ip: str) -> bool: | |
self._cleanup_old_requests(user_ip) | |
current_count = len(self.requests[user_ip]) | |
current_time = time.time() | |
self.requests[user_ip].append(current_time) | |
return (current_count + 1) > self.max_requests | |
def get_current_count(self, user_ip: str) -> int: | |
self._cleanup_old_requests(user_ip) | |
return len(self.requests[user_ip]) | |
rate_limiter = RateLimiter(max_requests=15, time_window=timedelta(days=1)) | |
def get_user_ip(request: Request) -> str: | |
forwarded = request.headers.get("X-Forwarded-For") | |
if forwarded: | |
return forwarded.split(",")[0] | |
return request.client.host | |
# --- FastAPI App Initialization --- | |
app = FastAPI( | |
title="Dual Chat & Async Gemini API with YouTube Transcript Summarizer", | |
description="Made by Cody from chrunos.com. Fetches YouTube transcripts for summarization, with a Gemini fallback.", | |
version="3.0.0" | |
) | |
# --- Helper Functions --- | |
def is_youtube_url(url: Optional[str]) -> bool: | |
if not url: | |
return False | |
youtube_regex = ( | |
r'(https?://)?(www\.)?' | |
r'(youtube|youtu|youtube-nocookie)\.(com|be)/' | |
r'(watch\?v=|embed/|v/|shorts/|.+\?v=)?([^&=%\?]{11})' | |
) | |
return re.match(youtube_regex, url) is not None | |
def extract_video_id(url: str) -> Optional[str]: | |
if not is_youtube_url(url): | |
return None | |
patterns = [ | |
r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', | |
r'(?:embed\/|v\/|shorts\/)([0-9A-Za-z_-]{11}).*', | |
r'youtu\.be\/([0-9A-Za-z_-]{11}).*' | |
] | |
for pattern in patterns: | |
match = re.search(pattern, url) | |
if match: | |
return match.group(1) | |
logger.warning(f"Could not extract YouTube video ID from URL: {url}") | |
return None | |
def parse_vtt_content(vtt_content: str) -> str: | |
""" | |
Parse VTT subtitle content and extract only the text, removing timestamps, | |
formatting, and duplicate lines. This parser is designed to handle captions | |
that build up line-by-line. | |
""" | |
lines = vtt_content.split('\n') | |
text_lines = [] | |
seen_lines = set() | |
for line in lines: | |
line = line.strip() | |
# Skip empty lines, WEBVTT header, NOTE lines, timestamp lines, and cue settings | |
if (not line or line.startswith('WEBVTT') or line.startswith('NOTE') or | |
'-->' in line or line.isdigit() or 'align:' in line or 'position:' in line): | |
continue | |
# Clean up HTML tags and entities | |
clean_line = re.sub(r'<[^>]+>', '', line) | |
clean_line = re.sub(r'&[^;]+;', ' ', clean_line) | |
clean_line = re.sub(r'\s+', ' ', clean_line).strip() | |
# Skip if line is empty after cleaning or if it's an exact duplicate | |
if not clean_line or clean_line in seen_lines: | |
continue | |
# This logic helps handle auto-generated captions where a new line | |
# contains the previous line plus new words. | |
if text_lines and (clean_line in text_lines[-1] or text_lines[-1] in clean_line): | |
# If the new line is longer, replace the previous one | |
if len(clean_line) > len(text_lines[-1]): | |
seen_lines.discard(text_lines[-1]) | |
text_lines[-1] = clean_line | |
seen_lines.add(clean_line) | |
# Otherwise, if it's shorter or a substring, skip it | |
continue | |
seen_lines.add(clean_line) | |
text_lines.append(clean_line) | |
full_text = ' '.join(text_lines) | |
return re.sub(r'\s+', ' ', full_text).strip() | |
async def get_transcript_with_yt_dlp_cookies(video_id: str, task_id: str) -> Optional[str]: | |
""" | |
Fetches transcript using yt-dlp with a cookies file. | |
""" | |
logger.info(f"[Task {task_id}] Attempting transcript fetch for video ID: {video_id} using yt-dlp.") | |
if not os.path.exists(COOKIES_FILE_PATH): | |
logger.error(f"[Task {task_id}] Cookies file not found at {COOKIES_FILE_PATH}. Cannot fetch transcript.") | |
return None | |
try: | |
video_url = f"https://www.youtube.com/watch?v={video_id}" | |
with tempfile.TemporaryDirectory() as temp_dir: | |
cmd = [ | |
"yt-dlp", | |
"--skip-download", | |
"--write-auto-subs", | |
"--write-subs", | |
"--sub-lang", "en", | |
"--sub-format", "vtt", | |
"--cookies", COOKIES_FILE_PATH, | |
"-o", os.path.join(temp_dir, "%(id)s.%(ext)s"), | |
video_url | |
] | |
logger.info(f"[Task {task_id}] Running yt-dlp command...") | |
result = await asyncio.to_thread( | |
subprocess.run, cmd, capture_output=True, text=True, timeout=60 | |
) | |
if result.returncode != 0: | |
logger.error(f"[Task {task_id}] yt-dlp failed. Stderr: {result.stderr}") | |
return None | |
subtitle_file_path = os.path.join(temp_dir, f"{video_id}.en.vtt") | |
if not os.path.exists(subtitle_file_path): | |
logger.warning(f"[Task {task_id}] Subtitle file not found at {subtitle_file_path}.") | |
return None | |
logger.info(f"[Task {task_id}] Found subtitle file: {os.path.basename(subtitle_file_path)}") | |
with open(subtitle_file_path, 'r', encoding='utf-8') as f: | |
subtitle_content = f.read() | |
transcript_text = parse_vtt_content(subtitle_content) | |
if not transcript_text: | |
logger.warning(f"[Task {task_id}] No text extracted from VTT file.") | |
return None | |
logger.info(f"[Task {task_id}] Transcript fetched successfully. Length: {len(transcript_text)}") | |
if len(transcript_text) > MAX_TRANSCRIPT_CHARS: | |
logger.warning(f"[Task {task_id}] Truncating transcript from {len(transcript_text)} to {MAX_TRANSCRIPT_CHARS} chars.") | |
return transcript_text[:MAX_TRANSCRIPT_CHARS] | |
return transcript_text | |
except subprocess.TimeoutExpired: | |
logger.error(f"[Task {task_id}] yt-dlp command timed out for video ID: {video_id}") | |
return None | |
except Exception as e: | |
logger.error(f"[Task {task_id}] Error fetching transcript with yt-dlp: {e}", exc_info=True) | |
return None | |
# --- Internal Business Logic --- | |
async def get_summary_from_custom_api(message: str, temperature: float = 0.6) -> Optional[str]: | |
""" | |
Calls the custom API (used by /chat) to get a complete response. | |
This is an internal-facing function designed for non-streaming, complete results. | |
""" | |
custom_api_key_secret = os.getenv("CUSTOM_API_SECRET_KEY") | |
custom_api_base_url = os.getenv("CUSTOM_API_BASE_URL", CUSTOM_API_BASE_URL_DEFAULT) | |
custom_api_model = os.getenv("CUSTOM_API_MODEL", CUSTOM_API_MODEL_DEFAULT) | |
if not custom_api_key_secret: | |
logger.error("Custom API key ('CUSTOM_API_SECRET_KEY') is not configured for internal summary generation.") | |
return None | |
try: | |
logger.info(f"Requesting summary from Custom API ({custom_api_base_url}) with model {custom_api_model}.") | |
from openai import AsyncOpenAI | |
client = AsyncOpenAI( | |
api_key=custom_api_key_secret, | |
base_url=custom_api_base_url, | |
timeout=SUMMARY_REQUEST_TIMEOUT_SECONDS | |
) | |
completion = await client.chat.completions.create( | |
model=custom_api_model, | |
temperature=temperature, | |
messages=[{"role": "user", "content": message}], | |
stream=False # We need the full response for a summary | |
) | |
if completion.choices and completion.choices[0].message and completion.choices[0].message.content: | |
summary = completion.choices[0].message.content | |
logger.info(f"Successfully generated summary of length {len(summary)}.") | |
return summary.strip() | |
else: | |
logger.warning("Custom API call for summary returned no content.") | |
return None | |
except Exception as e: | |
logger.error(f"Error during internal Custom API call for summary: {e}", exc_info=True) | |
return None | |
async def process_gemini_request_background( | |
task_id: str, | |
user_message: str, | |
input_url: Optional[str], | |
requested_gemini_model: str, | |
gemini_key_to_use: str | |
): | |
""" | |
The fallback background process that sends the original request to the Gemini API. | |
This is used when a transcript cannot be obtained or summarization fails. | |
""" | |
logger.info(f"[Task {task_id}] Starting background Gemini processing. Model: {requested_gemini_model}, URL: {input_url}") | |
tasks_db[task_id]["status"] = "PROCESSING" | |
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc) | |
try: | |
genai.configure(api_key=gemini_key_to_use) | |
model_instance = genai.GenerativeModel(model_name=requested_gemini_model) | |
# This function now primarily handles the non-transcript case | |
content_parts = [{"text": user_message}] | |
if input_url: | |
logger.info(f"[Task {task_id}] Providing Gemini with the URL directly: {input_url}") | |
content_parts.append({ | |
"file_data": { | |
"mime_type": "video/youtube", # Assuming Gemini can handle it | |
"file_uri": input_url | |
} | |
}) | |
gemini_contents = [{"parts": content_parts}] | |
generation_config = GenerationConfig(candidate_count=1) | |
request_options = {"timeout": GEMINI_REQUEST_TIMEOUT_SECONDS} | |
logger.info(f"[Task {task_id}] Sending request to Gemini API...") | |
response = await model_instance.generate_content_async( | |
gemini_contents, stream=False, generation_config=generation_config, request_options=request_options | |
) | |
full_response_text = getattr(response, 'text', '') | |
if not full_response_text and hasattr(response, 'parts'): | |
full_response_text = ''.join(part.text for part in response.parts if hasattr(part, 'text')) | |
if not full_response_text and response.prompt_feedback and response.prompt_feedback.block_reason: | |
block_reason = response.prompt_feedback.block_reason.name | |
logger.warning(f"[Task {task_id}] Gemini content blocked: {block_reason}") | |
tasks_db[task_id]["status"] = "FAILED" | |
tasks_db[task_id]["error"] = f"Content blocked by Gemini due to: {block_reason}" | |
elif full_response_text: | |
logger.info(f"[Task {task_id}] Gemini processing successful. Result length: {len(full_response_text)}") | |
tasks_db[task_id]["status"] = "COMPLETED" | |
tasks_db[task_id]["result"] = full_response_text | |
else: | |
logger.warning(f"[Task {task_id}] Gemini returned no content and no block reason.") | |
tasks_db[task_id]["status"] = "FAILED" | |
tasks_db[task_id]["error"] = "Gemini returned no content." | |
except Exception as e: | |
logger.error(f"[Task {task_id}] Error during Gemini background processing: {e}", exc_info=True) | |
tasks_db[task_id]["status"] = "FAILED" | |
tasks_db[task_id]["error"] = str(e) | |
finally: | |
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc) | |
# --- API Endpoints --- | |
async def direct_chat(payload: ChatPayload, request: Request): | |
logger.info(f"Direct chat request. Message: '{payload.message[:50]}...'") | |
user_ip = get_user_ip(request) | |
if rate_limiter.is_rate_limited(user_ip): | |
raise HTTPException( | |
status_code=429, | |
detail={"error": "Rate limit exceeded. Please try again tomorrow."} | |
) | |
custom_api_key_secret = os.getenv("CUSTOM_API_SECRET_KEY") | |
custom_api_base_url = os.getenv("CUSTOM_API_BASE_URL", CUSTOM_API_BASE_URL_DEFAULT) | |
custom_api_model = os.getenv("CUSTOM_API_MODEL", CUSTOM_API_MODEL_DEFAULT) | |
if not custom_api_key_secret: | |
logger.error("'CUSTOM_API_SECRET_KEY' is not configured for /chat.") | |
raise HTTPException(status_code=500, detail="API key not configured.") | |
async def custom_api_streamer(): | |
try: | |
from openai import AsyncOpenAI | |
client = AsyncOpenAI( | |
api_key=custom_api_key_secret, | |
base_url=custom_api_base_url, | |
timeout=60.0 | |
) | |
stream = await client.chat.completions.create( | |
model=custom_api_model, | |
temperature=payload.temperature, | |
messages=[{"role": "user", "content": payload.message}], | |
stream=True | |
) | |
async for chunk in stream: | |
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
yield chunk.choices[0].delta.content | |
except Exception as e: | |
logger.error(f"Error streaming from Custom API: {e}", exc_info=True) | |
yield f"Error processing with Custom API: {str(e)}" | |
return StreamingResponse(custom_api_streamer(), media_type="text/plain") | |
async def submit_gemini_task( | |
request_data: GeminiTaskRequest, | |
background_tasks: BackgroundTasks, | |
http_request: Request | |
): | |
task_id = str(uuid.uuid4()) | |
logger.info(f"Received task {task_id}. URL: {request_data.url}") | |
# Initialize the task in the database | |
tasks_db[task_id] = { | |
"status": "PENDING", "result": None, "error": None, | |
"submitted_at": datetime.now(timezone.utc), | |
"last_updated_at": datetime.now(timezone.utc), | |
"request_params": request_data.model_dump() | |
} | |
# --- Primary Path: YouTube Transcript Summarization --- | |
video_id = None | |
if request_data.url and is_youtube_url(request_data.url): | |
video_id = extract_video_id(request_data.url) | |
if video_id: | |
transcript_text = await get_transcript_with_yt_dlp_cookies(video_id, task_id) | |
if transcript_text: | |
logger.info(f"[Task {task_id}] Transcript found. Proceeding with summarization.") | |
# Use the user's message as a prompt for the transcript | |
summarization_prompt = f"{request_data.message}\n\nVideo Transcript:\n{transcript_text}" | |
summary_text = await get_summary_from_custom_api(summarization_prompt) | |
if summary_text: | |
logger.info(f"[Task {task_id}] Summarization successful. Task complete.") | |
tasks_db[task_id]["status"] = "COMPLETED" | |
tasks_db[task_id]["result"] = summary_text | |
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc) | |
return TaskSubmissionResponse( | |
task_id=task_id, | |
status="COMPLETED", | |
task_detail_url=str(http_request.url_for('get_gemini_task_status', task_id=task_id)) | |
) | |
else: | |
logger.warning(f"[Task {task_id}] Summarization via custom API failed.") | |
# Fall through to Gemini fallback | |
else: | |
logger.warning(f"[Task {task_id}] Transcript fetch failed.") | |
# Fall through to Gemini fallback | |
# --- Fallback Path: Gemini Background Processing --- | |
logger.info(f"[Task {task_id}] No transcript or summarization failed. Falling back to background Gemini task.") | |
gemini_key_to_use = request_data.api_key or os.getenv("GEMINI_API_KEY") | |
if not gemini_key_to_use: | |
logger.error(f"[Task {task_id}] Gemini API Key is missing.") | |
raise HTTPException(status_code=400, detail="Gemini API Key is required for the fallback process.") | |
requested_model = request_data.gemini_model or DEFAULT_GEMINI_MODEL | |
background_tasks.add_task( | |
process_gemini_request_background, | |
task_id, | |
request_data.message, | |
request_data.url, | |
requested_model, | |
gemini_key_to_use | |
) | |
logger.info(f"[Task {task_id}] Task submitted to background Gemini processing.") | |
return TaskSubmissionResponse( | |
task_id=task_id, | |
status="PENDING", # It's pending in the background | |
task_detail_url=str(http_request.url_for('get_gemini_task_status', task_id=task_id)) | |
) | |
async def get_gemini_task_status(task_id: str = Path(..., description="The ID of the task to retrieve")): | |
task = tasks_db.get(task_id) | |
if not task: | |
raise HTTPException(status_code=404, detail="Task ID not found.") | |
return TaskStatusResponse( | |
task_id=task_id, | |
status=task["status"], | |
submitted_at=task["submitted_at"], | |
last_updated_at=task["last_updated_at"], | |
result=task.get("result"), | |
error=task.get("error") | |
) | |
async def read_root(): | |
return {"message": "API for YouTube summarization and Gemini tasks is running."} |