EL GHAFRAOUI AYOUB commited on
Commit
54f5afe
·
1 Parent(s): 70ba739
.cruft.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "template": "https://github.com/superlinear-ai/poetry-cookiecutter",
3
+ "commit": "b7f2fb0f123aae0a01d2ab015db31f52d2d8cc21",
4
+ "checkout": null,
5
+ "context": {
6
+ "cookiecutter": {
7
+ "project_type": "package",
8
+ "project_name": "RAGLite",
9
+ "project_description": "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.",
10
+ "project_url": "https://github.com/superlinear-ai/raglite",
11
+ "author_name": "Laurent Sorber",
12
+ "author_email": "[email protected]",
13
+ "python_version": "3.10",
14
+ "development_environment": "strict",
15
+ "with_conventional_commits": "1",
16
+ "with_fastapi_api": "0",
17
+ "with_typer_cli": "0",
18
+ "continuous_integration": "GitHub",
19
+ "private_package_repository_name": "",
20
+ "private_package_repository_url": "",
21
+ "__docker_image": "python:$PYTHON_VERSION-slim",
22
+ "__docstring_style": "NumPy",
23
+ "__project_name_kebab_case": "raglite",
24
+ "__project_name_snake_case": "raglite",
25
+ "_template": "https://github.com/superlinear-ai/poetry-cookiecutter"
26
+ }
27
+ },
28
+ "directory": null
29
+ }
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "raglite",
3
+ "dockerComposeFile": "../docker-compose.yml",
4
+ "service": "devcontainer",
5
+ "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}/",
6
+ "remoteUser": "user",
7
+ "overrideCommand": true,
8
+ "postStartCommand": "cp --update /opt/build/poetry/poetry.lock /workspaces/${localWorkspaceFolderBasename}/ && mkdir -p /workspaces/${localWorkspaceFolderBasename}/.git/hooks/ && cp --update /opt/build/git/* /workspaces/${localWorkspaceFolderBasename}/.git/hooks/",
9
+ "customizations": {
10
+ "vscode": {
11
+ "extensions": [
12
+ "charliermarsh.ruff",
13
+ "GitHub.vscode-github-actions",
14
+ "GitHub.vscode-pull-request-github",
15
+ "ms-python.mypy-type-checker",
16
+ "ms-python.python",
17
+ "ms-toolsai.jupyter",
18
+ "ryanluker.vscode-coverage-gutters",
19
+ "tamasfe.even-better-toml",
20
+ "visualstudioexptteam.vscodeintellicode"
21
+ ],
22
+ "settings": {
23
+ "coverage-gutters.coverageFileNames": [
24
+ "reports/coverage.xml"
25
+ ],
26
+ "editor.codeActionsOnSave": {
27
+ "source.fixAll": "explicit",
28
+ "source.organizeImports": "explicit"
29
+ },
30
+ "editor.formatOnSave": true,
31
+ "[python]": {
32
+ "editor.defaultFormatter": "charliermarsh.ruff"
33
+ },
34
+ "[toml]": {
35
+ "editor.formatOnSave": false
36
+ },
37
+ "editor.rulers": [
38
+ 100
39
+ ],
40
+ "files.autoSave": "onFocusChange",
41
+ "jupyter.kernels.excludePythonEnvironments": [
42
+ "/usr/local/bin/python"
43
+ ],
44
+ "mypy-type-checker.importStrategy": "fromEnvironment",
45
+ "mypy-type-checker.preferDaemon": true,
46
+ "notebook.codeActionsOnSave": {
47
+ "notebook.source.fixAll": "explicit",
48
+ "notebook.source.organizeImports": "explicit"
49
+ },
50
+ "notebook.formatOnSave.enabled": true,
51
+ "python.defaultInterpreterPath": "/opt/raglite-env/bin/python",
52
+ "python.terminal.activateEnvironment": false,
53
+ "python.testing.pytestEnabled": true,
54
+ "ruff.importStrategy": "fromEnvironment",
55
+ "ruff.logLevel": "warning",
56
+ "terminal.integrated.defaultProfile.linux": "zsh",
57
+ "terminal.integrated.profiles.linux": {
58
+ "zsh": {
59
+ "path": "/usr/bin/zsh"
60
+ }
61
+ }
62
+ }
63
+ }
64
+ }
65
+ }
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Caches
2
+ .*_cache/
3
+
4
+ # Git
5
+ .git/
.github/dependabot.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+
3
+ updates:
4
+ - package-ecosystem: github-actions
5
+ directory: /
6
+ schedule:
7
+ interval: monthly
8
+ commit-message:
9
+ prefix: "ci"
10
+ prefix-development: "ci"
11
+ include: scope
12
+ groups:
13
+ ci-dependencies:
14
+ patterns:
15
+ - "*"
16
+ - package-ecosystem: pip
17
+ directory: /
18
+ schedule:
19
+ interval: monthly
20
+ commit-message:
21
+ prefix: "chore"
22
+ prefix-development: "build"
23
+ include: scope
24
+ allow:
25
+ - dependency-type: development
26
+ versioning-strategy: increase
27
+ groups:
28
+ development-dependencies:
29
+ dependency-type: development
.github/workflows/publish.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Publish
2
+
3
+ on:
4
+ release:
5
+ types:
6
+ - created
7
+
8
+ jobs:
9
+ publish:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v4
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v5
18
+ with:
19
+ python-version: "3.10"
20
+
21
+ - name: Install Poetry
22
+ run: pip install --no-input poetry
23
+
24
+ - name: Publish package
25
+ run: |
26
+ poetry config pypi-token.pypi "${{ secrets.POETRY_PYPI_TOKEN_PYPI }}"
27
+ poetry publish --build
.github/workflows/test.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Test
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - master
8
+ pull_request:
9
+
10
+ jobs:
11
+ test:
12
+ runs-on: ubuntu-latest
13
+
14
+ strategy:
15
+ fail-fast: false
16
+ matrix:
17
+ python-version: ["3.10", "3.11"]
18
+
19
+ name: Python ${{ matrix.python-version }}
20
+
21
+ steps:
22
+ - name: Checkout
23
+ uses: actions/checkout@v4
24
+
25
+ - name: Set up Node.js
26
+ uses: actions/setup-node@v4
27
+ with:
28
+ node-version: 21
29
+
30
+ - name: Install @devcontainers/cli
31
+ run: npm install --location=global @devcontainers/[email protected]
32
+
33
+ - name: Start Dev Container
34
+ run: |
35
+ git config --global init.defaultBranch main
36
+ PYTHON_VERSION=${{ matrix.python-version }} OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} devcontainer up --workspace-folder .
37
+
38
+ - name: Lint package
39
+ run: devcontainer exec --workspace-folder . poe lint
40
+
41
+ - name: Test package
42
+ run: devcontainer exec --workspace-folder . poe test
43
+
44
+ - name: Upload coverage
45
+ uses: codecov/codecov-action@v4
46
+ with:
47
+ files: reports/coverage.xml
.gitignore ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chainlit
2
+ .chainlit/
3
+ .files/
4
+ chainlit.md
5
+
6
+ # Coverage.py
7
+ htmlcov/
8
+ reports/
9
+
10
+ # cruft
11
+ *.rej
12
+
13
+ # Data
14
+ *.csv*
15
+ *.dat*
16
+ *.pickle*
17
+ *.xls*
18
+ *.zip*
19
+ data/
20
+
21
+ # direnv
22
+ .envrc
23
+
24
+ # dotenv
25
+ .env
26
+
27
+ # rerankers
28
+ .*_cache/
29
+
30
+ # Hypothesis
31
+ .hypothesis/
32
+
33
+ # Jupyter
34
+ *.ipynb
35
+ .ipynb_checkpoints/
36
+ notebooks/
37
+
38
+ # macOS
39
+ .DS_Store
40
+
41
+ # mypy
42
+ .dmypy.json
43
+ .mypy_cache/
44
+
45
+ # Node.js
46
+ node_modules/
47
+
48
+ # Poetry
49
+ .venv/
50
+ dist/
51
+
52
+ # PyCharm
53
+ .idea/
54
+
55
+ # pyenv
56
+ .python-version
57
+
58
+ # pytest
59
+ .pytest_cache/
60
+
61
+ # Python
62
+ __pycache__/
63
+ *.py[cdo]
64
+
65
+ # RAGLite
66
+ *.db
67
+ *.sqlite
68
+
69
+ # Ruff
70
+ .ruff_cache/
71
+
72
+ # Terraform
73
+ .terraform/
74
+
75
+ # VS Code
76
+ .vscode/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://pre-commit.com
2
+ default_install_hook_types: [commit-msg, pre-commit]
3
+ default_stages: [commit, manual]
4
+ fail_fast: true
5
+ repos:
6
+ - repo: meta
7
+ hooks:
8
+ - id: check-useless-excludes
9
+ - repo: https://github.com/pre-commit/pygrep-hooks
10
+ rev: v1.10.0
11
+ hooks:
12
+ - id: python-check-mock-methods
13
+ - id: python-use-type-annotations
14
+ - id: rst-backticks
15
+ - id: rst-directive-colons
16
+ - id: rst-inline-touching-normal
17
+ - id: text-unicode-replacement-char
18
+ - repo: https://github.com/pre-commit/pre-commit-hooks
19
+ rev: v4.5.0
20
+ hooks:
21
+ - id: check-added-large-files
22
+ - id: check-ast
23
+ - id: check-builtin-literals
24
+ - id: check-case-conflict
25
+ - id: check-docstring-first
26
+ - id: check-json
27
+ - id: check-merge-conflict
28
+ - id: check-shebang-scripts-are-executable
29
+ - id: check-symlinks
30
+ - id: check-toml
31
+ - id: check-vcs-permalinks
32
+ - id: check-xml
33
+ - id: check-yaml
34
+ - id: debug-statements
35
+ - id: destroyed-symlinks
36
+ - id: detect-private-key
37
+ - id: end-of-file-fixer
38
+ types: [python]
39
+ - id: fix-byte-order-marker
40
+ - id: mixed-line-ending
41
+ - id: name-tests-test
42
+ args: [--pytest-test-first]
43
+ - id: trailing-whitespace
44
+ types: [python]
45
+ - repo: local
46
+ hooks:
47
+ - id: commitizen
48
+ name: commitizen
49
+ entry: cz check
50
+ args: [--commit-msg-file]
51
+ require_serial: true
52
+ language: system
53
+ stages: [commit-msg]
54
+ - id: ruff-check
55
+ name: ruff check
56
+ entry: ruff check
57
+ args: ["--force-exclude", "--extend-fixable=ERA001,F401,F841,T201,T203"]
58
+ require_serial: true
59
+ language: system
60
+ types_or: [python, pyi]
61
+ - id: ruff-format
62
+ name: ruff format
63
+ entry: ruff format
64
+ args: [--force-exclude]
65
+ require_serial: true
66
+ language: system
67
+ types_or: [python, pyi]
68
+ - id: shellcheck
69
+ name: shellcheck
70
+ entry: shellcheck
71
+ args: [--check-sourced]
72
+ language: system
73
+ types: [shell]
74
+ - id: poetry-check
75
+ name: poetry check
76
+ entry: poetry check
77
+ language: system
78
+ pass_filenames: false
79
+ - id: mypy
80
+ name: mypy
81
+ entry: mypy
82
+ language: system
83
+ types: [python]
CHANGELOG.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## v0.2.0 (2024-10-21)
2
+
3
+ ### Feat
4
+
5
+ - add Chainlit frontend (#33)
6
+
7
+ ## v0.1.4 (2024-10-15)
8
+
9
+ ### Fix
10
+
11
+ - fix optimal chunking edge cases (#32)
12
+
13
+ ## v0.1.3 (2024-10-13)
14
+
15
+ ### Fix
16
+
17
+ - upgrade pdftext (#30)
18
+ - improve chunk and segment ordering (#29)
19
+
20
+ ## v0.1.2 (2024-10-08)
21
+
22
+ ### Fix
23
+
24
+ - avoid pdftext v0.3.11 (#27)
25
+
26
+ ## v0.1.1 (2024-10-07)
27
+
28
+ ### Fix
29
+
30
+ - patch rerankers flashrank issue (#22)
31
+
32
+ ## v0.1.0 (2024-10-07)
33
+
34
+ ### Feat
35
+
36
+ - add reranking (#20)
37
+ - add LiteLLM and late chunking (#19)
38
+ - add PostgreSQL support (#18)
39
+ - make query adapter minimally invasive (#16)
40
+ - upgrade default CPU model to Phi-3.5-mini (#15)
41
+ - add evaluation (#14)
42
+ - infer missing font sizes (#12)
43
+ - automatically adjust number of RAG contexts (#10)
44
+ - improve exception feedback for extraction (#9)
45
+ - optimize config for CPU and GPU (#7)
46
+ - simplify document insertion (#6)
47
+ - implement basic features (#2)
48
+ - initial commit
49
+
50
+ ### Fix
51
+
52
+ - lazily import optional dependencies (#11)
53
+ - improve indexing of multiple documents (#8)
Dockerfile ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # syntax=docker/dockerfile:1
2
+ ARG PYTHON_VERSION=3.10
3
+ FROM python:$PYTHON_VERSION-slim AS base
4
+
5
+ # Remove docker-clean so we can keep the apt cache in Docker build cache.
6
+ RUN rm /etc/apt/apt.conf.d/docker-clean
7
+
8
+ # Configure Python to print tracebacks on crash [1], and to not buffer stdout and stderr [2].
9
+ # [1] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONFAULTHANDLER
10
+ # [2] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONUNBUFFERED
11
+ ENV PYTHONFAULTHANDLER 1
12
+ ENV PYTHONUNBUFFERED 1
13
+
14
+ # Create a non-root user and switch to it [1].
15
+ # [1] https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
16
+ ARG UID=1000
17
+ ARG GID=$UID
18
+ RUN groupadd --gid $GID user && \
19
+ useradd --create-home --gid $GID --uid $UID user --no-log-init && \
20
+ chown user /opt/
21
+ USER user
22
+
23
+ # Create and activate a virtual environment.
24
+ ENV VIRTUAL_ENV /opt/raglite-env
25
+ ENV PATH $VIRTUAL_ENV/bin:$PATH
26
+ RUN python -m venv $VIRTUAL_ENV
27
+
28
+ # Set the working directory.
29
+ WORKDIR /workspaces/raglite/
30
+
31
+
32
+
33
+ FROM base as poetry
34
+
35
+ USER root
36
+
37
+ # Install Poetry in separate venv so it doesn't pollute the main venv.
38
+ ENV POETRY_VERSION 1.8.0
39
+ ENV POETRY_VIRTUAL_ENV /opt/poetry-env
40
+ RUN --mount=type=cache,target=/root/.cache/pip/ \
41
+ python -m venv $POETRY_VIRTUAL_ENV && \
42
+ $POETRY_VIRTUAL_ENV/bin/pip install poetry~=$POETRY_VERSION && \
43
+ ln -s $POETRY_VIRTUAL_ENV/bin/poetry /usr/local/bin/poetry
44
+
45
+ # Install compilers that may be required for certain packages or platforms.
46
+ RUN --mount=type=cache,target=/var/cache/apt/ \
47
+ --mount=type=cache,target=/var/lib/apt/ \
48
+ apt-get update && \
49
+ apt-get install --no-install-recommends --yes build-essential
50
+
51
+ USER user
52
+
53
+ # Install the run time Python dependencies in the virtual environment.
54
+ COPY --chown=user:user poetry.lock* pyproject.toml /workspaces/raglite/
55
+ RUN mkdir -p /home/user/.cache/pypoetry/ && mkdir -p /home/user/.config/pypoetry/ && \
56
+ mkdir -p src/raglite/ && touch src/raglite/__init__.py && touch README.md
57
+ RUN --mount=type=cache,uid=$UID,gid=$GID,target=/home/user/.cache/pypoetry/ \
58
+ poetry install --only main --all-extras --no-interaction
59
+
60
+
61
+
62
+ FROM poetry as dev
63
+
64
+ # Install development tools: curl, git, gpg, ssh, starship, sudo, vim, and zsh.
65
+ USER root
66
+ RUN --mount=type=cache,target=/var/cache/apt/ \
67
+ --mount=type=cache,target=/var/lib/apt/ \
68
+ apt-get update && \
69
+ apt-get install --no-install-recommends --yes curl git gnupg ssh sudo vim zsh && \
70
+ sh -c "$(curl -fsSL https://starship.rs/install.sh)" -- "--yes" && \
71
+ usermod --shell /usr/bin/zsh user && \
72
+ echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user
73
+ RUN git config --system --add safe.directory '*'
74
+ USER user
75
+
76
+ # Install the development Python dependencies in the virtual environment.
77
+ RUN --mount=type=cache,uid=$UID,gid=$GID,target=/home/user/.cache/pypoetry/ \
78
+ poetry install --all-extras --no-interaction
79
+
80
+ # Persist output generated during docker build so that we can restore it in the dev container.
81
+ COPY --chown=user:user .pre-commit-config.yaml /workspaces/raglite/
82
+ RUN mkdir -p /opt/build/poetry/ && cp poetry.lock /opt/build/poetry/ && \
83
+ git init && pre-commit install --install-hooks && \
84
+ mkdir -p /opt/build/git/ && cp .git/hooks/commit-msg .git/hooks/pre-commit /opt/build/git/
85
+
86
+ # Configure the non-root user's shell.
87
+ ENV ANTIDOTE_VERSION 1.8.6
88
+ RUN git clone --branch v$ANTIDOTE_VERSION --depth=1 https://github.com/mattmc3/antidote.git ~/.antidote/ && \
89
+ echo 'zsh-users/zsh-syntax-highlighting' >> ~/.zsh_plugins.txt && \
90
+ echo 'zsh-users/zsh-autosuggestions' >> ~/.zsh_plugins.txt && \
91
+ echo 'source ~/.antidote/antidote.zsh' >> ~/.zshrc && \
92
+ echo 'antidote load' >> ~/.zshrc && \
93
+ echo 'eval "$(starship init zsh)"' >> ~/.zshrc && \
94
+ echo 'HISTFILE=~/.history/.zsh_history' >> ~/.zshrc && \
95
+ echo 'HISTSIZE=1000' >> ~/.zshrc && \
96
+ echo 'SAVEHIST=1000' >> ~/.zshrc && \
97
+ echo 'setopt share_history' >> ~/.zshrc && \
98
+ echo 'bindkey "^[[A" history-beginning-search-backward' >> ~/.zshrc && \
99
+ echo 'bindkey "^[[B" history-beginning-search-forward' >> ~/.zshrc && \
100
+ mkdir ~/.history/ && \
101
+ zsh -c 'source ~/.zshrc'
docker-compose.yml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+
3
+ services:
4
+
5
+ devcontainer:
6
+ build:
7
+ context: .
8
+ target: dev
9
+ args:
10
+ PYTHON_VERSION: ${PYTHON_VERSION:-3.10}
11
+ UID: ${UID:-1000}
12
+ GID: ${GID:-1000}
13
+ environment:
14
+ - OPENAI_API_KEY
15
+ - POETRY_PYPI_TOKEN_PYPI
16
+ depends_on:
17
+ - postgres
18
+ networks:
19
+ - raglite-network
20
+ volumes:
21
+ - ..:/workspaces
22
+ - command-history-volume:/home/user/.history/
23
+
24
+ dev:
25
+ extends: devcontainer
26
+ stdin_open: true
27
+ tty: true
28
+ entrypoint: []
29
+ command: [ "sh", "-c", "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh" ]
30
+ environment:
31
+ - OPENAI_API_KEY
32
+ - POETRY_PYPI_TOKEN_PYPI
33
+ - SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock
34
+ depends_on:
35
+ - postgres
36
+ networks:
37
+ - raglite-network
38
+ volumes:
39
+ - ~/.gitconfig:/etc/gitconfig
40
+ - ~/.ssh/known_hosts:/home/user/.ssh/known_hosts
41
+ - ${SSH_AGENT_AUTH_SOCK:-/run/host-services/ssh-auth.sock}:/run/host-services/ssh-auth.sock
42
+ profiles:
43
+ - dev
44
+
45
+ postgres:
46
+ image: pgvector/pgvector:pg16
47
+ environment:
48
+ POSTGRES_USER: raglite_user
49
+ POSTGRES_PASSWORD: raglite_password
50
+ networks:
51
+ - raglite-network
52
+ tmpfs:
53
+ - /var/lib/postgresql/data
54
+
55
+ networks:
56
+ raglite-network:
57
+ driver: bridge
58
+
59
+ volumes:
60
+ command-history-volume:
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system] # https://python-poetry.org/docs/pyproject/#poetry-and-pep-517
2
+ requires = ["poetry-core>=1.0.0"]
3
+ build-backend = "poetry.core.masonry.api"
4
+
5
+ [tool.poetry] # https://python-poetry.org/docs/pyproject/
6
+ name = "raglite"
7
+ version = "0.2.0"
8
+ description = "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL."
9
+ authors = ["Laurent Sorber <[email protected]>"]
10
+ readme = "README.md"
11
+ repository = "https://github.com/superlinear-ai/raglite"
12
+
13
+ [tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
14
+ bump_message = "bump(release): v$current_version → v$new_version"
15
+ tag_format = "v$version"
16
+ update_changelog_on_bump = true
17
+ version_provider = "poetry"
18
+
19
+ [tool.poetry.dependencies] # https://python-poetry.org/docs/dependency-specification/
20
+ # Python:
21
+ python = ">=3.10,<4.0"
22
+ # Markdown conversion:
23
+ pdftext = ">=0.3.13"
24
+ pypandoc-binary = { version = ">=1.13", optional = true }
25
+ scikit-learn = ">=1.4.2"
26
+ # Markdown formatting:
27
+ markdown-it-py = ">=3.0.0"
28
+ mdformat-gfm = ">=0.3.6"
29
+ # Sentence and chunk splitting:
30
+ numpy = ">=1.26.4"
31
+ scipy = ">=1.5.0"
32
+ spacy = ">=3.7.0,<3.8.0"
33
+ # Large Language Models:
34
+ huggingface-hub = ">=0.22.0"
35
+ litellm = ">=1.47.1"
36
+ llama-cpp-python = ">=0.2.88"
37
+ pydantic = ">=2.7.0"
38
+ # Approximate Nearest Neighbors:
39
+ pynndescent = ">=0.5.12"
40
+ # Reranking:
41
+ langdetect = ">=1.0.9"
42
+ rerankers = { extras = ["flashrank"], version = ">=0.5.3" }
43
+ # Storage:
44
+ pg8000 = ">=1.31.2"
45
+ sqlmodel-slim = ">=0.0.18"
46
+ # Progress:
47
+ tqdm = ">=4.66.0"
48
+ # Evaluation:
49
+ pandas = ">=2.1.0"
50
+ ragas = { version = ">=0.1.12", optional = true }
51
+ # CLI:
52
+ typer = ">=0.12.5"
53
+ # Frontend:
54
+ chainlit = { version = ">=1.2.0", optional = true }
55
+
56
+ [tool.poetry.extras] # https://python-poetry.org/docs/pyproject/#extras
57
+ chainlit = ["chainlit"]
58
+ pandoc = ["pypandoc-binary"]
59
+ ragas = ["ragas"]
60
+
61
+ [tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
62
+ commitizen = ">=3.29.1"
63
+ coverage = { extras = ["toml"], version = ">=7.4.4" }
64
+ mypy = ">=1.9.0"
65
+ poethepoet = ">=0.25.0"
66
+ pre-commit = ">=3.7.0"
67
+ pytest = ">=8.1.1"
68
+ pytest-mock = ">=3.14.0"
69
+ ruff = ">=0.5.7"
70
+ safety = ">=3.1.0"
71
+ shellcheck-py = ">=0.10.0.1"
72
+ typeguard = ">=4.2.1"
73
+ xx_sent_ud_sm = { url = "https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl" }
74
+
75
+ [tool.poetry.group.dev.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
76
+ cruft = ">=2.15.0"
77
+ ipykernel = ">=6.29.4"
78
+ ipython = ">=8.8.0"
79
+ ipywidgets = ">=8.1.2"
80
+ matplotlib = ">=3.9.0"
81
+ memory-profiler = ">=0.61.0"
82
+ pdoc = ">=14.4.0"
83
+
84
+ [tool.poetry.scripts] # https://python-poetry.org/docs/pyproject/#scripts
85
+ raglite = "raglite:cli"
86
+
87
+ [tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report
88
+ fail_under = 50
89
+ precision = 1
90
+ show_missing = true
91
+ skip_covered = true
92
+
93
+ [tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run
94
+ branch = true
95
+ command_line = "--module pytest"
96
+ data_file = "reports/.coverage"
97
+ source = ["src"]
98
+
99
+ [tool.coverage.xml] # https://coverage.readthedocs.io/en/latest/config.html#xml
100
+ output = "reports/coverage.xml"
101
+
102
+ [tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html
103
+ junit_xml = "reports/mypy.xml"
104
+ strict = true
105
+ disallow_subclassing_any = false
106
+ disallow_untyped_decorators = false
107
+ ignore_missing_imports = true
108
+ pretty = true
109
+ show_column_numbers = true
110
+ show_error_codes = true
111
+ show_error_context = true
112
+ warn_unreachable = true
113
+
114
+ [tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
115
+ addopts = "--color=yes --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
116
+ filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"]
117
+ testpaths = ["src", "tests"]
118
+ xfail_strict = true
119
+
120
+ [tool.ruff] # https://github.com/charliermarsh/ruff
121
+ fix = true
122
+ line-length = 100
123
+ src = ["src", "tests"]
124
+ target-version = "py310"
125
+
126
+ [tool.ruff.lint]
127
+ select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
128
+ ignore = ["D203", "D213", "E501", "RET504", "RUF002", "S101", "S307"]
129
+ unfixable = ["ERA001", "F401", "F841", "T201", "T203"]
130
+
131
+ [tool.ruff.lint.flake8-tidy-imports]
132
+ ban-relative-imports = "all"
133
+
134
+ [tool.ruff.lint.pycodestyle]
135
+ max-doc-length = 100
136
+
137
+ [tool.ruff.lint.pydocstyle]
138
+ convention = "numpy"
139
+
140
+ [tool.poe.tasks] # https://github.com/nat-n/poethepoet
141
+
142
+ [tool.poe.tasks.docs]
143
+ help = "Generate this package's docs"
144
+ cmd = """
145
+ pdoc
146
+ --docformat $docformat
147
+ --output-directory $outputdirectory
148
+ raglite
149
+ """
150
+
151
+ [[tool.poe.tasks.docs.args]]
152
+ help = "The docstring style (default: numpy)"
153
+ name = "docformat"
154
+ options = ["--docformat"]
155
+ default = "numpy"
156
+
157
+ [[tool.poe.tasks.docs.args]]
158
+ help = "The output directory (default: docs)"
159
+ name = "outputdirectory"
160
+ options = ["--output-directory"]
161
+ default = "docs"
162
+
163
+ [tool.poe.tasks.lint]
164
+ help = "Lint this package"
165
+
166
+ [[tool.poe.tasks.lint.sequence]]
167
+ cmd = """
168
+ pre-commit run
169
+ --all-files
170
+ --color always
171
+ """
172
+
173
+ [[tool.poe.tasks.lint.sequence]]
174
+ shell = "safety check --continue-on-error --full-report"
175
+
176
+ [tool.poe.tasks.test]
177
+ help = "Test this package"
178
+
179
+ [[tool.poe.tasks.test.sequence]]
180
+ cmd = "coverage run"
181
+
182
+ [[tool.poe.tasks.test.sequence]]
183
+ cmd = "coverage report"
184
+
185
+ [[tool.poe.tasks.test.sequence]]
186
+ cmd = "coverage xml"
src/raglite/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAGLite."""
2
+
3
+ from raglite._cli import cli
4
+ from raglite._config import RAGLiteConfig
5
+ from raglite._eval import answer_evals, evaluate, insert_evals
6
+ from raglite._insert import insert_document
7
+ from raglite._query_adapter import update_query_adapter
8
+ from raglite._rag import async_rag, rag
9
+ from raglite._search import (
10
+ hybrid_search,
11
+ keyword_search,
12
+ rerank_chunks,
13
+ retrieve_chunks,
14
+ retrieve_segments,
15
+ vector_search,
16
+ )
17
+
18
+ __all__ = [
19
+ # Config
20
+ "RAGLiteConfig",
21
+ # Insert
22
+ "insert_document",
23
+ # Search
24
+ "hybrid_search",
25
+ "keyword_search",
26
+ "vector_search",
27
+ "retrieve_chunks",
28
+ "retrieve_segments",
29
+ "rerank_chunks",
30
+ # RAG
31
+ "async_rag",
32
+ "rag",
33
+ # Query adapter
34
+ "update_query_adapter",
35
+ # Evaluate
36
+ "insert_evals",
37
+ "answer_evals",
38
+ "evaluate",
39
+ # CLI
40
+ "cli",
41
+ ]
src/raglite/_chainlit.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chainlit frontend for RAGLite."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import chainlit as cl
7
+ from chainlit.input_widget import Switch, TextInput
8
+
9
+ from raglite import (
10
+ RAGLiteConfig,
11
+ async_rag,
12
+ hybrid_search,
13
+ insert_document,
14
+ rerank_chunks,
15
+ retrieve_chunks,
16
+ )
17
+ from raglite._markdown import document_to_markdown
18
+
19
+ async_insert_document = cl.make_async(insert_document)
20
+ async_hybrid_search = cl.make_async(hybrid_search)
21
+ async_retrieve_chunks = cl.make_async(retrieve_chunks)
22
+ async_rerank_chunks = cl.make_async(rerank_chunks)
23
+
24
+
25
+ @cl.on_chat_start
26
+ async def start_chat() -> None:
27
+ """Initialize the chat."""
28
+ # Disable tokenizes parallelism to avoid the deadlock warning.
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+ # Add Chainlit settings with which the user can configure the RAGLite config.
31
+ default_config = RAGLiteConfig()
32
+ config = RAGLiteConfig(
33
+ db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url),
34
+ llm=os.environ.get("RAGLITE_LLM", default_config.llm),
35
+ embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder),
36
+ )
37
+ settings = await cl.ChatSettings( # type: ignore[no-untyped-call]
38
+ [
39
+ TextInput(id="db_url", label="Database URL", initial=str(config.db_url)),
40
+ TextInput(id="llm", label="LLM", initial=config.llm),
41
+ TextInput(id="embedder", label="Embedder", initial=config.embedder),
42
+ Switch(id="vector_search_query_adapter", label="Query adapter", initial=True),
43
+ ]
44
+ ).send()
45
+ await update_config(settings)
46
+
47
+
48
+ @cl.on_settings_update # type: ignore[arg-type]
49
+ async def update_config(settings: cl.ChatSettings) -> None:
50
+ """Update the RAGLite config."""
51
+ # Update the RAGLite config given the Chainlit settings.
52
+ config = RAGLiteConfig(
53
+ db_url=settings["db_url"], # type: ignore[index]
54
+ llm=settings["llm"], # type: ignore[index]
55
+ embedder=settings["embedder"], # type: ignore[index]
56
+ vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index]
57
+ )
58
+ cl.user_session.set("config", config) # type: ignore[no-untyped-call]
59
+ # Run a search to prime the pipeline if it's a local pipeline.
60
+ # TODO: Don't do this for SQLite once we switch from PyNNDescent to sqlite-vec.
61
+ if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"):
62
+ # async with cl.Step(name="initialize", type="retrieval"):
63
+ query = "Hello world"
64
+ chunk_ids, _ = await async_hybrid_search(query=query, config=config)
65
+ _ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)
66
+
67
+
68
+ @cl.on_message
69
+ async def handle_message(user_message: cl.Message) -> None:
70
+ """Respond to a user message."""
71
+ # Get the config and message history from the user session.
72
+ config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call]
73
+ # Determine what to do with the attachments.
74
+ inline_attachments = []
75
+ for file in user_message.elements:
76
+ if file.path:
77
+ doc_md = document_to_markdown(Path(file.path))
78
+ if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3):
79
+ # Document is small enough to attach to the context.
80
+ inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}")
81
+ else:
82
+ # Document is too large and must be inserted into the database.
83
+ async with cl.Step(name="insert", type="run") as step:
84
+ step.input = Path(file.path).name
85
+ await async_insert_document(Path(file.path), config=config)
86
+ # Append any inline attachments to the user prompt.
87
+ user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
88
+ f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
89
+ for i, attachment in enumerate(inline_attachments)
90
+ )
91
+ # Search for relevant contexts for RAG.
92
+ async with cl.Step(name="search", type="retrieval") as step:
93
+ step.input = user_message.content
94
+ chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
95
+ chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
96
+ step.output = chunks
97
+ step.elements = [ # Show the top 3 chunks inline.
98
+ cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
99
+ ]
100
+ # Rerank the chunks.
101
+ async with cl.Step(name="rerank", type="rerank") as step:
102
+ step.input = chunks
103
+ chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
104
+ step.output = chunks
105
+ step.elements = [ # Show the top 3 chunks inline.
106
+ cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
107
+ ]
108
+ # Stream the LLM response.
109
+ assistant_message = cl.Message(content="")
110
+ async for token in async_rag(
111
+ prompt=user_prompt,
112
+ search=chunks,
113
+ messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
114
+ config=config,
115
+ ):
116
+ await assistant_message.stream_token(token)
117
+ await assistant_message.update() # type: ignore[no-untyped-call]
src/raglite/_cli.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAGLite CLI."""
2
+
3
+ import os
4
+
5
+ import typer
6
+
7
+ from raglite._config import RAGLiteConfig
8
+
9
+ cli = typer.Typer()
10
+
11
+
12
+ @cli.callback()
13
+ def main() -> None:
14
+ """RAGLite CLI."""
15
+
16
+
17
+ @cli.command()
18
+ def chainlit(
19
+ db_url: str = typer.Option(RAGLiteConfig().db_url, help="Database URL"),
20
+ llm: str = typer.Option(RAGLiteConfig().llm, help="LiteLLM LLM"),
21
+ embedder: str = typer.Option(RAGLiteConfig().embedder, help="LiteLLM embedder"),
22
+ ) -> None:
23
+ """Serve a Chainlit frontend."""
24
+ # Set the environment variables for the Chainlit frontend.
25
+ os.environ["RAGLITE_DB_URL"] = os.environ.get("RAGLITE_DB_URL", db_url)
26
+ os.environ["RAGLITE_LLM"] = os.environ.get("RAGLITE_LLM", llm)
27
+ os.environ["RAGLITE_EMBEDDER"] = os.environ.get("RAGLITE_EMBEDDER", embedder)
28
+ # Import Chainlit here as it's an optional dependency.
29
+ try:
30
+ from chainlit.cli import run_chainlit
31
+ except ImportError as error:
32
+ error_message = "To serve a Chainlit frontend, please install the `chainlit` extra."
33
+ raise ImportError(error_message) from error
34
+ # Serve the frontend.
35
+ run_chainlit(__file__.replace("_cli.py", "_chainlit.py"))
36
+
37
+
38
+ if __name__ == "__main__":
39
+ cli()
src/raglite/_config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAGLite config."""
2
+
3
+ import contextlib
4
+ import os
5
+ from dataclasses import dataclass, field
6
+ from io import StringIO
7
+
8
+ from llama_cpp import llama_supports_gpu_offload
9
+ from sqlalchemy.engine import URL
10
+
11
+ from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker
12
+
13
+ # Suppress rerankers output on import until [1] is fixed.
14
+ # [1] https://github.com/AnswerDotAI/rerankers/issues/36
15
+ with contextlib.redirect_stdout(StringIO()):
16
+ from rerankers.models.ranker import BaseRanker
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class RAGLiteConfig:
21
+ """Configuration for RAGLite."""
22
+
23
+ # Database config.
24
+ db_url: str | URL = "sqlite:///raglite.sqlite"
25
+ # LLM config used for generation.
26
+ llm: str = field(
27
+ default_factory=lambda: (
28
+ "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@8192"
29
+ if llama_supports_gpu_offload()
30
+ else "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096"
31
+ )
32
+ )
33
+ llm_max_tries: int = 4
34
+ # Embedder config used for indexing.
35
+ embedder: str = field(
36
+ default_factory=lambda: ( # Nomic-embed may be better if only English is used.
37
+ "llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf"
38
+ if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
39
+ else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf"
40
+ )
41
+ )
42
+ embedder_normalize: bool = True
43
+ embedder_sentence_window_size: int = 3
44
+ # Chunk config used to partition documents into chunks.
45
+ chunk_max_size: int = 1440 # Max number of characters per chunk.
46
+ # Vector search config.
47
+ vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
48
+ vector_search_query_adapter: bool = True
49
+ # Reranking config.
50
+ reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = field(
51
+ default_factory=lambda: (
52
+ ("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
53
+ ("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
54
+ ),
55
+ compare=False, # Exclude the reranker from comparison to avoid lru_cache misses.
56
+ )
57
+
58
+ def __post_init__(self) -> None:
59
+ # Late chunking with llama-cpp-python does not apply sentence windowing.
60
+ if self.embedder.startswith("llama-cpp-python"):
61
+ object.__setattr__(self, "embedder_sentence_window_size", 1)
src/raglite/_database.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PostgreSQL or SQLite database tables for RAGLite."""
2
+
3
+ import datetime
4
+ import json
5
+ from functools import lru_cache
6
+ from hashlib import sha256
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ from litellm import get_model_info # type: ignore[attr-defined]
12
+ from markdown_it import MarkdownIt
13
+ from pydantic import ConfigDict
14
+ from sqlalchemy.engine import Engine, make_url
15
+ from sqlmodel import (
16
+ JSON,
17
+ Column,
18
+ Field,
19
+ Relationship,
20
+ Session,
21
+ SQLModel,
22
+ create_engine,
23
+ text,
24
+ )
25
+
26
+ from raglite._config import RAGLiteConfig
27
+ from raglite._litellm import LlamaCppPythonLLM
28
+ from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
29
+
30
+
31
+ def hash_bytes(data: bytes, max_len: int = 16) -> str:
32
+ """Hash bytes to a hexadecimal string."""
33
+ return sha256(data, usedforsecurity=False).hexdigest()[:max_len]
34
+
35
+
36
+ class Document(SQLModel, table=True):
37
+ """A document."""
38
+
39
+ # Enable JSON columns.
40
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
41
+
42
+ # Table columns.
43
+ id: str = Field(..., primary_key=True)
44
+ filename: str
45
+ url: str | None = Field(default=None)
46
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
47
+
48
+ # Add relationships so we can access document.chunks and document.evals.
49
+ chunks: list["Chunk"] = Relationship(back_populates="document")
50
+ evals: list["Eval"] = Relationship(back_populates="document")
51
+
52
+ @staticmethod
53
+ def from_path(doc_path: Path, **kwargs: Any) -> "Document":
54
+ """Create a document from a file path."""
55
+ return Document(
56
+ id=hash_bytes(doc_path.read_bytes()),
57
+ filename=doc_path.name,
58
+ metadata_={
59
+ "size": doc_path.stat().st_size,
60
+ "created": doc_path.stat().st_ctime,
61
+ "modified": doc_path.stat().st_mtime,
62
+ **kwargs,
63
+ },
64
+ )
65
+
66
+
67
+ class Chunk(SQLModel, table=True):
68
+ """A document chunk."""
69
+
70
+ # Enable JSON columns.
71
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
72
+
73
+ # Table columns.
74
+ id: str = Field(..., primary_key=True)
75
+ document_id: str = Field(..., foreign_key="document.id", index=True)
76
+ index: int = Field(..., index=True)
77
+ headings: str
78
+ body: str
79
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
80
+
81
+ # Add relationships so we can access chunk.document and chunk.embeddings.
82
+ document: Document = Relationship(back_populates="chunks")
83
+ embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
84
+
85
+ @staticmethod
86
+ def from_body(
87
+ document_id: str,
88
+ index: int,
89
+ body: str,
90
+ headings: str = "",
91
+ **kwargs: Any,
92
+ ) -> "Chunk":
93
+ """Create a chunk from Markdown."""
94
+ return Chunk(
95
+ id=hash_bytes(body.encode()),
96
+ document_id=document_id,
97
+ index=index,
98
+ headings=headings,
99
+ body=body,
100
+ metadata_=kwargs,
101
+ )
102
+
103
+ def extract_headings(self) -> str:
104
+ """Extract Markdown headings from the chunk, starting from the current Markdown headings."""
105
+ md = MarkdownIt()
106
+ heading_lines = [""] * 10
107
+ level = None
108
+ for doc in (self.headings, self.body):
109
+ for token in md.parse(doc):
110
+ if token.type == "heading_open":
111
+ level = int(token.tag[1])
112
+ elif token.type == "heading_close":
113
+ level = None
114
+ elif level is not None:
115
+ heading_content = token.content.strip().replace("\n", " ")
116
+ heading_lines[level] = ("#" * level) + " " + heading_content
117
+ heading_lines[level + 1 :] = [""] * len(heading_lines[level + 1 :])
118
+ headings = "\n".join([heading for heading in heading_lines if heading])
119
+ return headings
120
+
121
+ @property
122
+ def embedding_matrix(self) -> FloatMatrix:
123
+ """Return this chunk's multi-vector embedding matrix."""
124
+ # Uses the relationship chunk.embeddings to access the chunk_embedding table.
125
+ return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings])
126
+
127
+ def __hash__(self) -> int:
128
+ return hash(self.id)
129
+
130
+ def __repr__(self) -> str:
131
+ return json.dumps(
132
+ {
133
+ "id": self.id,
134
+ "document_id": self.document_id,
135
+ "index": self.index,
136
+ "headings": self.headings,
137
+ "body": self.body[:100],
138
+ "metadata": self.metadata_,
139
+ },
140
+ indent=4,
141
+ )
142
+
143
+ def __str__(self) -> str:
144
+ """Context representation of this chunk."""
145
+ return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()
146
+
147
+
148
+ class ChunkEmbedding(SQLModel, table=True):
149
+ """A (sub-)chunk embedding."""
150
+
151
+ __tablename__ = "chunk_embedding"
152
+
153
+ # Enable Embedding columns.
154
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
155
+
156
+ # Table columns.
157
+ id: int = Field(..., primary_key=True)
158
+ chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
159
+ embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
160
+
161
+ # Add relationship so we can access embedding.chunk.
162
+ chunk: Chunk = Relationship(back_populates="embeddings")
163
+
164
+ @classmethod
165
+ def set_embedding_dim(cls, dim: int) -> None:
166
+ """Modify the embedding column's dimension after class definition."""
167
+ cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined]
168
+
169
+
170
+ class IndexMetadata(SQLModel, table=True):
171
+ """Vector and keyword search index metadata."""
172
+
173
+ __tablename__ = "index_metadata"
174
+
175
+ # Enable PickledObject columns.
176
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
177
+
178
+ # Table columns.
179
+ id: str = Field(..., primary_key=True)
180
+ version: datetime.datetime = Field(
181
+ default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
182
+ )
183
+ metadata_: dict[str, Any] = Field(
184
+ default_factory=dict, sa_column=Column("metadata", PickledObject)
185
+ )
186
+
187
+ @staticmethod
188
+ @lru_cache(maxsize=4)
189
+ def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | None:
190
+ engine = create_database_engine(config)
191
+ with Session(engine) as session:
192
+ index_metadata_record = session.get(IndexMetadata, id_)
193
+ if index_metadata_record is None:
194
+ return None
195
+ return index_metadata_record.metadata_
196
+
197
+ @staticmethod
198
+ def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]:
199
+ metadata = IndexMetadata._get(id_, config=config) or {}
200
+ return metadata
201
+
202
+
203
+ class Eval(SQLModel, table=True):
204
+ """A RAG evaluation example."""
205
+
206
+ __tablename__ = "eval"
207
+
208
+ # Enable JSON columns.
209
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
210
+
211
+ # Table columns.
212
+ id: str = Field(..., primary_key=True)
213
+ document_id: str = Field(..., foreign_key="document.id", index=True)
214
+ chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
215
+ question: str
216
+ contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
217
+ ground_truth: str
218
+ metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
219
+
220
+ # Add relationship so we can access eval.document.
221
+ document: Document = Relationship(back_populates="evals")
222
+
223
+ @staticmethod
224
+ def from_chunks(
225
+ question: str,
226
+ contexts: list[Chunk],
227
+ ground_truth: str,
228
+ **kwargs: Any,
229
+ ) -> "Eval":
230
+ """Create a chunk from Markdown."""
231
+ document_id = contexts[0].document_id
232
+ chunk_ids = [context.id for context in contexts]
233
+ return Eval(
234
+ id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()),
235
+ document_id=document_id,
236
+ chunk_ids=chunk_ids,
237
+ question=question,
238
+ contexts=[str(context) for context in contexts],
239
+ ground_truth=ground_truth,
240
+ metadata_=kwargs,
241
+ )
242
+
243
+
244
+ @lru_cache(maxsize=1)
245
+ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
246
+ """Create a database engine and initialize it."""
247
+ # Parse the database URL and validate that the database backend is supported.
248
+ config = config or RAGLiteConfig()
249
+ db_url = make_url(config.db_url)
250
+ db_backend = db_url.get_backend_name()
251
+ # Update database configuration.
252
+ connect_args = {}
253
+ if db_backend == "postgresql":
254
+ # Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL.
255
+ if "+" not in db_url.drivername:
256
+ db_url = db_url.set(drivername="postgresql+pg8000")
257
+ # Support setting the sslmode for pg8000.
258
+ if "pg8000" in db_url.drivername and "sslmode" in db_url.query:
259
+ query = dict(db_url.query)
260
+ if query.pop("sslmode") != "disable":
261
+ connect_args["ssl_context"] = True
262
+ db_url = db_url.set(query=query)
263
+ elif db_backend == "sqlite":
264
+ # Optimize SQLite performance.
265
+ pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"}
266
+ db_url = db_url.update_query_dict(pragmas, append=True)
267
+ else:
268
+ error_message = "RAGLite only supports PostgreSQL and SQLite."
269
+ raise ValueError(error_message)
270
+ # Create the engine.
271
+ engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args)
272
+ # Install database extensions.
273
+ if db_backend == "postgresql":
274
+ with Session(engine) as session:
275
+ session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
276
+ session.commit()
277
+ # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
278
+ # to date by loading that LLM.
279
+ if config.embedder.startswith("llama-cpp-python"):
280
+ _ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
281
+ llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
282
+ model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
283
+ embedding_dim = model_info.get("output_vector_size") or -1
284
+ assert embedding_dim > 0
285
+ # Create all SQLModel tables.
286
+ ChunkEmbedding.set_embedding_dim(embedding_dim)
287
+ SQLModel.metadata.create_all(engine)
288
+ # Create backend-specific indexes.
289
+ if db_backend == "postgresql":
290
+ # Create a keyword search index with `tsvector` and a vector search index with `pgvector`.
291
+ with Session(engine) as session:
292
+ metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
293
+ session.execute(
294
+ text("""
295
+ CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
296
+ """)
297
+ )
298
+ session.execute(
299
+ text(f"""
300
+ CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
301
+ USING hnsw (
302
+ (embedding::halfvec({embedding_dim}))
303
+ halfvec_{metrics[config.vector_search_index_metric]}_ops
304
+ );
305
+ """)
306
+ )
307
+ session.commit()
308
+ elif db_backend == "sqlite":
309
+ # Create a virtual table for keyword search on the chunk table.
310
+ # We use the chunk table as an external content table [1] to avoid duplicating the data.
311
+ # [1] https://www.sqlite.org/fts5.html#external_content_tables
312
+ with Session(engine) as session:
313
+ session.execute(
314
+ text("""
315
+ CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
316
+ """)
317
+ )
318
+ session.execute(
319
+ text("""
320
+ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
321
+ INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
322
+ END;
323
+ """)
324
+ )
325
+ session.execute(
326
+ text("""
327
+ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
328
+ INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
329
+ END;
330
+ """)
331
+ )
332
+ session.execute(
333
+ text("""
334
+ CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
335
+ INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
336
+ INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
337
+ END;
338
+ """)
339
+ )
340
+ session.commit()
341
+ return engine
src/raglite/_embed.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """String embedder."""
2
+
3
+ from functools import partial
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ from litellm import embedding
8
+ from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
9
+ from tqdm.auto import tqdm, trange
10
+
11
+ from raglite._config import RAGLiteConfig
12
+ from raglite._litellm import LlamaCppPythonLLM
13
+ from raglite._typing import FloatMatrix, IntVector
14
+
15
+
16
+ def _embed_sentences_with_late_chunking( # noqa: PLR0915
17
+ sentences: list[str], *, config: RAGLiteConfig | None = None
18
+ ) -> FloatMatrix:
19
+ """Embed a document's sentences with late chunking."""
20
+
21
+ def _count_tokens(
22
+ sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
23
+ ) -> list[int]:
24
+ # Join the sentences with the sentinel token and tokenise the result.
25
+ sentences_tokens = np.asarray(
26
+ embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
27
+ )
28
+ # Map all sentinel token variants to the first one.
29
+ for sentinel_token in sentinel_tokens[1:]:
30
+ sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
31
+ # Count how many tokens there are in between sentinel tokens to recover the token counts.
32
+ sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
33
+ num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
34
+ assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
35
+ num_tokens_list: list[int] = num_tokens.tolist()
36
+ return num_tokens_list
37
+
38
+ def _create_segment(
39
+ content_start_index: int,
40
+ max_tokens_preamble: int,
41
+ max_tokens_content: int,
42
+ num_tokens: IntVector,
43
+ ) -> tuple[int, int]:
44
+ # Compute the segment sentence start index so that the segment preamble has no more than
45
+ # max_tokens_preamble tokens between [segment_start_index, content_start_index).
46
+ cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
47
+ offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
48
+ segment_start_index = content_start_index - int(offset_preamble)
49
+ # Allow a larger segment content if we didn't use all of the allowed preamble tokens.
50
+ max_tokens_content = max_tokens_content + (
51
+ max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
52
+ )
53
+ # Compute the segment sentence end index so that the segment content has no more than
54
+ # max_tokens_content tokens between [content_start_index, segment_end_index).
55
+ cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
56
+ offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
57
+ segment_end_index = content_start_index + int(offset_segment)
58
+ return segment_start_index, segment_end_index
59
+
60
+ # Assert that we're using a llama-cpp-python model, since API-based embedding models don't
61
+ # support outputting token-level embeddings.
62
+ config = config or RAGLiteConfig()
63
+ assert config.embedder.startswith("llama-cpp-python")
64
+ embedder = LlamaCppPythonLLM.llm(
65
+ config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
66
+ )
67
+ n_ctx = embedder.n_ctx()
68
+ n_batch = embedder.n_batch
69
+ # Identify the tokens corresponding to a sentinel character.
70
+ sentinel_char = "⊕"
71
+ sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
72
+ sentinel_tokens = [
73
+ token
74
+ for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
75
+ if sentinel_char in embedder.detokenize([token]).decode()
76
+ ]
77
+ assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder"
78
+ # Compute the number of tokens per sentence. We use a method based on a sentinel token to
79
+ # minimise the number of calls to embedder.tokenize, which incurs a significant overhead
80
+ # (presumably to load the tokenizer) [1].
81
+ # TODO: Make token counting faster and more robust once [1] is fixed.
82
+ # [1] https://github.com/abetlen/llama-cpp-python/issues/1763
83
+ num_tokens_list: list[int] = []
84
+ sentence_batch, sentence_batch_len = [], 0
85
+ for i, sentence in enumerate(sentences):
86
+ sentence_batch.append(sentence)
87
+ sentence_batch_len += len(sentence)
88
+ if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
89
+ num_tokens_list.extend(
90
+ _count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
91
+ )
92
+ sentence_batch, sentence_batch_len = [], 0
93
+ num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
94
+ # Compute the maximum number of tokens for each segment's preamble and content.
95
+ # Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
96
+ # to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
97
+ # TODO: Improve the context window size once [1] is fixed.
98
+ # [1] https://github.com/abetlen/llama-cpp-python/issues/1762
99
+ max_tokens = min(n_ctx, n_batch) - 16
100
+ max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
101
+ max_tokens_content = max_tokens - max_tokens_preamble
102
+ # Compute a list of segments, each consisting of a preamble and content.
103
+ segments = []
104
+ content_start_index = 0
105
+ while content_start_index < len(sentences):
106
+ segment_start_index, segment_end_index = _create_segment(
107
+ content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
108
+ )
109
+ segments.append((segment_start_index, content_start_index, segment_end_index))
110
+ content_start_index = segment_end_index
111
+ # Embed the segments and apply late chunking.
112
+ sentence_embeddings_list: list[FloatMatrix] = []
113
+ if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
114
+ segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True)
115
+ for segment in segments:
116
+ # Get the token embeddings of the entire segment, including preamble and content.
117
+ segment_start_index, content_start_index, segment_end_index = segment
118
+ segment_sentences = sentences[segment_start_index:segment_end_index]
119
+ segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
120
+ # Split the segment embeddings into embedding matrices per sentence.
121
+ segment_tokens = num_tokens[segment_start_index:segment_end_index]
122
+ sentence_size = np.round(
123
+ len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
124
+ ).astype(np.intp)
125
+ sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
126
+ # Compute the segment sentence embeddings by averaging the token embeddings.
127
+ content_sentence_embeddings = [
128
+ np.mean(sentence_matrix, axis=0, keepdims=True)
129
+ for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
130
+ ]
131
+ sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
132
+ sentence_embeddings = np.vstack(sentence_embeddings_list)
133
+ # Normalise the sentence embeddings to unit norm and cast to half precision.
134
+ if config.embedder_normalize:
135
+ sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
136
+ sentence_embeddings = sentence_embeddings.astype(np.float16)
137
+ return sentence_embeddings
138
+
139
+
140
+ def _embed_sentences_with_windowing(
141
+ sentences: list[str], *, config: RAGLiteConfig | None = None
142
+ ) -> FloatMatrix:
143
+ """Embed a document's sentences with windowing."""
144
+
145
+ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
146
+ # Embed the batch of strings.
147
+ if config.embedder.startswith("llama-cpp-python"):
148
+ # LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
149
+ # Additionally, we explicitly manually pool the token embeddings to obtain sentence
150
+ # embeddings because token embeddings are universally supported, while sequence
151
+ # embeddings are only supported by some models.
152
+ embedder = LlamaCppPythonLLM.llm(
153
+ config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
154
+ )
155
+ embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
156
+ else:
157
+ # Use LiteLLM's API to embed the batch of strings.
158
+ response = embedding(config.embedder, string_batch)
159
+ embeddings = np.asarray([item["embedding"] for item in response["data"]])
160
+ # Normalise the embeddings to unit norm and cast to half precision.
161
+ if config.embedder_normalize:
162
+ embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
163
+ embeddings = embeddings.astype(np.float16)
164
+ return embeddings
165
+
166
+ # Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
167
+ config = config or RAGLiteConfig()
168
+ sentence_windows = [
169
+ "".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
170
+ for i in range(len(sentences))
171
+ ]
172
+ # Embed the sentence windows in batches.
173
+ batch_size = 64
174
+ batch_range = (
175
+ partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
176
+ if len(sentence_windows) > batch_size
177
+ else range
178
+ )
179
+ batch_embeddings = [
180
+ _embed_string_batch(sentence_windows[i : i + batch_size], config=config)
181
+ for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator]
182
+ ]
183
+ sentence_embeddings = np.vstack(batch_embeddings)
184
+ return sentence_embeddings
185
+
186
+
187
+ def sentence_embedding_type(
188
+ *,
189
+ config: RAGLiteConfig | None = None,
190
+ ) -> Literal["late_chunking", "windowing"]:
191
+ """Return the type of sentence embeddings."""
192
+ config = config or RAGLiteConfig()
193
+ return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"
194
+
195
+
196
+ def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
197
+ """Embed the sentences of a document as a NumPy matrix with one row per sentence."""
198
+ config = config or RAGLiteConfig()
199
+ if sentence_embedding_type(config=config) == "late_chunking":
200
+ sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
201
+ else:
202
+ sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config)
203
+ return sentence_embeddings
src/raglite/_eval.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generation and evaluation of evals."""
2
+
3
+ from random import randint
4
+ from typing import ClassVar
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pydantic import BaseModel, Field, field_validator
9
+ from sqlmodel import Session, func, select
10
+ from tqdm.auto import tqdm, trange
11
+
12
+ from raglite._config import RAGLiteConfig
13
+ from raglite._database import Chunk, Document, Eval, create_database_engine
14
+ from raglite._extract import extract_with_llm
15
+ from raglite._rag import rag
16
+ from raglite._search import hybrid_search, retrieve_segments, vector_search
17
+ from raglite._typing import SearchMethod
18
+
19
+
20
+ def insert_evals( # noqa: C901
21
+ *, num_evals: int = 100, max_contexts_per_eval: int = 20, config: RAGLiteConfig | None = None
22
+ ) -> None:
23
+ """Generate and insert evals into the database."""
24
+
25
+ class QuestionResponse(BaseModel):
26
+ """A specific question about the content of a set of document contexts."""
27
+
28
+ question: str = Field(
29
+ ...,
30
+ description="A specific question about the content of a set of document contexts.",
31
+ min_length=1,
32
+ )
33
+ system_prompt: ClassVar[str] = """
34
+ You are given a set of contexts extracted from a document.
35
+ You are a subject matter expert on the document's topic.
36
+ Your task is to generate a question to quiz other subject matter experts on the information in the provided context.
37
+ The question MUST satisfy ALL of the following criteria:
38
+ - The question SHOULD integrate as much of the provided context as possible.
39
+ - The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible.
40
+ - The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information.
41
+ - The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
42
+ - The question MUST NOT reference the existence of the context, directly or indirectly.
43
+ - The question MUST treat the context as if its contents are entirely part of your working memory.
44
+ """.strip()
45
+
46
+ @field_validator("question")
47
+ @classmethod
48
+ def validate_question(cls, value: str) -> str:
49
+ """Validate the question."""
50
+ question = value.strip().lower()
51
+ if "context" in question or "document" in question or "question" in question:
52
+ raise ValueError
53
+ if not question.endswith("?"):
54
+ raise ValueError
55
+ return value
56
+
57
+ config = config or RAGLiteConfig()
58
+ engine = create_database_engine(config)
59
+ with Session(engine) as session:
60
+ for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True):
61
+ # Sample a random document from the database.
62
+ seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first()
63
+ if seed_document is None:
64
+ error_message = "First run `insert_document()` before generating evals."
65
+ raise ValueError(error_message)
66
+ # Sample a random chunk from that document.
67
+ seed_chunk = session.exec(
68
+ select(Chunk)
69
+ .where(Chunk.document_id == seed_document.id)
70
+ .order_by(func.random())
71
+ .limit(1)
72
+ ).first()
73
+ if seed_chunk is None:
74
+ continue
75
+ # Expand the seed chunk into a set of related chunks.
76
+ related_chunk_ids, _ = vector_search(
77
+ np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
78
+ num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
79
+ config=config,
80
+ )
81
+ related_chunks = retrieve_segments(related_chunk_ids, config=config)
82
+ # Extract a question from the seed chunk's related chunks.
83
+ try:
84
+ question_response = extract_with_llm(
85
+ QuestionResponse, related_chunks, config=config
86
+ )
87
+ except ValueError:
88
+ continue
89
+ else:
90
+ question = question_response.question
91
+ # Search for candidate chunks to answer the generated question.
92
+ candidate_chunk_ids, _ = hybrid_search(
93
+ question, num_results=max_contexts_per_eval, config=config
94
+ )
95
+ candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
96
+
97
+ # Determine which candidate chunks are relevant to answer the generated question.
98
+ class ContextEvalResponse(BaseModel):
99
+ """Indicate whether the provided context can be used to answer a given question."""
100
+
101
+ hit: bool = Field(
102
+ ...,
103
+ description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
104
+ )
105
+ system_prompt: ClassVar[str] = f"""
106
+ You are given a context extracted from a document.
107
+ You are a subject matter expert on the document's topic.
108
+ Your task is to answer whether the provided context contains (a part of) the answer to this question: "{question}"
109
+ An example of a context that does NOT contain (a part of) the answer is a table of contents.
110
+ """.strip()
111
+
112
+ relevant_chunks = []
113
+ for candidate_chunk in tqdm(
114
+ candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True
115
+ ):
116
+ try:
117
+ context_eval_response = extract_with_llm(
118
+ ContextEvalResponse, str(candidate_chunk), config=config
119
+ )
120
+ except ValueError: # noqa: PERF203
121
+ pass
122
+ else:
123
+ if context_eval_response.hit:
124
+ relevant_chunks.append(candidate_chunk)
125
+ if not relevant_chunks:
126
+ continue
127
+
128
+ # Answer the question using the relevant chunks.
129
+ class AnswerResponse(BaseModel):
130
+ """Answer a question using the provided context."""
131
+
132
+ answer: str = Field(
133
+ ...,
134
+ description="A complete answer to the given question using the provided context.",
135
+ min_length=1,
136
+ )
137
+ system_prompt: ClassVar[str] = f"""
138
+ You are given a set of contexts extracted from a document.
139
+ You are a subject matter expert on the document's topic.
140
+ Your task is to generate a complete answer to the following question using the provided context: "{question}"
141
+ The answer MUST satisfy ALL of the following criteria:
142
+ - The answer MUST integrate as much of the provided context as possible.
143
+ - The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
144
+ - The answer MUST NOT reference the existence of the context, directly or indirectly.
145
+ - The answer MUST treat the context as if its contents are entirely part of your working memory.
146
+ """.strip()
147
+
148
+ try:
149
+ answer_response = extract_with_llm(
150
+ AnswerResponse,
151
+ [str(relevant_chunk) for relevant_chunk in relevant_chunks],
152
+ config=config,
153
+ )
154
+ except ValueError:
155
+ continue
156
+ else:
157
+ answer = answer_response.answer
158
+ # Store the eval in the database.
159
+ eval_ = Eval.from_chunks(
160
+ question=question,
161
+ contexts=relevant_chunks,
162
+ ground_truth=answer,
163
+ )
164
+ session.add(eval_)
165
+ session.commit()
166
+
167
+
168
+ def answer_evals(
169
+ num_evals: int = 100,
170
+ search: SearchMethod = hybrid_search,
171
+ *,
172
+ config: RAGLiteConfig | None = None,
173
+ ) -> pd.DataFrame:
174
+ """Read evals from the database and answer them with RAG."""
175
+ # Read evals from the database.
176
+ config = config or RAGLiteConfig()
177
+ engine = create_database_engine(config)
178
+ with Session(engine) as session:
179
+ evals = session.exec(select(Eval).limit(num_evals)).all()
180
+ # Answer evals with RAG.
181
+ answers: list[str] = []
182
+ contexts: list[list[str]] = []
183
+ for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
184
+ response = rag(eval_.question, search=search, config=config)
185
+ answer = "".join(response)
186
+ answers.append(answer)
187
+ chunk_ids, _ = search(eval_.question, config=config)
188
+ contexts.append(retrieve_segments(chunk_ids))
189
+ # Collect the answered evals.
190
+ answered_evals: dict[str, list[str] | list[list[str]]] = {
191
+ "question": [eval_.question for eval_ in evals],
192
+ "answer": answers,
193
+ "contexts": contexts,
194
+ "ground_truth": [eval_.ground_truth for eval_ in evals],
195
+ "ground_truth_contexts": [eval_.contexts for eval_ in evals],
196
+ }
197
+ answered_evals_df = pd.DataFrame.from_dict(answered_evals)
198
+ return answered_evals_df
199
+
200
+
201
+ def evaluate(
202
+ answered_evals: pd.DataFrame | int = 100,
203
+ config: RAGLiteConfig | None = None,
204
+ ) -> pd.DataFrame:
205
+ """Evaluate the performance of a set of answered evals with Ragas."""
206
+ try:
207
+ from datasets import Dataset
208
+ from langchain_community.chat_models import ChatLiteLLM
209
+ from langchain_community.embeddings import LlamaCppEmbeddings
210
+ from langchain_community.llms import LlamaCpp
211
+ from ragas import RunConfig
212
+ from ragas import evaluate as ragas_evaluate
213
+
214
+ from raglite._litellm import LlamaCppPythonLLM
215
+ except ImportError as import_error:
216
+ error_message = "To use the `evaluate` function, please install the `ragas` extra."
217
+ raise ImportError(error_message) from import_error
218
+
219
+ # Create a set of answered evals if not provided.
220
+ config = config or RAGLiteConfig()
221
+ answered_evals_df = (
222
+ answered_evals
223
+ if isinstance(answered_evals, pd.DataFrame)
224
+ else answer_evals(num_evals=answered_evals, config=config)
225
+ )
226
+ # Load the LLM.
227
+ if config.llm.startswith("llama-cpp-python"):
228
+ llm = LlamaCppPythonLLM().llm(model=config.llm)
229
+ lc_llm = LlamaCpp(
230
+ model_path=llm.model_path,
231
+ n_batch=llm.n_batch,
232
+ n_ctx=llm.n_ctx(),
233
+ n_gpu_layers=-1,
234
+ verbose=llm.verbose,
235
+ )
236
+ else:
237
+ lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
238
+ # Load the embedder.
239
+ if not config.embedder.startswith("llama-cpp-python"):
240
+ error_message = "Currently, only `llama-cpp-python` embedders are supported."
241
+ raise NotImplementedError(error_message)
242
+ embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
243
+ lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
244
+ model_path=embedder.model_path,
245
+ n_batch=embedder.n_batch,
246
+ n_ctx=embedder.n_ctx(),
247
+ n_gpu_layers=-1,
248
+ verbose=embedder.verbose,
249
+ )
250
+ # Evaluate the answered evals with Ragas.
251
+ evaluation_df = ragas_evaluate(
252
+ dataset=Dataset.from_pandas(answered_evals_df),
253
+ llm=lc_llm,
254
+ embeddings=lc_embedder,
255
+ run_config=RunConfig(max_workers=1),
256
+ ).to_pandas()
257
+ return evaluation_df
src/raglite/_extract.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extract structured data from unstructured text with an LLM."""
2
+
3
+ from typing import Any, TypeVar
4
+
5
+ from litellm import completion
6
+ from pydantic import BaseModel, ValidationError
7
+
8
+ from raglite._config import RAGLiteConfig
9
+
10
+ T = TypeVar("T", bound=BaseModel)
11
+
12
+
13
+ def extract_with_llm(
14
+ return_type: type[T],
15
+ user_prompt: str | list[str],
16
+ config: RAGLiteConfig | None = None,
17
+ **kwargs: Any,
18
+ ) -> T:
19
+ """Extract structured data from unstructured text with an LLM.
20
+
21
+ This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system
22
+ prompt to use. Example:
23
+
24
+ from typing import ClassVar
25
+ from pydantic import BaseModel, Field
26
+
27
+ class MyNameResponse(BaseModel):
28
+ my_name: str = Field(..., description="The user's name.")
29
+ system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)."
30
+
31
+ my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.")
32
+ """
33
+ # Load the default config if not provided.
34
+ config = config or RAGLiteConfig()
35
+ # Update the system prompt with the JSON schema of the return type to help the LLM.
36
+ system_prompt = (
37
+ return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
38
+ "Format your response according to this JSON schema:\n",
39
+ return_type.model_json_schema(),
40
+ )
41
+ # Concatenate the user prompt if it is a list of strings.
42
+ if isinstance(user_prompt, list):
43
+ user_prompt = "\n\n".join(
44
+ f'<context index="{i}">\n{chunk.strip()}\n</context>'
45
+ for i, chunk in enumerate(user_prompt)
46
+ )
47
+ # Extract structured data from the unstructured input.
48
+ for _ in range(config.llm_max_tries):
49
+ response = completion(
50
+ model=config.llm,
51
+ messages=[
52
+ {"role": "system", "content": system_prompt},
53
+ {"role": "user", "content": user_prompt},
54
+ ],
55
+ response_format={"type": "json_object", "schema": return_type.model_json_schema()},
56
+ **kwargs,
57
+ )
58
+ try:
59
+ instance = return_type.model_validate_json(response["choices"][0]["message"]["content"])
60
+ except (KeyError, ValueError, ValidationError) as e:
61
+ # Malformed response, not a JSON string, or not a valid instance of the return type.
62
+ last_exception = e
63
+ continue
64
+ else:
65
+ break
66
+ else:
67
+ error_message = f"Failed to extract {return_type} from input {user_prompt}."
68
+ raise ValueError(error_message) from last_exception
69
+ return instance
src/raglite/_flashrank.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Patched version of FlashRankRanker that fixes incorrect reranking [1].
2
+
3
+ [1] https://github.com/AnswerDotAI/rerankers/issues/39
4
+ """
5
+
6
+ import contextlib
7
+ from io import StringIO
8
+ from typing import Any
9
+
10
+ from flashrank import RerankRequest
11
+
12
+ # Suppress rerankers output on import until [1] is fixed.
13
+ # [1] https://github.com/AnswerDotAI/rerankers/issues/36
14
+ with contextlib.redirect_stdout(StringIO()):
15
+ from rerankers.documents import Document
16
+ from rerankers.models.flashrank_ranker import FlashRankRanker
17
+ from rerankers.results import RankedResults, Result
18
+ from rerankers.utils import prep_docs
19
+
20
+
21
+ class PatchedFlashRankRanker(FlashRankRanker):
22
+ def rank(
23
+ self,
24
+ query: str,
25
+ docs: str | list[str] | Document | list[Document],
26
+ doc_ids: list[str] | list[int] | None = None,
27
+ metadata: list[dict[str, Any]] | None = None,
28
+ ) -> RankedResults:
29
+ docs = prep_docs(docs, doc_ids, metadata)
30
+ passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)]
31
+ rerank_request = RerankRequest(query=query, passages=passages)
32
+ flashrank_results = self.model.rerank(rerank_request)
33
+ ranked_results = [
34
+ Result(
35
+ document=docs[result["id"]], # This patches the incorrect ranking in the original.
36
+ score=result["score"],
37
+ rank=idx + 1,
38
+ )
39
+ for idx, result in enumerate(flashrank_results)
40
+ ]
41
+ return RankedResults(results=ranked_results, query=query, has_scores=True)
src/raglite/_insert.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Index documents."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from sqlalchemy.engine import make_url
7
+ from sqlmodel import Session, select
8
+ from tqdm.auto import tqdm
9
+
10
+ from raglite._config import RAGLiteConfig
11
+ from raglite._database import Chunk, ChunkEmbedding, Document, IndexMetadata, create_database_engine
12
+ from raglite._embed import embed_sentences, sentence_embedding_type
13
+ from raglite._markdown import document_to_markdown
14
+ from raglite._split_chunks import split_chunks
15
+ from raglite._split_sentences import split_sentences
16
+ from raglite._typing import FloatMatrix
17
+
18
+
19
+ def _create_chunk_records(
20
+ document_id: str,
21
+ chunks: list[str],
22
+ chunk_embeddings: list[FloatMatrix],
23
+ config: RAGLiteConfig,
24
+ ) -> tuple[list[Chunk], list[list[ChunkEmbedding]]]:
25
+ """Process chunks into chunk and chunk embedding records."""
26
+ # Create the chunk records.
27
+ chunk_records, headings = [], ""
28
+ for i, chunk in enumerate(chunks):
29
+ # Create and append the chunk record.
30
+ record = Chunk.from_body(document_id=document_id, index=i, body=chunk, headings=headings)
31
+ chunk_records.append(record)
32
+ # Update the Markdown headings with those of this chunk.
33
+ headings = record.extract_headings()
34
+ # Create the chunk embedding records.
35
+ chunk_embedding_records = []
36
+ if sentence_embedding_type(config=config) == "late_chunking":
37
+ # Every chunk record is associated with a list of chunk embedding records, one for each of
38
+ # the sentences in the chunk.
39
+ for chunk_record, chunk_embedding in zip(chunk_records, chunk_embeddings, strict=True):
40
+ chunk_embedding_records.append(
41
+ [
42
+ ChunkEmbedding(chunk_id=chunk_record.id, embedding=sentence_embedding)
43
+ for sentence_embedding in chunk_embedding
44
+ ]
45
+ )
46
+ else:
47
+ # Embed the full chunks, including the current Markdown headings.
48
+ full_chunk_embeddings = embed_sentences([str(chunk) for chunk in chunks], config=config)
49
+ # Every chunk record is associated with a list of chunk embedding records. The chunk
50
+ # embedding records each correspond to a linear combination of a sentence embedding and an
51
+ # embedding of the full chunk with Markdown headings.
52
+ α = 0.382 # Golden ratio. # noqa: PLC2401
53
+ for chunk_record, chunk_embedding, full_chunk_embedding in zip(
54
+ chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True
55
+ ):
56
+ chunk_embedding_records.append(
57
+ [
58
+ ChunkEmbedding(
59
+ chunk_id=chunk_record.id,
60
+ embedding=α * sentence_embedding + (1 - α) * full_chunk_embedding,
61
+ )
62
+ for sentence_embedding in chunk_embedding
63
+ ]
64
+ )
65
+ return chunk_records, chunk_embedding_records
66
+
67
+
68
+ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> None: # noqa: PLR0915
69
+ """Insert a document into the database and update the index."""
70
+ # Use the default config if not provided.
71
+ config = config or RAGLiteConfig()
72
+ db_backend = make_url(config.db_url).get_backend_name()
73
+ # Preprocess the document into chunks and chunk embeddings.
74
+ with tqdm(total=5, unit="step", dynamic_ncols=True) as pbar:
75
+ pbar.set_description("Initializing database")
76
+ engine = create_database_engine(config)
77
+ pbar.update(1)
78
+ pbar.set_description("Converting to Markdown")
79
+ doc = document_to_markdown(doc_path)
80
+ pbar.update(1)
81
+ pbar.set_description("Splitting sentences")
82
+ sentences = split_sentences(doc, max_len=config.chunk_max_size)
83
+ pbar.update(1)
84
+ pbar.set_description("Embedding sentences")
85
+ sentence_embeddings = embed_sentences(sentences, config=config)
86
+ pbar.update(1)
87
+ pbar.set_description("Splitting chunks")
88
+ chunks, chunk_embeddings = split_chunks(
89
+ sentences=sentences,
90
+ sentence_embeddings=sentence_embeddings,
91
+ sentence_window_size=config.embedder_sentence_window_size,
92
+ max_size=config.chunk_max_size,
93
+ )
94
+ pbar.update(1)
95
+ # Create and store the chunk records.
96
+ with Session(engine) as session:
97
+ # Add the document to the document table.
98
+ document_record = Document.from_path(doc_path)
99
+ if session.get(Document, document_record.id) is None:
100
+ session.add(document_record)
101
+ session.commit()
102
+ # Create the chunk records to insert into the chunk table.
103
+ chunk_records, chunk_embedding_records = _create_chunk_records(
104
+ document_record.id, chunks, chunk_embeddings, config
105
+ )
106
+ # Store the chunk and chunk embedding records.
107
+ for chunk_record, chunk_embedding_record_list in tqdm(
108
+ zip(chunk_records, chunk_embedding_records, strict=True),
109
+ desc="Inserting chunks",
110
+ total=len(chunk_records),
111
+ unit="chunk",
112
+ dynamic_ncols=True,
113
+ ):
114
+ if session.get(Chunk, chunk_record.id) is not None:
115
+ continue
116
+ session.add(chunk_record)
117
+ session.add_all(chunk_embedding_record_list)
118
+ session.commit()
119
+ # Manually update the vector search chunk index for SQLite.
120
+ if db_backend == "sqlite":
121
+ from pynndescent import NNDescent
122
+
123
+ with Session(engine) as session:
124
+ # Get the vector search chunk index from the database, or create a new one.
125
+ index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
126
+ chunk_ids = index_metadata.metadata_.get("chunk_ids", [])
127
+ chunk_sizes = index_metadata.metadata_.get("chunk_sizes", [])
128
+ # Get the unindexed chunks.
129
+ unindexed_chunks = list(session.exec(select(Chunk).offset(len(chunk_ids))).all())
130
+ if not unindexed_chunks:
131
+ return
132
+ # Assemble the unindexed chunk embeddings into a NumPy array.
133
+ unindexed_chunk_embeddings = [chunk.embedding_matrix for chunk in unindexed_chunks]
134
+ X = np.vstack(unindexed_chunk_embeddings) # noqa: N806
135
+ # Index the unindexed chunks.
136
+ with tqdm(
137
+ total=len(unindexed_chunks),
138
+ desc="Indexing chunks",
139
+ unit="chunk",
140
+ dynamic_ncols=True,
141
+ ) as pbar:
142
+ # Fit or update the ANN index.
143
+ if len(chunk_ids) == 0:
144
+ nndescent = NNDescent(X, metric=config.vector_search_index_metric)
145
+ else:
146
+ nndescent = index_metadata.metadata_["index"]
147
+ nndescent.update(X)
148
+ # Prepare the ANN index so it can to handle query vectors not in the training set.
149
+ nndescent.prepare()
150
+ # Update the index metadata and mark it as dirty by recreating the dictionary.
151
+ index_metadata.metadata_ = {
152
+ **index_metadata.metadata_,
153
+ "index": nndescent,
154
+ "chunk_ids": chunk_ids + [c.id for c in unindexed_chunks],
155
+ "chunk_sizes": chunk_sizes + [len(em) for em in unindexed_chunk_embeddings],
156
+ }
157
+ # Store the updated vector search chunk index.
158
+ session.add(index_metadata)
159
+ session.commit()
160
+ pbar.update(len(unindexed_chunks))
src/raglite/_litellm.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Add support for llama-cpp-python models to LiteLLM."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import warnings
6
+ from collections.abc import AsyncIterator, Callable, Iterator
7
+ from functools import cache
8
+ from typing import Any, ClassVar, cast
9
+
10
+ import httpx
11
+ import litellm
12
+ from litellm import ( # type: ignore[attr-defined]
13
+ CustomLLM,
14
+ GenericStreamingChunk,
15
+ ModelResponse,
16
+ convert_to_model_response_object,
17
+ )
18
+ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
19
+ from llama_cpp import ( # type: ignore[attr-defined]
20
+ ChatCompletionRequestMessage,
21
+ CreateChatCompletionResponse,
22
+ CreateChatCompletionStreamResponse,
23
+ Llama,
24
+ LlamaRAMCache,
25
+ )
26
+
27
+ # Reduce the logging level for LiteLLM and flashrank.
28
+ logging.getLogger("litellm").setLevel(logging.WARNING)
29
+ logging.getLogger("flashrank").setLevel(logging.WARNING)
30
+
31
+
32
+ class LlamaCppPythonLLM(CustomLLM):
33
+ """A llama-cpp-python provider for LiteLLM.
34
+
35
+ This provider enables using llama-cpp-python models with LiteLLM. The LiteLLM model
36
+ specification is "llama-cpp-python/<hugging_face_repo_id>/<filename>@<n_ctx>", where n_ctx is
37
+ an optional parameter that specifies the context size of the model. If n_ctx is not provided or
38
+ if it's set to 0, the model's default context size is used.
39
+
40
+ Example usage:
41
+
42
+ ```python
43
+ from litellm import completion
44
+
45
+ response = completion(
46
+ model="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4092",
47
+ messages=[{"role": "user", "content": "Hello world!"}],
48
+ # stream=True
49
+ )
50
+ ```
51
+ """
52
+
53
+ # Create a lock to prevent concurrent access to llama-cpp-python models.
54
+ streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
55
+
56
+ # The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included:
57
+ # max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers.
58
+ # [1] https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion
59
+ # [2] https://docs.litellm.ai/docs/completion/input
60
+ supported_openai_params: ClassVar[list[str]] = [
61
+ "functions", # Deprecated
62
+ "function_call", # Deprecated
63
+ "tools",
64
+ "tool_choice",
65
+ "temperature",
66
+ "top_p",
67
+ "top_k",
68
+ "min_p",
69
+ "typical_p",
70
+ "stop",
71
+ "seed",
72
+ "response_format",
73
+ "max_tokens",
74
+ "presence_penalty",
75
+ "frequency_penalty",
76
+ "repeat_penalty",
77
+ "tfs_z",
78
+ "mirostat_mode",
79
+ "mirostat_tau",
80
+ "mirostat_eta",
81
+ "logit_bias",
82
+ ]
83
+
84
+ @staticmethod
85
+ @cache
86
+ def llm(model: str, **kwargs: Any) -> Llama:
87
+ # Drop the llama-cpp-python prefix from the model.
88
+ repo_id_filename = model.replace("llama-cpp-python/", "")
89
+ # Convert the LiteLLM model string to repo_id, filename, and n_ctx.
90
+ repo_id, filename = repo_id_filename.rsplit("/", maxsplit=1)
91
+ n_ctx = 0
92
+ if len(filename_n_ctx := filename.rsplit("@", maxsplit=1)) == 2: # noqa: PLR2004
93
+ filename, n_ctx_str = filename_n_ctx
94
+ n_ctx = int(n_ctx_str)
95
+ # Load the LLM.
96
+ with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN.
97
+ warnings.filterwarnings("ignore", category=UserWarning)
98
+ llm = Llama.from_pretrained(
99
+ repo_id=repo_id,
100
+ filename=filename,
101
+ n_ctx=n_ctx,
102
+ n_gpu_layers=-1,
103
+ verbose=False,
104
+ **kwargs,
105
+ )
106
+ # Enable caching.
107
+ llm.set_cache(LlamaRAMCache())
108
+ # Register the model info with LiteLLM.
109
+ litellm.register_model( # type: ignore[attr-defined]
110
+ {
111
+ model: {
112
+ "max_tokens": llm.n_ctx(),
113
+ "max_input_tokens": llm.n_ctx(),
114
+ "max_output_tokens": None,
115
+ "input_cost_per_token": 0.0,
116
+ "output_cost_per_token": 0.0,
117
+ "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None,
118
+ "litellm_provider": "llama-cpp-python",
119
+ "mode": "embedding" if kwargs.get("embedding") else "completion",
120
+ "supported_openai_params": LlamaCppPythonLLM.supported_openai_params,
121
+ "supports_function_calling": True,
122
+ "supports_parallel_function_calling": True,
123
+ "supports_vision": False,
124
+ }
125
+ }
126
+ )
127
+ return llm
128
+
129
+ def completion( # noqa: PLR0913
130
+ self,
131
+ model: str,
132
+ messages: list[ChatCompletionRequestMessage],
133
+ api_base: str,
134
+ custom_prompt_dict: dict[str, Any],
135
+ model_response: ModelResponse,
136
+ print_verbose: Callable, # type: ignore[type-arg]
137
+ encoding: str,
138
+ api_key: str,
139
+ logging_obj: Any,
140
+ optional_params: dict[str, Any],
141
+ acompletion: Callable | None = None, # type: ignore[type-arg]
142
+ litellm_params: dict[str, Any] | None = None,
143
+ logger_fn: Callable | None = None, # type: ignore[type-arg]
144
+ headers: dict[str, Any] | None = None,
145
+ timeout: float | httpx.Timeout | None = None,
146
+ client: HTTPHandler | None = None,
147
+ ) -> ModelResponse:
148
+ llm = self.llm(model)
149
+ llama_cpp_python_params = {
150
+ k: v for k, v in optional_params.items() if k in self.supported_openai_params
151
+ }
152
+ response = cast(
153
+ CreateChatCompletionResponse,
154
+ llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
155
+ )
156
+ litellm_model_response: ModelResponse = convert_to_model_response_object(
157
+ response_object=response,
158
+ model_response_object=model_response,
159
+ response_type="completion",
160
+ stream=False,
161
+ )
162
+ return litellm_model_response
163
+
164
+ def streaming( # noqa: PLR0913
165
+ self,
166
+ model: str,
167
+ messages: list[ChatCompletionRequestMessage],
168
+ api_base: str,
169
+ custom_prompt_dict: dict[str, Any],
170
+ model_response: ModelResponse,
171
+ print_verbose: Callable, # type: ignore[type-arg]
172
+ encoding: str,
173
+ api_key: str,
174
+ logging_obj: Any,
175
+ optional_params: dict[str, Any],
176
+ acompletion: Callable | None = None, # type: ignore[type-arg]
177
+ litellm_params: dict[str, Any] | None = None,
178
+ logger_fn: Callable | None = None, # type: ignore[type-arg]
179
+ headers: dict[str, Any] | None = None,
180
+ timeout: float | httpx.Timeout | None = None,
181
+ client: HTTPHandler | None = None,
182
+ ) -> Iterator[GenericStreamingChunk]:
183
+ llm = self.llm(model)
184
+ llama_cpp_python_params = {
185
+ k: v for k, v in optional_params.items() if k in self.supported_openai_params
186
+ }
187
+ stream = cast(
188
+ Iterator[CreateChatCompletionStreamResponse],
189
+ llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),
190
+ )
191
+ for chunk in stream:
192
+ choices = chunk.get("choices", [])
193
+ for choice in choices:
194
+ text = choice.get("delta", {}).get("content", None)
195
+ finish_reason = choice.get("finish_reason")
196
+ litellm_generic_streaming_chunk = GenericStreamingChunk(
197
+ text=text, # type: ignore[typeddict-item]
198
+ is_finished=bool(finish_reason),
199
+ finish_reason=finish_reason, # type: ignore[typeddict-item]
200
+ usage=None,
201
+ index=choice.get("index"), # type: ignore[typeddict-item]
202
+ provider_specific_fields={
203
+ "id": chunk.get("id"),
204
+ "model": chunk.get("model"),
205
+ "created": chunk.get("created"),
206
+ "object": chunk.get("object"),
207
+ },
208
+ )
209
+ yield litellm_generic_streaming_chunk
210
+
211
+ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
212
+ self,
213
+ model: str,
214
+ messages: list[ChatCompletionRequestMessage],
215
+ api_base: str,
216
+ custom_prompt_dict: dict[str, Any],
217
+ model_response: ModelResponse,
218
+ print_verbose: Callable, # type: ignore[type-arg]
219
+ encoding: str,
220
+ api_key: str,
221
+ logging_obj: Any,
222
+ optional_params: dict[str, Any],
223
+ acompletion: Callable | None = None, # type: ignore[type-arg]
224
+ litellm_params: dict[str, Any] | None = None,
225
+ logger_fn: Callable | None = None, # type: ignore[type-arg]
226
+ headers: dict[str, Any] | None = None,
227
+ timeout: float | httpx.Timeout | None = None, # noqa: ASYNC109
228
+ client: AsyncHTTPHandler | None = None,
229
+ ) -> AsyncIterator[GenericStreamingChunk]:
230
+ # Start a synchronous stream.
231
+ stream = self.streaming(
232
+ model,
233
+ messages,
234
+ api_base,
235
+ custom_prompt_dict,
236
+ model_response,
237
+ print_verbose,
238
+ encoding,
239
+ api_key,
240
+ logging_obj,
241
+ optional_params,
242
+ acompletion,
243
+ litellm_params,
244
+ logger_fn,
245
+ headers,
246
+ timeout,
247
+ )
248
+ await asyncio.sleep(0) # Yield control to the event loop after initialising the context.
249
+ # Wrap the synchronous stream in an asynchronous stream.
250
+ async with LlamaCppPythonLLM.streaming_lock:
251
+ for litellm_generic_streaming_chunk in stream:
252
+ yield litellm_generic_streaming_chunk
253
+ await asyncio.sleep(0) # Yield control to the event loop after each token.
254
+
255
+
256
+ # Register the LlamaCppPythonLLM provider.
257
+ if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map):
258
+ litellm.custom_provider_map.append(
259
+ {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}
260
+ )
261
+ litellm.suppress_debug_info = True
src/raglite/_markdown.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert any document to Markdown."""
2
+
3
+ import re
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import mdformat
9
+ import numpy as np
10
+ from pdftext.extraction import dictionary_output
11
+ from sklearn.cluster import KMeans
12
+
13
+
14
+ def parsed_pdf_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0915
15
+ """Convert a PDF parsed with pdftext to Markdown."""
16
+
17
+ def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: # noqa: C901
18
+ """Add heading level metadata to a PDF parsed with pdftext."""
19
+
20
+ def extract_font_size(span: dict[str, Any]) -> float:
21
+ """Extract the font size from a text span."""
22
+ font_size: float = 1.0
23
+ if span["font"]["size"] > 1: # A value of 1 appears to mean "unknown" in pdftext.
24
+ font_size = span["font"]["size"]
25
+ elif digit_sequences := re.findall(r"\d+", span["font"]["name"] or ""):
26
+ font_size = float(digit_sequences[-1])
27
+ elif "\n" not in span["text"]: # Occasionally a span can contain a newline character.
28
+ if round(span["rotation"]) in (0.0, 180.0, -180.0):
29
+ font_size = span["bbox"][3] - span["bbox"][1]
30
+ elif round(span["rotation"]) in (90.0, -90.0, 270.0, -270.0):
31
+ font_size = span["bbox"][2] - span["bbox"][0]
32
+ return font_size
33
+
34
+ # Copy the pages.
35
+ pages = deepcopy(pages)
36
+ # Extract an array of all font sizes used by the text spans.
37
+ font_sizes = np.asarray(
38
+ [
39
+ extract_font_size(span)
40
+ for page in pages
41
+ for block in page["blocks"]
42
+ for line in block["lines"]
43
+ for span in line["spans"]
44
+ ]
45
+ )
46
+ font_sizes = np.round(font_sizes * 2) / 2
47
+ unique_font_sizes, counts = np.unique(font_sizes, return_counts=True)
48
+ # Determine the paragraph font size as the mode font size.
49
+ tiny = unique_font_sizes < min(5, np.max(unique_font_sizes))
50
+ counts[tiny] = -counts[tiny]
51
+ mode = np.argmax(counts)
52
+ counts[tiny] = -counts[tiny]
53
+ mode_font_size = unique_font_sizes[mode]
54
+ # Determine (at most) 6 heading font sizes by clustering font sizes larger than the mode.
55
+ heading_font_sizes = unique_font_sizes[mode + 1 :]
56
+ if len(heading_font_sizes) > 0:
57
+ heading_counts = counts[mode + 1 :]
58
+ kmeans = KMeans(n_clusters=min(6, len(heading_font_sizes)), random_state=42)
59
+ kmeans.fit(heading_font_sizes[:, np.newaxis], sample_weight=heading_counts)
60
+ heading_font_sizes = np.sort(np.ravel(kmeans.cluster_centers_))[::-1]
61
+ # Add heading level information to the text spans and lines.
62
+ for page in pages:
63
+ for block in page["blocks"]:
64
+ for line in block["lines"]:
65
+ if "md" not in line:
66
+ line["md"] = {}
67
+ heading_level = np.zeros(8) # 0-5: <h1>-<h6>, 6: <p>, 7: <small>
68
+ for span in line["spans"]:
69
+ if "md" not in span:
70
+ span["md"] = {}
71
+ span_font_size = extract_font_size(span)
72
+ if span_font_size < mode_font_size:
73
+ idx = 7
74
+ elif span_font_size == mode_font_size:
75
+ idx = 6
76
+ else:
77
+ idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment]
78
+ span["md"]["heading_level"] = idx + 1
79
+ heading_level[idx] += len(span["text"])
80
+ line["md"]["heading_level"] = np.argmax(heading_level) + 1
81
+ return pages
82
+
83
+ def add_emphasis_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
84
+ """Add emphasis metadata such as bold and italic to a PDF parsed with pdftext."""
85
+ # Copy the pages.
86
+ pages = deepcopy(pages)
87
+ # Add emphasis metadata to the text spans.
88
+ for page in pages:
89
+ for block in page["blocks"]:
90
+ for line in block["lines"]:
91
+ if "md" not in line:
92
+ line["md"] = {}
93
+ for span in line["spans"]:
94
+ if "md" not in span:
95
+ span["md"] = {}
96
+ span["md"]["bold"] = span["font"]["weight"] > 500 # noqa: PLR2004
97
+ span["md"]["italic"] = "ital" in (span["font"]["name"] or "").lower()
98
+ line["md"]["bold"] = all(
99
+ span["md"]["bold"] for span in line["spans"] if span["text"].strip()
100
+ )
101
+ line["md"]["italic"] = all(
102
+ span["md"]["italic"] for span in line["spans"] if span["text"].strip()
103
+ )
104
+ return pages
105
+
106
+ def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
107
+ """Strip page numbers from a PDF parsed with pdftext."""
108
+ # Copy the pages.
109
+ pages = deepcopy(pages)
110
+ # Remove lines that only contain a page number.
111
+ for page in pages:
112
+ for block in page["blocks"]:
113
+ block["lines"] = [
114
+ line
115
+ for line in block["lines"]
116
+ if not re.match(
117
+ r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"])
118
+ )
119
+ ]
120
+ return pages
121
+
122
+ def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0912
123
+ """Convert a list of pages to Markdown."""
124
+ pages_md = []
125
+ for page in pages:
126
+ page_md = ""
127
+ for block in page["blocks"]:
128
+ block_text = ""
129
+ for line in block["lines"]:
130
+ # Build the line text and style the spans.
131
+ line_text = ""
132
+ for span in line["spans"]:
133
+ if (
134
+ not line["md"]["bold"]
135
+ and not line["md"]["italic"]
136
+ and span["md"]["bold"]
137
+ and span["md"]["italic"]
138
+ ):
139
+ line_text += f"***{span['text']}***"
140
+ elif not line["md"]["bold"] and span["md"]["bold"]:
141
+ line_text += f"**{span['text']}**"
142
+ elif not line["md"]["italic"] and span["md"]["italic"]:
143
+ line_text += f"*{span['text']}*"
144
+ else:
145
+ line_text += span["text"]
146
+ # Add emphasis to the line (if it's not a heading or whitespace).
147
+ line_text = line_text.rstrip()
148
+ line_is_whitespace = not line_text.strip()
149
+ line_is_heading = line["md"]["heading_level"] <= 6 # noqa: PLR2004
150
+ if not line_is_heading and not line_is_whitespace:
151
+ if line["md"]["bold"] and line["md"]["italic"]:
152
+ line_text = f"***{line_text}***"
153
+ elif line["md"]["bold"]:
154
+ line_text = f"**{line_text}**"
155
+ elif line["md"]["italic"]:
156
+ line_text = f"*{line_text}*"
157
+ # Set the heading level.
158
+ if line_is_heading and not line_is_whitespace:
159
+ line_text = f"{'#' * line['md']['heading_level']} {line_text}"
160
+ line_text += "\n"
161
+ block_text += line_text
162
+ block_text = block_text.rstrip() + "\n\n"
163
+ page_md += block_text
164
+ pages_md.append(page_md.strip())
165
+ return pages_md
166
+
167
+ def merge_split_headings(pages: list[str]) -> list[str]:
168
+ """Merge headings that are split across lines."""
169
+
170
+ def _merge_split_headings(match: re.Match[str]) -> str:
171
+ atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()]
172
+ return f"{match.group(1)} {' '.join(atx_headings)}\n\n"
173
+
174
+ pages_md = [
175
+ re.sub(
176
+ r"^(#+)[ \t]+[^\n]+\n+(?:^\1[ \t]+[^\n]+\n+)+",
177
+ _merge_split_headings,
178
+ page,
179
+ flags=re.MULTILINE,
180
+ )
181
+ for page in pages
182
+ ]
183
+ return pages_md
184
+
185
+ # Add heading level metadata.
186
+ pages = add_heading_level_metadata(pages)
187
+ # Add emphasis metadata.
188
+ pages = add_emphasis_metadata(pages)
189
+ # Strip page numbers.
190
+ pages = strip_page_numbers(pages)
191
+ # Convert the pages to Markdown.
192
+ pages_md = convert_to_markdown(pages)
193
+ # Merge headings that are split across lines.
194
+ pages_md = merge_split_headings(pages_md)
195
+ return pages_md
196
+
197
+
198
+ def document_to_markdown(doc_path: Path) -> str:
199
+ """Convert any document to GitHub Flavored Markdown."""
200
+ # Convert the file's content to GitHub Flavored Markdown.
201
+ if doc_path.suffix == ".pdf":
202
+ # Parse the PDF with pdftext and convert it to Markdown.
203
+ pages = dictionary_output(doc_path, sort=True, keep_chars=False)
204
+ doc = "\n\n".join(parsed_pdf_to_markdown(pages))
205
+ else:
206
+ try:
207
+ # Use pandoc for everything else.
208
+ import pypandoc
209
+
210
+ doc = pypandoc.convert_file(doc_path, to="gfm")
211
+ except ImportError as error:
212
+ error_message = (
213
+ "To convert files to Markdown with pandoc, please install the `pandoc` extra."
214
+ )
215
+ raise ImportError(error_message) from error
216
+ except RuntimeError:
217
+ # File format not supported, fall back to reading the text.
218
+ doc = doc_path.read_text()
219
+ # Improve Markdown quality.
220
+ doc = mdformat.text(doc)
221
+ return doc
src/raglite/_query_adapter.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute and update an optimal query adapter."""
2
+
3
+ import numpy as np
4
+ from sqlmodel import Session, col, select
5
+ from tqdm.auto import tqdm
6
+
7
+ from raglite._config import RAGLiteConfig
8
+ from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
9
+ from raglite._embed import embed_sentences
10
+ from raglite._search import vector_search
11
+
12
+
13
+ def update_query_adapter( # noqa: PLR0915, C901
14
+ *,
15
+ max_triplets: int = 4096,
16
+ max_triplets_per_eval: int = 64,
17
+ optimize_top_k: int = 40,
18
+ config: RAGLiteConfig | None = None,
19
+ ) -> None:
20
+ """Compute an optimal query adapter and update the database with it.
21
+
22
+ This function computes an optimal linear transform A, called a 'query adapter', that is used to
23
+ transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in
24
+ order to improve the quality of the search results.
25
+
26
+ Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
27
+ score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
28
+
29
+ If the nearest neighbour search uses the dot product as its relevance score, we can find the
30
+ optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
31
+ bound on the Frobenius norm of A:
32
+
33
+ A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
34
+ Σᵢ (pᵢ - nᵢ)' A qᵢ
35
+ trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ']
36
+ P := [p₁'; ...; pₖ']
37
+ N := [n₁'; ...; nₖ']
38
+ trace[ Q' (P - N) A ]
39
+ trace[ M A ] where M := Q' (P - N)
40
+ s.t. ||A||_F == 1
41
+ = M' / ||M||_F
42
+
43
+ If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
44
+ the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
45
+ with an orthogonality constraint on A:
46
+
47
+ A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
48
+ Σᵢ (pᵢ - nᵢ)' A qᵢ
49
+ trace[ (P - N) A Q' ]
50
+ trace[ Q' (P - N) A ]
51
+ trace[ M A ]
52
+ trace[ U Σ V' A ] where U Σ V' := M is the SVD of M
53
+ trace[ Σ V' A U ]
54
+ s.t. A'A == 𝕀
55
+ = V U'
56
+
57
+ Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
58
+ incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
59
+ To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
60
+ and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
61
+ between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
62
+ would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
63
+ C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
64
+ """
65
+ config = config or RAGLiteConfig()
66
+ config_no_query_adapter = RAGLiteConfig(
67
+ **{**config.__dict__, "vector_search_query_adapter": False}
68
+ )
69
+ engine = create_database_engine(config)
70
+ with Session(engine) as session:
71
+ # Get random evals from the database.
72
+ chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
73
+ if chunk_embedding is None:
74
+ error_message = "First run `insert_document()` to insert documents."
75
+ raise ValueError(error_message)
76
+ evals = session.exec(
77
+ select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
78
+ ).all()
79
+ if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding):
80
+ error_message = "First run `insert_evals()` to generate sufficient evals."
81
+ raise ValueError(error_message)
82
+ # Loop over the evals to generate (q, p, n) triplets.
83
+ Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806
84
+ P = np.zeros_like(Q) # noqa: N806
85
+ N = np.zeros_like(Q) # noqa: N806
86
+ for eval_ in tqdm(
87
+ evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
88
+ ):
89
+ # Embed the question.
90
+ question_embedding = embed_sentences([eval_.question], config=config)
91
+ # Retrieve chunks that would be used to answer the question.
92
+ chunk_ids, _ = vector_search(
93
+ question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
94
+ )
95
+ retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
96
+ # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
97
+ num_triplets = 0
98
+ for i, retrieved_chunk in enumerate(retrieved_chunks):
99
+ # Select irrelevant chunks.
100
+ if retrieved_chunk.id not in eval_.chunk_ids:
101
+ # Look up all positive chunks (each represented by the mean of its multi-vector
102
+ # embedding) that are ranked lower than this negative one (represented by the
103
+ # embedding in the multi-vector embedding that best matches the query).
104
+ p_mean = [
105
+ np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
106
+ for chunk in retrieved_chunks[i + 1 :]
107
+ if chunk is not None and chunk.id in eval_.chunk_ids
108
+ ]
109
+ n_top = retrieved_chunk.embedding_matrix[
110
+ np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
111
+ np.newaxis,
112
+ :,
113
+ ]
114
+ # Filter out any (p, n, q) triplets for which the mean positive embedding ranks
115
+ # higher than the top negative one.
116
+ p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
117
+ if not p_mean:
118
+ continue
119
+ # Stack the (p, n, q) triplets.
120
+ p = np.vstack(p_mean)
121
+ n = np.repeat(n_top, p.shape[0], axis=0)
122
+ q = np.repeat(question_embedding, p.shape[0], axis=0)
123
+ num_triplets += p.shape[0]
124
+ # Append the (query, positive, negative) tuples to the Q, P, N matrices.
125
+ Q = np.vstack([Q, q]) # noqa: N806
126
+ P = np.vstack([P, p]) # noqa: N806
127
+ N = np.vstack([N, n]) # noqa: N806
128
+ # Check if we have sufficient triplets for this eval.
129
+ if num_triplets >= max_triplets_per_eval:
130
+ break
131
+ # Check if we have sufficient triplets to compute the query adapter.
132
+ if Q.shape[0] > max_triplets:
133
+ Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
134
+ break
135
+ # Normalise the rows of Q, P, N.
136
+ Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
137
+ P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806
138
+ N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806
139
+ # Compute the optimal weighted query adapter A*.
140
+ # TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
141
+ gap_before = np.sum((P - N) * Q, axis=1)
142
+ gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
143
+ gap_target = 0.05 * gap_after
144
+ α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
145
+ MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
146
+ s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
147
+ MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
148
+ if config.vector_search_index_metric == "dot":
149
+ # Use the relaxed Procrustes solution.
150
+ A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
151
+ elif config.vector_search_index_metric == "cosine":
152
+ # Use the orthogonal Procrustes solution.
153
+ U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806
154
+ A_star = U @ VT # noqa: N806
155
+ else:
156
+ error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}"
157
+ raise ValueError(error_message)
158
+ # Store the optimal query adapter in the database.
159
+ index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
160
+ index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
161
+ session.add(index_metadata)
162
+ session.commit()
src/raglite/_rag.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retrieval-augmented generation."""
2
+
3
+ from collections.abc import AsyncIterator, Iterator
4
+
5
+ from litellm import acompletion, completion, get_model_info # type: ignore[attr-defined]
6
+
7
+ from raglite._config import RAGLiteConfig
8
+ from raglite._database import Chunk
9
+ from raglite._litellm import LlamaCppPythonLLM
10
+ from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
11
+ from raglite._typing import SearchMethod
12
+
13
+ RAG_SYSTEM_PROMPT = """
14
+ You are a friendly and knowledgeable assistant that provides complete and insightful answers.
15
+ Answer the user's question using only the context below.
16
+ When responding, you MUST NOT reference the existence of the context, directly or indirectly.
17
+ Instead, you MUST treat the context as if its contents are entirely part of your working memory.
18
+ """.strip()
19
+
20
+
21
+ def _max_contexts(
22
+ prompt: str,
23
+ *,
24
+ max_contexts: int = 5,
25
+ context_neighbors: tuple[int, ...] | None = (-1, 1),
26
+ messages: list[dict[str, str]] | None = None,
27
+ config: RAGLiteConfig | None = None,
28
+ ) -> int:
29
+ """Determine the maximum number of contexts for RAG."""
30
+ # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
31
+ # to date by loading that LLM.
32
+ config = config or RAGLiteConfig()
33
+ if config.llm.startswith("llama-cpp-python"):
34
+ _ = LlamaCppPythonLLM.llm(config.llm)
35
+ # Get the model's maximum context size.
36
+ llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
37
+ model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
38
+ max_tokens = model_info.get("max_tokens") or 2048
39
+ # Reduce the maximum number of contexts to take into account the LLM's context size.
40
+ max_context_tokens = (
41
+ max_tokens
42
+ - sum(len(message["content"]) // 3 for message in messages or []) # Previous messages.
43
+ - len(RAG_SYSTEM_PROMPT) // 3 # System prompt.
44
+ - len(prompt) // 3 # User prompt.
45
+ )
46
+ max_tokens_per_context = config.chunk_max_size // 3
47
+ max_tokens_per_context *= 1 + len(context_neighbors or [])
48
+ max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context)
49
+ if max_contexts <= 0:
50
+ error_message = "Not enough context tokens available for RAG."
51
+ raise ValueError(error_message)
52
+ return max_contexts
53
+
54
+
55
+ def _contexts( # noqa: PLR0913
56
+ prompt: str,
57
+ *,
58
+ max_contexts: int = 5,
59
+ context_neighbors: tuple[int, ...] | None = (-1, 1),
60
+ search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
61
+ messages: list[dict[str, str]] | None = None,
62
+ config: RAGLiteConfig | None = None,
63
+ ) -> list[str]:
64
+ """Retrieve contexts for RAG."""
65
+ # Determine the maximum number of contexts.
66
+ max_contexts = _max_contexts(
67
+ prompt,
68
+ max_contexts=max_contexts,
69
+ context_neighbors=context_neighbors,
70
+ messages=messages,
71
+ config=config,
72
+ )
73
+ # Retrieve the top chunks.
74
+ config = config or RAGLiteConfig()
75
+ chunks: list[str] | list[Chunk]
76
+ if callable(search):
77
+ # If the user has configured a reranker, we retrieve extra contexts to rerank.
78
+ extra_contexts = 3 * max_contexts if config.reranker else 0
79
+ # Retrieve relevant contexts.
80
+ chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
81
+ # Rerank the relevant contexts.
82
+ chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
83
+ else:
84
+ # The user has passed a list of chunk_ids or chunks directly.
85
+ chunks = search
86
+ # Extend the top contexts with their neighbors and group chunks into contiguous segments.
87
+ segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
88
+ return segments
89
+
90
+
91
+ def rag( # noqa: PLR0913
92
+ prompt: str,
93
+ *,
94
+ max_contexts: int = 5,
95
+ context_neighbors: tuple[int, ...] | None = (-1, 1),
96
+ search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
97
+ messages: list[dict[str, str]] | None = None,
98
+ system_prompt: str = RAG_SYSTEM_PROMPT,
99
+ config: RAGLiteConfig | None = None,
100
+ ) -> Iterator[str]:
101
+ """Retrieval-augmented generation."""
102
+ # Get the contexts for RAG as contiguous segments of chunks.
103
+ config = config or RAGLiteConfig()
104
+ segments = _contexts(
105
+ prompt,
106
+ max_contexts=max_contexts,
107
+ context_neighbors=context_neighbors,
108
+ search=search,
109
+ config=config,
110
+ )
111
+ system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
112
+ f'<context index="{i}">\n{segment.strip()}\n</context>'
113
+ for i, segment in enumerate(segments)
114
+ )
115
+ # Stream the LLM response.
116
+ stream = completion(
117
+ model=config.llm,
118
+ messages=[
119
+ *(messages or []),
120
+ {"role": "system", "content": system_prompt},
121
+ {"role": "user", "content": prompt},
122
+ ],
123
+ stream=True,
124
+ )
125
+ for output in stream:
126
+ token: str = output["choices"][0]["delta"].get("content") or ""
127
+ yield token
128
+
129
+
130
+ async def async_rag( # noqa: PLR0913
131
+ prompt: str,
132
+ *,
133
+ max_contexts: int = 5,
134
+ context_neighbors: tuple[int, ...] | None = (-1, 1),
135
+ search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
136
+ messages: list[dict[str, str]] | None = None,
137
+ system_prompt: str = RAG_SYSTEM_PROMPT,
138
+ config: RAGLiteConfig | None = None,
139
+ ) -> AsyncIterator[str]:
140
+ """Retrieval-augmented generation."""
141
+ # Get the contexts for RAG as contiguous segments of chunks.
142
+ config = config or RAGLiteConfig()
143
+ segments = _contexts(
144
+ prompt,
145
+ max_contexts=max_contexts,
146
+ context_neighbors=context_neighbors,
147
+ search=search,
148
+ config=config,
149
+ )
150
+ system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
151
+ f'<context index="{i}">\n{segment.strip()}\n</context>'
152
+ for i, segment in enumerate(segments)
153
+ )
154
+ # Stream the LLM response.
155
+ async_stream = await acompletion(
156
+ model=config.llm,
157
+ messages=[
158
+ *(messages or []),
159
+ {"role": "system", "content": system_prompt},
160
+ {"role": "user", "content": prompt},
161
+ ],
162
+ stream=True,
163
+ )
164
+ async for output in async_stream:
165
+ token: str = output["choices"][0]["delta"].get("content") or ""
166
+ yield token
src/raglite/_search.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query documents."""
2
+
3
+ import re
4
+ import string
5
+ from collections import defaultdict
6
+ from collections.abc import Sequence
7
+ from itertools import groupby
8
+ from typing import cast
9
+
10
+ import numpy as np
11
+ from langdetect import detect
12
+ from sqlalchemy.engine import make_url
13
+ from sqlmodel import Session, and_, col, or_, select, text
14
+
15
+ from raglite._config import RAGLiteConfig
16
+ from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine
17
+ from raglite._embed import embed_sentences
18
+ from raglite._typing import FloatMatrix
19
+
20
+
21
+ def vector_search(
22
+ query: str | FloatMatrix,
23
+ *,
24
+ num_results: int = 3,
25
+ config: RAGLiteConfig | None = None,
26
+ ) -> tuple[list[str], list[float]]:
27
+ """Search chunks using ANN vector search."""
28
+ # Read the config.
29
+ config = config or RAGLiteConfig()
30
+ db_backend = make_url(config.db_url).get_backend_name()
31
+ # Get the index metadata (including the query adapter, and in the case of SQLite, the index).
32
+ index_metadata = IndexMetadata.get("default", config=config)
33
+ # Embed the query.
34
+ query_embedding = (
35
+ embed_sentences([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query)
36
+ )
37
+ # Apply the query adapter to the query embedding.
38
+ Q = index_metadata.get("query_adapter") # noqa: N806
39
+ if config.vector_search_query_adapter and Q is not None:
40
+ query_embedding = (Q @ query_embedding).astype(query_embedding.dtype)
41
+ # Search for the multi-vector chunk embeddings that are most similar to the query embedding.
42
+ if db_backend == "postgresql":
43
+ # Check that the selected metric is supported by pgvector.
44
+ metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"}
45
+ if config.vector_search_index_metric not in metrics:
46
+ error_message = f"Unsupported metric {config.vector_search_index_metric}."
47
+ raise ValueError(error_message)
48
+ # With pgvector, we can obtain the nearest neighbours and similarities with a single query.
49
+ engine = create_database_engine(config)
50
+ with Session(engine) as session:
51
+ distance_func = getattr(
52
+ ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance"
53
+ )
54
+ distance = distance_func(query_embedding).label("distance")
55
+ results = session.exec(
56
+ select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results)
57
+ )
58
+ chunk_ids_, distance = zip(*results, strict=True)
59
+ chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
60
+ elif db_backend == "sqlite":
61
+ # Load the NNDescent index.
62
+ index = index_metadata.get("index")
63
+ ids = np.asarray(index_metadata.get("chunk_ids"))
64
+ cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes")))
65
+ # Find the neighbouring multi-vector indices.
66
+ from pynndescent import NNDescent
67
+
68
+ multi_vector_indices, distance = cast(NNDescent, index).query(
69
+ query_embedding[np.newaxis, :], k=8 * num_results
70
+ )
71
+ similarity = 1 - distance[0, :]
72
+ # Transform the multi-vector indices into chunk indices, and then to chunk ids.
73
+ chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
74
+ chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
75
+ # Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
76
+ # fewer hits are padded with the minimum similarity of the result set.
77
+ unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True)
78
+ score = np.full(
79
+ (len(unique_chunk_ids), np.max(counts)), np.min(similarity), dtype=similarity.dtype
80
+ )
81
+ for i, (unique_chunk_id, count) in enumerate(zip(unique_chunk_ids, counts, strict=True)):
82
+ score[i, :count] = similarity[chunk_ids == unique_chunk_id]
83
+ pooled_similarity = np.mean(score, axis=1)
84
+ # Sort the chunk ids by their adjusted similarity.
85
+ sorted_indices = np.argsort(pooled_similarity)[::-1]
86
+ unique_chunk_ids = unique_chunk_ids[sorted_indices][:num_results]
87
+ pooled_similarity = pooled_similarity[sorted_indices][:num_results]
88
+ return unique_chunk_ids.tolist(), pooled_similarity.tolist()
89
+
90
+
91
+ def keyword_search(
92
+ query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
93
+ ) -> tuple[list[str], list[float]]:
94
+ """Search chunks using BM25 keyword search."""
95
+ # Read the config.
96
+ config = config or RAGLiteConfig()
97
+ db_backend = make_url(config.db_url).get_backend_name()
98
+ # Connect to the database.
99
+ engine = create_database_engine(config)
100
+ with Session(engine) as session:
101
+ if db_backend == "postgresql":
102
+ # Convert the query to a tsquery [1].
103
+ # [1] https://www.postgresql.org/docs/current/textsearch-controls.html
104
+ query_escaped = re.sub(r"[&|!():<>\"]", " ", query)
105
+ tsv_query = " | ".join(query_escaped.split())
106
+ # Perform keyword search with tsvector.
107
+ statement = text("""
108
+ SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
109
+ FROM chunk
110
+ WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
111
+ ORDER BY score DESC
112
+ LIMIT :limit;
113
+ """)
114
+ results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
115
+ elif db_backend == "sqlite":
116
+ # Convert the query to an FTS5 query [1].
117
+ # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax
118
+ query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", query)
119
+ fts5_query = " OR ".join(query_escaped.split())
120
+ # Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we
121
+ # negate them to make them positive.
122
+ # [1] https://www.sqlite.org/fts5.html#the_bm25_function
123
+ statement = text("""
124
+ SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score
125
+ FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid
126
+ WHERE keyword_search_chunk_index MATCH :match
127
+ ORDER BY score DESC
128
+ LIMIT :limit;
129
+ """)
130
+ results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
131
+ # Unpack the results.
132
+ chunk_ids, keyword_score = zip(*results, strict=True)
133
+ chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment]
134
+ return chunk_ids, keyword_score # type: ignore[return-value]
135
+
136
+
137
+ def reciprocal_rank_fusion(
138
+ rankings: list[list[str]], *, k: int = 60
139
+ ) -> tuple[list[str], list[float]]:
140
+ """Reciprocal Rank Fusion."""
141
+ # Compute the RRF score.
142
+ chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking}
143
+ chunk_id_score: defaultdict[str, float] = defaultdict(float)
144
+ for ranking in rankings:
145
+ chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)}
146
+ for chunk_id in chunk_ids:
147
+ chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index)))
148
+ # Rank RRF results according to descending RRF score.
149
+ rrf_chunk_ids, rrf_score = zip(
150
+ *sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
151
+ )
152
+ return list(rrf_chunk_ids), list(rrf_score)
153
+
154
+
155
+ def hybrid_search(
156
+ query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
157
+ ) -> tuple[list[str], list[float]]:
158
+ """Search chunks by combining ANN vector search with BM25 keyword search."""
159
+ # Run both searches.
160
+ vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config)
161
+ ks_chunk_ids, _ = keyword_search(query, num_results=num_rerank, config=config)
162
+ # Combine the results with Reciprocal Rank Fusion (RRF).
163
+ chunk_ids, hybrid_score = reciprocal_rank_fusion([vs_chunk_ids, ks_chunk_ids])
164
+ chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results]
165
+ return chunk_ids, hybrid_score
166
+
167
+
168
+ def retrieve_chunks(
169
+ chunk_ids: list[str],
170
+ *,
171
+ config: RAGLiteConfig | None = None,
172
+ ) -> list[Chunk]:
173
+ """Retrieve chunks by their ids."""
174
+ config = config or RAGLiteConfig()
175
+ engine = create_database_engine(config)
176
+ with Session(engine) as session:
177
+ chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all())
178
+ chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
179
+ return chunks
180
+
181
+
182
+ def retrieve_segments(
183
+ chunk_ids: list[str] | list[Chunk],
184
+ *,
185
+ neighbors: tuple[int, ...] | None = (-1, 1),
186
+ config: RAGLiteConfig | None = None,
187
+ ) -> list[str]:
188
+ """Group chunks into contiguous segments and retrieve them."""
189
+ # Retrieve the chunks.
190
+ config = config or RAGLiteConfig()
191
+ chunks: list[Chunk] = (
192
+ retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
193
+ if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
194
+ else chunk_ids
195
+ )
196
+ # Extend the chunks with their neighbouring chunks.
197
+ if neighbors:
198
+ engine = create_database_engine(config)
199
+ with Session(engine) as session:
200
+ neighbor_conditions = [
201
+ and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
202
+ for chunk in chunks
203
+ for offset in neighbors
204
+ ]
205
+ chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all())
206
+ # Keep only the unique chunks.
207
+ chunks = list(set(chunks))
208
+ # Sort the chunks by document_id and index (needed for groupby).
209
+ chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index))
210
+ # Group the chunks into contiguous segments.
211
+ segments: list[list[Chunk]] = []
212
+ for _, group in groupby(chunks, key=lambda chunk: chunk.document_id):
213
+ segment: list[Chunk] = []
214
+ for chunk in group:
215
+ if not segment or chunk.index == segment[-1].index + 1:
216
+ segment.append(chunk)
217
+ else:
218
+ segments.append(segment)
219
+ segment = [chunk]
220
+ segments.append(segment)
221
+ # Rank segments according to the aggregate relevance of their chunks.
222
+ chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
223
+ segments.sort(
224
+ key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment),
225
+ reverse=True,
226
+ )
227
+ # Convert the segments into strings.
228
+ segments = [
229
+ segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc]
230
+ for segment in segments
231
+ ]
232
+ return segments # type: ignore[return-value]
233
+
234
+
235
+ def rerank_chunks(
236
+ query: str,
237
+ chunk_ids: list[str] | list[Chunk],
238
+ *,
239
+ config: RAGLiteConfig | None = None,
240
+ ) -> list[Chunk]:
241
+ """Rerank chunks according to their relevance to a given query."""
242
+ # Retrieve the chunks.
243
+ config = config or RAGLiteConfig()
244
+ chunks: list[Chunk] = (
245
+ retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
246
+ if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
247
+ else chunk_ids
248
+ )
249
+ # Early exit if no reranker is configured.
250
+ if not config.reranker:
251
+ return chunks
252
+ # Select the reranker.
253
+ if isinstance(config.reranker, Sequence):
254
+ # Detect the languages of the chunks and queries.
255
+ langs = {detect(str(chunk)) for chunk in chunks}
256
+ langs.add(detect(query))
257
+ # If all chunks and the query are in the same language, use a language-specific reranker.
258
+ rerankers = dict(config.reranker)
259
+ if len(langs) == 1 and (lang := next(iter(langs))) in rerankers:
260
+ reranker = rerankers[lang]
261
+ else:
262
+ reranker = rerankers.get("other")
263
+ else:
264
+ # A specific reranker was configured.
265
+ reranker = config.reranker
266
+ # Rerank the chunks.
267
+ if reranker:
268
+ results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks])
269
+ chunks = [chunks[result.doc_id] for result in results.results]
270
+ return chunks
src/raglite/_split_chunks.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Split a document into semantic chunks."""
2
+
3
+ import re
4
+
5
+ import numpy as np
6
+ from scipy.optimize import linprog
7
+ from scipy.sparse import coo_matrix
8
+
9
+ from raglite._typing import FloatMatrix
10
+
11
+
12
+ def split_chunks( # noqa: C901, PLR0915
13
+ sentences: list[str],
14
+ sentence_embeddings: FloatMatrix,
15
+ sentence_window_size: int = 3,
16
+ max_size: int = 1440,
17
+ ) -> tuple[list[str], list[FloatMatrix]]:
18
+ """Split sentences into optimal semantic chunks with corresponding sentence embeddings."""
19
+ # Validate the input.
20
+ sentence_length = np.asarray([len(sentence) for sentence in sentences])
21
+ if not np.all(sentence_length <= max_size):
22
+ error_message = "Sentence with length larger than chunk max_size detected."
23
+ raise ValueError(error_message)
24
+ if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0):
25
+ error_message = "Sentence embeddings with zero norm detected."
26
+ raise ValueError(error_message)
27
+ # Exit early if there is only one chunk to return.
28
+ if len(sentences) <= 1 or sum(sentence_length) <= max_size:
29
+ return ["".join(sentences)] if sentences else sentences, [sentence_embeddings]
30
+ # Normalise the sentence embeddings to unit norm.
31
+ X = sentence_embeddings.astype(np.float32) # noqa: N806
32
+ X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
33
+ # Select nonoutlying sentences and remove the discourse vector.
34
+ q15, q85 = np.quantile(sentence_length, [0.15, 0.85])
35
+ nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85)
36
+ discourse = np.mean(X[nonoutlying_sentences, :], axis=0)
37
+ discourse = discourse / np.linalg.norm(discourse)
38
+ if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps):
39
+ X = X - np.outer(X @ discourse, discourse) # noqa: N806
40
+ X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
41
+ # For each partition point in the list of sentences, compute the similarity of the windows
42
+ # before and after the partition point. Sentence embeddings are assumed to be of the sentence
43
+ # itself and at most the (sentence_window_size - 1) sentences that preceed it.
44
+ sentence_window_size = min(len(sentences) - 1, sentence_window_size)
45
+ windows_before = X[:-sentence_window_size]
46
+ windows_after = X[sentence_window_size:]
47
+ partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype)
48
+ partition_similarity[: len(windows_before)] = np.sum(windows_before * windows_after, axis=1)
49
+ # Make partition similarity nonnegative before modification and optimisation.
50
+ partition_similarity = np.maximum(
51
+ (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps)
52
+ )
53
+ # Modify the partition similarity to encourage splitting on Markdown headings.
54
+ prev_sentence_is_heading = True
55
+ for i, sentence in enumerate(sentences[:-1]):
56
+ is_heading = bool(re.match(r"^#+\s", sentence.replace("\n", "").strip()))
57
+ if is_heading:
58
+ # Encourage splitting before a heading.
59
+ if not prev_sentence_is_heading:
60
+ partition_similarity[i - 1] = partition_similarity[i - 1] / 4
61
+ # Don't split immediately after a heading.
62
+ partition_similarity[i] = 1.0
63
+ prev_sentence_is_heading = is_heading
64
+ # Solve an optimisation problem to find the best partition points.
65
+ sentence_length_cumsum = np.cumsum(sentence_length)
66
+ row_indices = []
67
+ col_indices = []
68
+ data = []
69
+ for i in range(len(sentences) - 1):
70
+ r = sentence_length_cumsum[i - 1] if i > 0 else 0
71
+ idx = np.searchsorted(sentence_length_cumsum - r, max_size)
72
+ assert idx > i
73
+ if idx == len(sentence_length_cumsum):
74
+ break
75
+ cols = list(range(i, idx))
76
+ col_indices.extend(cols)
77
+ row_indices.extend([i] * len(cols))
78
+ data.extend([1] * len(cols))
79
+ A = coo_matrix( # noqa: N806
80
+ (data, (row_indices, col_indices)),
81
+ shape=(max(row_indices) + 1, len(sentences) - 1),
82
+ dtype=np.float32,
83
+ )
84
+ b_ub = np.ones(A.shape[0], dtype=np.float32)
85
+ res = linprog(
86
+ partition_similarity,
87
+ A_ub=-A,
88
+ b_ub=-b_ub,
89
+ bounds=(0, 1),
90
+ integrality=[1] * A.shape[1],
91
+ )
92
+ if not res.success:
93
+ error_message = "Optimization of chunk partitions failed."
94
+ raise ValueError(error_message)
95
+ # Split the sentences and their window embeddings into optimal chunks.
96
+ partition_indices = (np.where(res.x)[0] + 1).tolist()
97
+ chunks = [
98
+ "".join(sentences[i:j])
99
+ for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True)
100
+ ]
101
+ chunk_embeddings = np.split(sentence_embeddings, partition_indices)
102
+ return chunks, chunk_embeddings
src/raglite/_split_sentences.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sentence splitter."""
2
+
3
+ import re
4
+
5
+ import spacy
6
+ from markdown_it import MarkdownIt
7
+ from spacy.language import Language
8
+
9
+
10
+ @Language.component("_mark_additional_sentence_boundaries")
11
+ def _mark_additional_sentence_boundaries(doc: spacy.tokens.Doc) -> spacy.tokens.Doc:
12
+ """Mark additional sentence boundaries in Markdown documents."""
13
+
14
+ def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]:
15
+ """Get the indexes of the headings in a Markdown document."""
16
+ md = MarkdownIt()
17
+ tokens = md.parse(doc)
18
+ headings = []
19
+ lines = doc.splitlines(keepends=True)
20
+ char_idx = [0]
21
+ for line in lines:
22
+ char_idx.append(char_idx[-1] + len(line))
23
+ for token in tokens:
24
+ if token.type == "heading_open":
25
+ start_line, end_line = token.map # type: ignore[misc]
26
+ heading_start = char_idx[start_line]
27
+ heading_end = char_idx[end_line]
28
+ headings.append((heading_start, heading_end))
29
+ return headings
30
+
31
+ headings = get_markdown_heading_indexes(doc.text)
32
+ for heading_start, heading_end in headings:
33
+ # Mark the start of a heading as a new sentence.
34
+ for token in doc:
35
+ if heading_start <= token.idx:
36
+ token.is_sent_start = True
37
+ break
38
+ # Mark the end of a heading as a new sentence.
39
+ for token in doc:
40
+ if heading_end <= token.idx:
41
+ token.is_sent_start = True
42
+ break
43
+ return doc
44
+
45
+
46
+ def split_sentences(doc: str, max_len: int | None = None) -> list[str]:
47
+ """Split a document into sentences."""
48
+ # Split sentences with spaCy.
49
+ try:
50
+ nlp = spacy.load("xx_sent_ud_sm")
51
+ except OSError as error:
52
+ error_message = "Please install `xx_sent_ud_sm` with `pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl`."
53
+ raise ImportError(error_message) from error
54
+ nlp.add_pipe("_mark_additional_sentence_boundaries", before="senter")
55
+ sentences = [sent.text_with_ws for sent in nlp(doc).sents if sent.text.strip()]
56
+ # Apply additional splits on paragraphs and sentences because spaCy's splitting is not perfect.
57
+ if max_len is not None:
58
+ for pattern in (r"(?<=\n\n)", r"(?<=\.\s)"):
59
+ sentences = [
60
+ part
61
+ for sent in sentences
62
+ for part in ([sent] if len(sent) <= max_len else re.split(pattern, sent))
63
+ ]
64
+ # Recursively split long sentences in the middle if they are still too long.
65
+ if max_len is not None:
66
+ while any(len(sentence) > max_len for sentence in sentences):
67
+ sentences = [
68
+ part
69
+ for sent in sentences
70
+ for part in (
71
+ [sent]
72
+ if len(sent) <= max_len
73
+ else [sent[: len(sent) // 2], sent[len(sent) // 2 :]]
74
+ )
75
+ ]
76
+ return sentences
src/raglite/_typing.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RAGLite typing."""
2
+
3
+ import io
4
+ import pickle
5
+ from collections.abc import Callable
6
+ from typing import Any, Protocol
7
+
8
+ import numpy as np
9
+ from sqlalchemy.engine import Dialect
10
+ from sqlalchemy.sql.operators import Operators
11
+ from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType
12
+
13
+ from raglite._config import RAGLiteConfig
14
+
15
+ FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
16
+ FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
17
+ IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
18
+
19
+
20
+ class SearchMethod(Protocol):
21
+ def __call__(
22
+ self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
23
+ ) -> tuple[list[str], list[float]]: ...
24
+
25
+
26
+ class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
27
+ """A NumPy array column type for SQLAlchemy."""
28
+
29
+ impl = LargeBinary
30
+
31
+ def process_bind_param(
32
+ self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
33
+ ) -> bytes | None:
34
+ """Convert a NumPy array to bytes."""
35
+ if value is None:
36
+ return None
37
+ buffer = io.BytesIO()
38
+ np.save(buffer, value, allow_pickle=False, fix_imports=False)
39
+ return buffer.getvalue()
40
+
41
+ def process_result_value(
42
+ self, value: bytes | None, dialect: Dialect
43
+ ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
44
+ """Convert bytes to a NumPy array."""
45
+ if value is None:
46
+ return None
47
+ return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
48
+
49
+
50
+ class PickledObject(TypeDecorator[object]):
51
+ """A pickled object column type for SQLAlchemy."""
52
+
53
+ impl = LargeBinary
54
+
55
+ def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
56
+ """Convert a Python object to bytes."""
57
+ if value is None:
58
+ return None
59
+ return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
60
+
61
+ def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
62
+ """Convert bytes to a Python object."""
63
+ if value is None:
64
+ return None
65
+ return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
66
+
67
+
68
+ class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]):
69
+ """A mixin that provides comparison operators for halfvecs."""
70
+
71
+ def cosine_distance(self, other: FloatVector) -> Operators:
72
+ """Compute the cosine distance."""
73
+ return self.op("<=>", return_type=Float)(other)
74
+
75
+ def dot_distance(self, other: FloatVector) -> Operators:
76
+ """Compute the dot product distance."""
77
+ return self.op("<#>", return_type=Float)(other)
78
+
79
+ def euclidean_distance(self, other: FloatVector) -> Operators:
80
+ """Compute the Euclidean distance."""
81
+ return self.op("<->", return_type=Float)(other)
82
+
83
+ def l1_distance(self, other: FloatVector) -> Operators:
84
+ """Compute the L1 distance."""
85
+ return self.op("<+>", return_type=Float)(other)
86
+
87
+ def l2_distance(self, other: FloatVector) -> Operators:
88
+ """Compute the L2 distance."""
89
+ return self.op("<->", return_type=Float)(other)
90
+
91
+
92
+ class HalfVec(UserDefinedType[FloatVector]):
93
+ """A PostgreSQL half-precision vector column type for SQLAlchemy."""
94
+
95
+ cache_ok = True # HalfVec is immutable.
96
+
97
+ def __init__(self, dim: int | None = None) -> None:
98
+ super().__init__()
99
+ self.dim = dim
100
+
101
+ def get_col_spec(self, **kwargs: Any) -> str:
102
+ return f"halfvec({self.dim})"
103
+
104
+ def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
105
+ """Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""
106
+
107
+ def process(value: FloatVector | None) -> str | None:
108
+ return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None
109
+
110
+ return process
111
+
112
+ def result_processor(
113
+ self, dialect: Dialect, coltype: Any
114
+ ) -> Callable[[str | None], FloatVector | None]:
115
+ """Process PostgreSQL halfvec format to NumPy ndarray."""
116
+
117
+ def process(value: str | None) -> FloatVector | None:
118
+ if value is None:
119
+ return None
120
+ return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)
121
+
122
+ return process
123
+
124
+ class comparator_factory(HalfVecComparatorMixin): # noqa: N801
125
+ ...
126
+
127
+
128
+ class Embedding(TypeDecorator[FloatVector]):
129
+ """An embedding column type for SQLAlchemy."""
130
+
131
+ cache_ok = True # Embedding is immutable.
132
+
133
+ impl = NumpyArray
134
+
135
+ def __init__(self, dim: int = -1):
136
+ super().__init__()
137
+ self.dim = dim
138
+
139
+ def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
140
+ if dialect.name == "postgresql":
141
+ return dialect.type_descriptor(HalfVec(self.dim))
142
+ return dialect.type_descriptor(NumpyArray())
143
+
144
+ class comparator_factory(HalfVecComparatorMixin): # noqa: N801
145
+ ...
src/raglite/py.typed ADDED
File without changes
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """RAGLite test suite."""
tests/conftest.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fixtures for the tests."""
2
+
3
+ import os
4
+ import socket
5
+ import tempfile
6
+ from collections.abc import Generator
7
+ from pathlib import Path
8
+
9
+ import pytest
10
+ from sqlalchemy import create_engine, text
11
+
12
+ from raglite import RAGLiteConfig, insert_document
13
+
14
+ POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"
15
+
16
+
17
+ def is_postgres_running() -> bool:
18
+ """Check if PostgreSQL is running."""
19
+ try:
20
+ with socket.create_connection(("postgres", 5432), timeout=1):
21
+ return True
22
+ except OSError:
23
+ return False
24
+
25
+
26
+ def is_openai_available() -> bool:
27
+ """Check if an OpenAI API key is set."""
28
+ return bool(os.environ.get("OPENAI_API_KEY"))
29
+
30
+
31
+ def pytest_sessionstart(session: pytest.Session) -> None:
32
+ """Reset the PostgreSQL and SQLite databases."""
33
+ if is_postgres_running():
34
+ engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
35
+ with engine.connect() as conn:
36
+ for variant in ["local", "remote"]:
37
+ conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}"))
38
+ conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))
39
+
40
+
41
+ @pytest.fixture(scope="session")
42
+ def sqlite_url() -> Generator[str, None, None]:
43
+ """Create a temporary SQLite database file and return the database URL."""
44
+ with tempfile.TemporaryDirectory() as temp_dir:
45
+ db_file = Path(temp_dir) / "raglite_test.sqlite"
46
+ yield f"sqlite:///{db_file}"
47
+
48
+
49
+ @pytest.fixture(
50
+ scope="session",
51
+ params=[
52
+ pytest.param("sqlite", id="sqlite"),
53
+ pytest.param(
54
+ POSTGRES_URL,
55
+ id="postgres",
56
+ marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
57
+ ),
58
+ ],
59
+ )
60
+ def database(request: pytest.FixtureRequest) -> str:
61
+ """Get a database URL to test RAGLite with."""
62
+ db_url: str = (
63
+ request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
64
+ )
65
+ return db_url
66
+
67
+
68
+ @pytest.fixture(
69
+ scope="session",
70
+ params=[
71
+ pytest.param(
72
+ "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
73
+ id="bge_m3",
74
+ ),
75
+ pytest.param(
76
+ "text-embedding-3-small",
77
+ id="openai_text_embedding_3_small",
78
+ marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
79
+ ),
80
+ ],
81
+ )
82
+ def embedder(request: pytest.FixtureRequest) -> str:
83
+ """Get an embedder model URL to test RAGLite with."""
84
+ embedder: str = request.param
85
+ return embedder
86
+
87
+
88
+ @pytest.fixture(scope="session")
89
+ def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
90
+ """Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
91
+ # Select the database based on the embedder.
92
+ variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
93
+ if "postgres" in database:
94
+ database = database.replace("/postgres", f"/raglite_test_{variant}")
95
+ elif "sqlite" in database:
96
+ database = database.replace(".sqlite", f"_{variant}.sqlite")
97
+ # Create a RAGLite config for the given database and embedder.
98
+ db_config = RAGLiteConfig(db_url=database, embedder=embedder)
99
+ # Insert a document and update the index.
100
+ doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
101
+ insert_document(doc_path, config=db_config)
102
+ return db_config
tests/specrel.pdf ADDED
Binary file (178 kB). View file
 
tests/test_embed.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite's embedding functionality."""
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+
7
+ from raglite import RAGLiteConfig
8
+ from raglite._embed import embed_sentences
9
+ from raglite._markdown import document_to_markdown
10
+ from raglite._split_sentences import split_sentences
11
+
12
+
13
+ def test_embed(embedder: str) -> None:
14
+ """Test embedding a document."""
15
+ raglite_test_config = RAGLiteConfig(embedder=embedder, embedder_normalize=True)
16
+ doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
17
+ doc = document_to_markdown(doc_path)
18
+ sentences = split_sentences(doc, max_len=raglite_test_config.chunk_max_size)
19
+ sentence_embeddings = embed_sentences(sentences, config=raglite_test_config)
20
+ assert isinstance(sentences, list)
21
+ assert isinstance(sentence_embeddings, np.ndarray)
22
+ assert len(sentences) == len(sentence_embeddings)
23
+ assert sentence_embeddings.shape[1] >= 128 # noqa: PLR2004
24
+ assert sentence_embeddings.dtype == np.float16
25
+ assert np.all(np.isfinite(sentence_embeddings))
26
+ assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0, rtol=1e-3)
tests/test_import.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite."""
2
+
3
+ import raglite
4
+
5
+
6
+ def test_import() -> None:
7
+ """Test that the package can be imported."""
8
+ assert isinstance(raglite.__name__, str)
tests/test_markdown.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test Markdown conversion."""
2
+
3
+ from pathlib import Path
4
+
5
+ from raglite._markdown import document_to_markdown
6
+
7
+
8
+ def test_pdf_with_missing_font_sizes() -> None:
9
+ """Test conversion of a PDF with missing font sizes."""
10
+ # Convert a PDF whose parsed font sizes are all equal to 1.
11
+ doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
12
+ doc = document_to_markdown(doc_path)
13
+ # Verify that we can reconstruct the font sizes and heading levels regardless of the missing
14
+ # font size data.
15
+ expected_heading = """
16
+ # ON THE ELECTRODYNAMICS OF MOVING BODIES
17
+
18
+ ## By A. EINSTEIN June 30, 1905
19
+
20
+ It is known that Maxwell
21
+ """.strip()
22
+ assert doc.startswith(expected_heading)
tests/test_rag.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite's RAG functionality."""
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ import pytest
7
+ from llama_cpp import llama_supports_gpu_offload
8
+
9
+ from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks
10
+
11
+ if TYPE_CHECKING:
12
+ from raglite._database import Chunk
13
+ from raglite._typing import SearchMethod
14
+
15
+
16
+ def is_accelerator_available() -> bool:
17
+ """Check if an accelerator is available."""
18
+ return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
19
+
20
+
21
+ @pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
22
+ def test_rag(raglite_test_config: RAGLiteConfig) -> None:
23
+ """Test Retrieval-Augmented Generation."""
24
+ # Assemble different types of search inputs for RAG.
25
+ prompt = "What does it mean for two events to be simultaneous?"
26
+ search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
27
+ hybrid_search, # A search method as input.
28
+ hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
29
+ retrieve_chunks( # Chunks as input.
30
+ hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
31
+ ),
32
+ ]
33
+ # Answer a question with RAG.
34
+ for search_input in search_inputs:
35
+ stream = rag(prompt, search=search_input, config=raglite_test_config)
36
+ answer = ""
37
+ for update in stream:
38
+ assert isinstance(update, str)
39
+ answer += update
40
+ assert "simultaneous" in answer.lower()
tests/test_rerank.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite's reranking functionality."""
2
+
3
+ import pytest
4
+ from rerankers.models.ranker import BaseRanker
5
+
6
+ from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks
7
+ from raglite._database import Chunk
8
+ from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker
9
+
10
+
11
+ @pytest.fixture(
12
+ params=[
13
+ pytest.param(None, id="no_reranker"),
14
+ pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"),
15
+ pytest.param(
16
+ (
17
+ ("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
18
+ ("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
19
+ ),
20
+ id="flashrank_multilingual",
21
+ ),
22
+ ],
23
+ )
24
+ def reranker(
25
+ request: pytest.FixtureRequest,
26
+ ) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None:
27
+ """Get a reranker to test RAGLite with."""
28
+ reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param
29
+ return reranker
30
+
31
+
32
+ def test_reranker(
33
+ raglite_test_config: RAGLiteConfig,
34
+ reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None,
35
+ ) -> None:
36
+ """Test inserting a document, updating the indexes, and searching for a query."""
37
+ # Update the config with the reranker.
38
+ raglite_test_config = RAGLiteConfig(
39
+ db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker
40
+ )
41
+ # Search for a query.
42
+ query = "What does it mean for two events to be simultaneous?"
43
+ chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config)
44
+ # Retrieve the chunks.
45
+ chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
46
+ assert all(isinstance(chunk, Chunk) for chunk in chunks)
47
+ assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
48
+ # Rerank the chunks given an inverted chunk order.
49
+ reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config)
50
+ if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
51
+ assert reranked_chunks[0] == chunks[0]
52
+ # Test that we can also rerank given the chunk_ids only.
53
+ reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config)
54
+ if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
55
+ assert reranked_chunks[0] == chunks[0]
tests/test_search.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite's search functionality."""
2
+
3
+ import pytest
4
+
5
+ from raglite import (
6
+ RAGLiteConfig,
7
+ hybrid_search,
8
+ keyword_search,
9
+ retrieve_chunks,
10
+ retrieve_segments,
11
+ vector_search,
12
+ )
13
+ from raglite._database import Chunk
14
+ from raglite._typing import SearchMethod
15
+
16
+
17
+ @pytest.fixture(
18
+ params=[
19
+ pytest.param(keyword_search, id="keyword_search"),
20
+ pytest.param(vector_search, id="vector_search"),
21
+ pytest.param(hybrid_search, id="hybrid_search"),
22
+ ],
23
+ )
24
+ def search_method(
25
+ request: pytest.FixtureRequest,
26
+ ) -> SearchMethod:
27
+ """Get a search method to test RAGLite with."""
28
+ search_method: SearchMethod = request.param
29
+ return search_method
30
+
31
+
32
+ def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
33
+ """Test searching for a query."""
34
+ # Search for a query.
35
+ query = "What does it mean for two events to be simultaneous?"
36
+ num_results = 5
37
+ chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
38
+ assert len(chunk_ids) == len(scores) == num_results
39
+ assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
40
+ assert all(isinstance(score, float) for score in scores)
41
+ # Retrieve the chunks.
42
+ chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
43
+ assert all(isinstance(chunk, Chunk) for chunk in chunks)
44
+ assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
45
+ assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
46
+ # Extend the chunks with their neighbours and group them into contiguous segments.
47
+ segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
48
+ assert all(isinstance(segment, str) for segment in segments)
tests/test_split_chunks.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test RAGLite's chunk splitting functionality."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+
6
+ from raglite._split_chunks import split_chunks
7
+
8
+
9
+ @pytest.mark.parametrize(
10
+ "sentences",
11
+ [
12
+ pytest.param([], id="one_chunk:no_sentences"),
13
+ pytest.param(["Hello world"], id="one_chunk:one_sentence"),
14
+ pytest.param(["Hello world"] * 2, id="one_chunk:two_sentences"),
15
+ pytest.param(["Hello world"] * 3, id="one_chunk:three_sentences"),
16
+ pytest.param(["Hello world"] * 100, id="one_chunk:many_sentences"),
17
+ pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_sentences_a"),
18
+ pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_sentences_b"),
19
+ pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_sentences_a"),
20
+ pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_sentences_b"),
21
+ pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_sentences_c"),
22
+ pytest.param(["X" * 1000] * 100, id="n_chunks:many_sentences_a"),
23
+ pytest.param(["X" * 100] * 1000, id="n_chunks:many_sentences_b"),
24
+ ],
25
+ )
26
+ def test_edge_cases(sentences: list[str]) -> None:
27
+ """Test chunk splitting edge cases."""
28
+ sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
29
+ chunks, chunk_embeddings = split_chunks(
30
+ sentences, sentence_embeddings, sentence_window_size=3, max_size=1440
31
+ )
32
+ assert isinstance(chunks, list)
33
+ assert isinstance(chunk_embeddings, list)
34
+ assert len(chunk_embeddings) == (len(chunks) if sentences else 1)
35
+ assert all(isinstance(chunk, str) for chunk in chunks)
36
+ assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
37
+ assert all(ce.dtype == sentence_embeddings.dtype for ce in chunk_embeddings)
38
+ assert sum(ce.shape[0] for ce in chunk_embeddings) == sentence_embeddings.shape[0]
39
+ assert all(ce.shape[1] == sentence_embeddings.shape[1] for ce in chunk_embeddings)
40
+
41
+
42
+ @pytest.mark.parametrize(
43
+ "sentences",
44
+ [
45
+ pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
46
+ pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
47
+ pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
48
+ ],
49
+ )
50
+ def test_long_sentence(sentences: list[str]) -> None:
51
+ """Test chunking on sentences that are too long."""
52
+ sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
53
+ with pytest.raises(
54
+ ValueError, match="Sentence with length larger than chunk max_size detected."
55
+ ):
56
+ _ = split_chunks(sentences, sentence_embeddings, sentence_window_size=3, max_size=1440)