Chrunos's picture
Update app.py
97f80b5 verified
raw
history blame
19.7 kB
import os
import re
import logging
import uuid
import time
from datetime import datetime, timezone, timedelta
from collections import defaultdict
from typing import Optional, Dict, Any, List
import asyncio
import subprocess
import json
import tempfile
from fastapi import FastAPI, HTTPException, Body, BackgroundTasks, Path, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
import openai # For your custom API
import google.generativeai as genai # For Gemini API
from google.generativeai.types import GenerationConfig
# --- Imports for YouTube Transcript API ---
# Note: These are not directly used in the yt-dlp path but are good to keep if you ever add fallback methods.
from youtube_transcript_api import TranscriptsDisabled, NoTranscriptFound, CouldNotRetrieveTranscript
# --- Logging Configuration ---
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
# --- Configuration ---
CUSTOM_API_BASE_URL_DEFAULT = "https://api-q3ieh5raqfuad9o8.aistudio-app.com/v1"
CUSTOM_API_MODEL_DEFAULT = "gemma3:27b"
DEFAULT_GEMINI_MODEL = "gemini-1.5-flash-latest"
GEMINI_REQUEST_TIMEOUT_SECONDS = 300
SUMMARY_REQUEST_TIMEOUT_SECONDS = 180 # A separate timeout for the summary task
MAX_TRANSCRIPT_CHARS = 750000
COOKIES_FILE_PATH = "private.txt"
# --- In-Memory Task Storage ---
tasks_db: Dict[str, Dict[str, Any]] = {}
# --- Pydantic Models ---
class ChatPayload(BaseModel):
message: str
temperature: float = Field(0.6, ge=0.0, le=1.0)
class GeminiTaskRequest(BaseModel):
message: str
url: Optional[str] = None
gemini_model: Optional[str] = None
api_key: Optional[str] = Field(None, description="Gemini API Key (optional; uses Space secret if not provided)")
class TaskSubmissionResponse(BaseModel):
task_id: str
status: str
task_detail_url: str
class TaskStatusResponse(BaseModel):
task_id: str
status: str
submitted_at: datetime
last_updated_at: datetime
result: Optional[str] = None
error: Optional[str] = None
# --- RateLimiter Class ---
class RateLimiter:
def __init__(self, max_requests: int, time_window: timedelta):
self.max_requests = max_requests
self.time_window = time_window
self.requests: Dict[str, list] = defaultdict(list)
def _cleanup_old_requests(self, user_ip: str) -> None:
current_time = time.time()
self.requests[user_ip] = [
timestamp for timestamp in self.requests[user_ip]
if current_time - timestamp < self.time_window.total_seconds()
]
def is_rate_limited(self, user_ip: str) -> bool:
self._cleanup_old_requests(user_ip)
current_count = len(self.requests[user_ip])
current_time = time.time()
self.requests[user_ip].append(current_time)
return (current_count + 1) > self.max_requests
def get_current_count(self, user_ip: str) -> int:
self._cleanup_old_requests(user_ip)
return len(self.requests[user_ip])
rate_limiter = RateLimiter(max_requests=15, time_window=timedelta(days=1))
def get_user_ip(request: Request) -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0]
return request.client.host
# --- FastAPI App Initialization ---
app = FastAPI(
title="Dual Chat & Async Gemini API with YouTube Transcript Summarizer",
description="Made by Cody from chrunos.com. Fetches YouTube transcripts for summarization, with a Gemini fallback.",
version="3.0.0"
)
# --- Helper Functions ---
def is_youtube_url(url: Optional[str]) -> bool:
if not url:
return False
youtube_regex = (
r'(https?://)?(www\.)?'
r'(youtube|youtu|youtube-nocookie)\.(com|be)/'
r'(watch\?v=|embed/|v/|shorts/|.+\?v=)?([^&=%\?]{11})'
)
return re.match(youtube_regex, url) is not None
def extract_video_id(url: str) -> Optional[str]:
if not is_youtube_url(url):
return None
patterns = [
r'(?:v=|\/)([0-9A-Za-z_-]{11}).*',
r'(?:embed\/|v\/|shorts\/)([0-9A-Za-z_-]{11}).*',
r'youtu\.be\/([0-9A-Za-z_-]{11}).*'
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
logger.warning(f"Could not extract YouTube video ID from URL: {url}")
return None
def parse_vtt_content(vtt_content: str) -> str:
"""
Parse VTT subtitle content and extract only the text, removing timestamps,
formatting, and duplicate lines. This parser is designed to handle captions
that build up line-by-line.
"""
lines = vtt_content.split('\n')
text_lines = []
seen_lines = set()
for line in lines:
line = line.strip()
# Skip empty lines, WEBVTT header, NOTE lines, timestamp lines, and cue settings
if (not line or line.startswith('WEBVTT') or line.startswith('NOTE') or
'-->' in line or line.isdigit() or 'align:' in line or 'position:' in line):
continue
# Clean up HTML tags and entities
clean_line = re.sub(r'<[^>]+>', '', line)
clean_line = re.sub(r'&[^;]+;', ' ', clean_line)
clean_line = re.sub(r'\s+', ' ', clean_line).strip()
# Skip if line is empty after cleaning or if it's an exact duplicate
if not clean_line or clean_line in seen_lines:
continue
# This logic helps handle auto-generated captions where a new line
# contains the previous line plus new words.
if text_lines and (clean_line in text_lines[-1] or text_lines[-1] in clean_line):
# If the new line is longer, replace the previous one
if len(clean_line) > len(text_lines[-1]):
seen_lines.discard(text_lines[-1])
text_lines[-1] = clean_line
seen_lines.add(clean_line)
# Otherwise, if it's shorter or a substring, skip it
continue
seen_lines.add(clean_line)
text_lines.append(clean_line)
full_text = ' '.join(text_lines)
return re.sub(r'\s+', ' ', full_text).strip()
async def get_transcript_with_yt_dlp_cookies(video_id: str, task_id: str) -> Optional[str]:
"""
Fetches transcript using yt-dlp with a cookies file.
"""
logger.info(f"[Task {task_id}] Attempting transcript fetch for video ID: {video_id} using yt-dlp.")
if not os.path.exists(COOKIES_FILE_PATH):
logger.error(f"[Task {task_id}] Cookies file not found at {COOKIES_FILE_PATH}. Cannot fetch transcript.")
return None
try:
video_url = f"https://www.youtube.com/watch?v={video_id}"
with tempfile.TemporaryDirectory() as temp_dir:
cmd = [
"yt-dlp",
"--skip-download",
"--write-auto-subs",
"--write-subs",
"--sub-lang", "en",
"--sub-format", "vtt",
"--cookies", COOKIES_FILE_PATH,
"-o", os.path.join(temp_dir, "%(id)s.%(ext)s"),
video_url
]
logger.info(f"[Task {task_id}] Running yt-dlp command...")
result = await asyncio.to_thread(
subprocess.run, cmd, capture_output=True, text=True, timeout=60
)
if result.returncode != 0:
logger.error(f"[Task {task_id}] yt-dlp failed. Stderr: {result.stderr}")
return None
subtitle_file_path = os.path.join(temp_dir, f"{video_id}.en.vtt")
if not os.path.exists(subtitle_file_path):
logger.warning(f"[Task {task_id}] Subtitle file not found at {subtitle_file_path}.")
return None
logger.info(f"[Task {task_id}] Found subtitle file: {os.path.basename(subtitle_file_path)}")
with open(subtitle_file_path, 'r', encoding='utf-8') as f:
subtitle_content = f.read()
transcript_text = parse_vtt_content(subtitle_content)
if not transcript_text:
logger.warning(f"[Task {task_id}] No text extracted from VTT file.")
return None
logger.info(f"[Task {task_id}] Transcript fetched successfully. Length: {len(transcript_text)}")
if len(transcript_text) > MAX_TRANSCRIPT_CHARS:
logger.warning(f"[Task {task_id}] Truncating transcript from {len(transcript_text)} to {MAX_TRANSCRIPT_CHARS} chars.")
return transcript_text[:MAX_TRANSCRIPT_CHARS]
return transcript_text
except subprocess.TimeoutExpired:
logger.error(f"[Task {task_id}] yt-dlp command timed out for video ID: {video_id}")
return None
except Exception as e:
logger.error(f"[Task {task_id}] Error fetching transcript with yt-dlp: {e}", exc_info=True)
return None
# --- Internal Business Logic ---
async def get_summary_from_custom_api(message: str, temperature: float = 0.6) -> Optional[str]:
"""
Calls the custom API (used by /chat) to get a complete response.
This is an internal-facing function designed for non-streaming, complete results.
"""
custom_api_key_secret = os.getenv("CUSTOM_API_SECRET_KEY")
custom_api_base_url = os.getenv("CUSTOM_API_BASE_URL", CUSTOM_API_BASE_URL_DEFAULT)
custom_api_model = os.getenv("CUSTOM_API_MODEL", CUSTOM_API_MODEL_DEFAULT)
if not custom_api_key_secret:
logger.error("Custom API key ('CUSTOM_API_SECRET_KEY') is not configured for internal summary generation.")
return None
try:
logger.info(f"Requesting summary from Custom API ({custom_api_base_url}) with model {custom_api_model}.")
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=custom_api_key_secret,
base_url=custom_api_base_url,
timeout=SUMMARY_REQUEST_TIMEOUT_SECONDS
)
completion = await client.chat.completions.create(
model=custom_api_model,
temperature=temperature,
messages=[{"role": "user", "content": message}],
stream=False # We need the full response for a summary
)
if completion.choices and completion.choices[0].message and completion.choices[0].message.content:
summary = completion.choices[0].message.content
logger.info(f"Successfully generated summary of length {len(summary)}.")
return summary.strip()
else:
logger.warning("Custom API call for summary returned no content.")
return None
except Exception as e:
logger.error(f"Error during internal Custom API call for summary: {e}", exc_info=True)
return None
async def process_gemini_request_background(
task_id: str,
user_message: str,
input_url: Optional[str],
requested_gemini_model: str,
gemini_key_to_use: str
):
"""
The fallback background process that sends the original request to the Gemini API.
This is used when a transcript cannot be obtained or summarization fails.
"""
logger.info(f"[Task {task_id}] Starting background Gemini processing. Model: {requested_gemini_model}, URL: {input_url}")
tasks_db[task_id]["status"] = "PROCESSING"
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc)
try:
genai.configure(api_key=gemini_key_to_use)
model_instance = genai.GenerativeModel(model_name=requested_gemini_model)
# This function now primarily handles the non-transcript case
content_parts = [{"text": user_message}]
if input_url:
logger.info(f"[Task {task_id}] Providing Gemini with the URL directly: {input_url}")
content_parts.append({
"file_data": {
"mime_type": "video/youtube", # Assuming Gemini can handle it
"file_uri": input_url
}
})
gemini_contents = [{"parts": content_parts}]
generation_config = GenerationConfig(candidate_count=1)
request_options = {"timeout": GEMINI_REQUEST_TIMEOUT_SECONDS}
logger.info(f"[Task {task_id}] Sending request to Gemini API...")
response = await model_instance.generate_content_async(
gemini_contents, stream=False, generation_config=generation_config, request_options=request_options
)
full_response_text = getattr(response, 'text', '')
if not full_response_text and hasattr(response, 'parts'):
full_response_text = ''.join(part.text for part in response.parts if hasattr(part, 'text'))
if not full_response_text and response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason = response.prompt_feedback.block_reason.name
logger.warning(f"[Task {task_id}] Gemini content blocked: {block_reason}")
tasks_db[task_id]["status"] = "FAILED"
tasks_db[task_id]["error"] = f"Content blocked by Gemini due to: {block_reason}"
elif full_response_text:
logger.info(f"[Task {task_id}] Gemini processing successful. Result length: {len(full_response_text)}")
tasks_db[task_id]["status"] = "COMPLETED"
tasks_db[task_id]["result"] = full_response_text
else:
logger.warning(f"[Task {task_id}] Gemini returned no content and no block reason.")
tasks_db[task_id]["status"] = "FAILED"
tasks_db[task_id]["error"] = "Gemini returned no content."
except Exception as e:
logger.error(f"[Task {task_id}] Error during Gemini background processing: {e}", exc_info=True)
tasks_db[task_id]["status"] = "FAILED"
tasks_db[task_id]["error"] = str(e)
finally:
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc)
# --- API Endpoints ---
@app.post("/chat", response_class=StreamingResponse)
async def direct_chat(payload: ChatPayload, request: Request):
logger.info(f"Direct chat request. Message: '{payload.message[:50]}...'")
user_ip = get_user_ip(request)
if rate_limiter.is_rate_limited(user_ip):
raise HTTPException(
status_code=429,
detail={"error": "Rate limit exceeded. Please try again tomorrow."}
)
custom_api_key_secret = os.getenv("CUSTOM_API_SECRET_KEY")
custom_api_base_url = os.getenv("CUSTOM_API_BASE_URL", CUSTOM_API_BASE_URL_DEFAULT)
custom_api_model = os.getenv("CUSTOM_API_MODEL", CUSTOM_API_MODEL_DEFAULT)
if not custom_api_key_secret:
logger.error("'CUSTOM_API_SECRET_KEY' is not configured for /chat.")
raise HTTPException(status_code=500, detail="API key not configured.")
async def custom_api_streamer():
try:
from openai import AsyncOpenAI
client = AsyncOpenAI(
api_key=custom_api_key_secret,
base_url=custom_api_base_url,
timeout=60.0
)
stream = await client.chat.completions.create(
model=custom_api_model,
temperature=payload.temperature,
messages=[{"role": "user", "content": payload.message}],
stream=True
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except Exception as e:
logger.error(f"Error streaming from Custom API: {e}", exc_info=True)
yield f"Error processing with Custom API: {str(e)}"
return StreamingResponse(custom_api_streamer(), media_type="text/plain")
@app.post("/gemini/submit_task", response_model=TaskSubmissionResponse)
async def submit_gemini_task(
request_data: GeminiTaskRequest,
background_tasks: BackgroundTasks,
http_request: Request
):
task_id = str(uuid.uuid4())
logger.info(f"Received task {task_id}. URL: {request_data.url}")
# Initialize the task in the database
tasks_db[task_id] = {
"status": "PENDING", "result": None, "error": None,
"submitted_at": datetime.now(timezone.utc),
"last_updated_at": datetime.now(timezone.utc),
"request_params": request_data.model_dump()
}
# --- Primary Path: YouTube Transcript Summarization ---
video_id = None
if request_data.url and is_youtube_url(request_data.url):
video_id = extract_video_id(request_data.url)
if video_id:
transcript_text = await get_transcript_with_yt_dlp_cookies(video_id, task_id)
if transcript_text:
logger.info(f"[Task {task_id}] Transcript found. Proceeding with summarization.")
# Use the user's message as a prompt for the transcript
summarization_prompt = f"{request_data.message}\n\nVideo Transcript:\n{transcript_text}"
summary_text = await get_summary_from_custom_api(summarization_prompt)
if summary_text:
logger.info(f"[Task {task_id}] Summarization successful. Task complete.")
tasks_db[task_id]["status"] = "COMPLETED"
tasks_db[task_id]["result"] = summary_text
tasks_db[task_id]["last_updated_at"] = datetime.now(timezone.utc)
return TaskSubmissionResponse(
task_id=task_id,
status="COMPLETED",
task_detail_url=str(http_request.url_for('get_gemini_task_status', task_id=task_id))
)
else:
logger.warning(f"[Task {task_id}] Summarization via custom API failed.")
# Fall through to Gemini fallback
else:
logger.warning(f"[Task {task_id}] Transcript fetch failed.")
# Fall through to Gemini fallback
# --- Fallback Path: Gemini Background Processing ---
logger.info(f"[Task {task_id}] No transcript or summarization failed. Falling back to background Gemini task.")
gemini_key_to_use = request_data.api_key or os.getenv("GEMINI_API_KEY")
if not gemini_key_to_use:
logger.error(f"[Task {task_id}] Gemini API Key is missing.")
raise HTTPException(status_code=400, detail="Gemini API Key is required for the fallback process.")
requested_model = request_data.gemini_model or DEFAULT_GEMINI_MODEL
background_tasks.add_task(
process_gemini_request_background,
task_id,
request_data.message,
request_data.url,
requested_model,
gemini_key_to_use
)
logger.info(f"[Task {task_id}] Task submitted to background Gemini processing.")
return TaskSubmissionResponse(
task_id=task_id,
status="PENDING", # It's pending in the background
task_detail_url=str(http_request.url_for('get_gemini_task_status', task_id=task_id))
)
@app.get("/gemini/task/{task_id}", response_model=TaskStatusResponse)
async def get_gemini_task_status(task_id: str = Path(..., description="The ID of the task to retrieve")):
task = tasks_db.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task ID not found.")
return TaskStatusResponse(
task_id=task_id,
status=task["status"],
submitted_at=task["submitted_at"],
last_updated_at=task["last_updated_at"],
result=task.get("result"),
error=task.get("error")
)
@app.get("/")
async def read_root():
return {"message": "API for YouTube summarization and Gemini tasks is running."}