Spaces:
Running
Running
EL GHAFRAOUI AYOUB
commited on
Commit
·
54f5afe
1
Parent(s):
70ba739
- .cruft.json +29 -0
- .devcontainer/devcontainer.json +65 -0
- .dockerignore +5 -0
- .github/dependabot.yml +29 -0
- .github/workflows/publish.yml +27 -0
- .github/workflows/test.yml +47 -0
- .gitignore +76 -0
- .pre-commit-config.yaml +83 -0
- CHANGELOG.md +53 -0
- Dockerfile +101 -0
- docker-compose.yml +60 -0
- poetry.lock +0 -0
- pyproject.toml +186 -0
- src/raglite/__init__.py +41 -0
- src/raglite/_chainlit.py +117 -0
- src/raglite/_cli.py +39 -0
- src/raglite/_config.py +61 -0
- src/raglite/_database.py +341 -0
- src/raglite/_embed.py +203 -0
- src/raglite/_eval.py +257 -0
- src/raglite/_extract.py +69 -0
- src/raglite/_flashrank.py +41 -0
- src/raglite/_insert.py +160 -0
- src/raglite/_litellm.py +261 -0
- src/raglite/_markdown.py +221 -0
- src/raglite/_query_adapter.py +162 -0
- src/raglite/_rag.py +166 -0
- src/raglite/_search.py +270 -0
- src/raglite/_split_chunks.py +102 -0
- src/raglite/_split_sentences.py +76 -0
- src/raglite/_typing.py +145 -0
- src/raglite/py.typed +0 -0
- tests/__init__.py +1 -0
- tests/conftest.py +102 -0
- tests/specrel.pdf +0 -0
- tests/test_embed.py +26 -0
- tests/test_import.py +8 -0
- tests/test_markdown.py +22 -0
- tests/test_rag.py +40 -0
- tests/test_rerank.py +55 -0
- tests/test_search.py +48 -0
- tests/test_split_chunks.py +56 -0
.cruft.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"template": "https://github.com/superlinear-ai/poetry-cookiecutter",
|
3 |
+
"commit": "b7f2fb0f123aae0a01d2ab015db31f52d2d8cc21",
|
4 |
+
"checkout": null,
|
5 |
+
"context": {
|
6 |
+
"cookiecutter": {
|
7 |
+
"project_type": "package",
|
8 |
+
"project_name": "RAGLite",
|
9 |
+
"project_description": "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL.",
|
10 |
+
"project_url": "https://github.com/superlinear-ai/raglite",
|
11 |
+
"author_name": "Laurent Sorber",
|
12 |
+
"author_email": "[email protected]",
|
13 |
+
"python_version": "3.10",
|
14 |
+
"development_environment": "strict",
|
15 |
+
"with_conventional_commits": "1",
|
16 |
+
"with_fastapi_api": "0",
|
17 |
+
"with_typer_cli": "0",
|
18 |
+
"continuous_integration": "GitHub",
|
19 |
+
"private_package_repository_name": "",
|
20 |
+
"private_package_repository_url": "",
|
21 |
+
"__docker_image": "python:$PYTHON_VERSION-slim",
|
22 |
+
"__docstring_style": "NumPy",
|
23 |
+
"__project_name_kebab_case": "raglite",
|
24 |
+
"__project_name_snake_case": "raglite",
|
25 |
+
"_template": "https://github.com/superlinear-ai/poetry-cookiecutter"
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"directory": null
|
29 |
+
}
|
.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "raglite",
|
3 |
+
"dockerComposeFile": "../docker-compose.yml",
|
4 |
+
"service": "devcontainer",
|
5 |
+
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}/",
|
6 |
+
"remoteUser": "user",
|
7 |
+
"overrideCommand": true,
|
8 |
+
"postStartCommand": "cp --update /opt/build/poetry/poetry.lock /workspaces/${localWorkspaceFolderBasename}/ && mkdir -p /workspaces/${localWorkspaceFolderBasename}/.git/hooks/ && cp --update /opt/build/git/* /workspaces/${localWorkspaceFolderBasename}/.git/hooks/",
|
9 |
+
"customizations": {
|
10 |
+
"vscode": {
|
11 |
+
"extensions": [
|
12 |
+
"charliermarsh.ruff",
|
13 |
+
"GitHub.vscode-github-actions",
|
14 |
+
"GitHub.vscode-pull-request-github",
|
15 |
+
"ms-python.mypy-type-checker",
|
16 |
+
"ms-python.python",
|
17 |
+
"ms-toolsai.jupyter",
|
18 |
+
"ryanluker.vscode-coverage-gutters",
|
19 |
+
"tamasfe.even-better-toml",
|
20 |
+
"visualstudioexptteam.vscodeintellicode"
|
21 |
+
],
|
22 |
+
"settings": {
|
23 |
+
"coverage-gutters.coverageFileNames": [
|
24 |
+
"reports/coverage.xml"
|
25 |
+
],
|
26 |
+
"editor.codeActionsOnSave": {
|
27 |
+
"source.fixAll": "explicit",
|
28 |
+
"source.organizeImports": "explicit"
|
29 |
+
},
|
30 |
+
"editor.formatOnSave": true,
|
31 |
+
"[python]": {
|
32 |
+
"editor.defaultFormatter": "charliermarsh.ruff"
|
33 |
+
},
|
34 |
+
"[toml]": {
|
35 |
+
"editor.formatOnSave": false
|
36 |
+
},
|
37 |
+
"editor.rulers": [
|
38 |
+
100
|
39 |
+
],
|
40 |
+
"files.autoSave": "onFocusChange",
|
41 |
+
"jupyter.kernels.excludePythonEnvironments": [
|
42 |
+
"/usr/local/bin/python"
|
43 |
+
],
|
44 |
+
"mypy-type-checker.importStrategy": "fromEnvironment",
|
45 |
+
"mypy-type-checker.preferDaemon": true,
|
46 |
+
"notebook.codeActionsOnSave": {
|
47 |
+
"notebook.source.fixAll": "explicit",
|
48 |
+
"notebook.source.organizeImports": "explicit"
|
49 |
+
},
|
50 |
+
"notebook.formatOnSave.enabled": true,
|
51 |
+
"python.defaultInterpreterPath": "/opt/raglite-env/bin/python",
|
52 |
+
"python.terminal.activateEnvironment": false,
|
53 |
+
"python.testing.pytestEnabled": true,
|
54 |
+
"ruff.importStrategy": "fromEnvironment",
|
55 |
+
"ruff.logLevel": "warning",
|
56 |
+
"terminal.integrated.defaultProfile.linux": "zsh",
|
57 |
+
"terminal.integrated.profiles.linux": {
|
58 |
+
"zsh": {
|
59 |
+
"path": "/usr/bin/zsh"
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
}
|
65 |
+
}
|
.dockerignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Caches
|
2 |
+
.*_cache/
|
3 |
+
|
4 |
+
# Git
|
5 |
+
.git/
|
.github/dependabot.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: 2
|
2 |
+
|
3 |
+
updates:
|
4 |
+
- package-ecosystem: github-actions
|
5 |
+
directory: /
|
6 |
+
schedule:
|
7 |
+
interval: monthly
|
8 |
+
commit-message:
|
9 |
+
prefix: "ci"
|
10 |
+
prefix-development: "ci"
|
11 |
+
include: scope
|
12 |
+
groups:
|
13 |
+
ci-dependencies:
|
14 |
+
patterns:
|
15 |
+
- "*"
|
16 |
+
- package-ecosystem: pip
|
17 |
+
directory: /
|
18 |
+
schedule:
|
19 |
+
interval: monthly
|
20 |
+
commit-message:
|
21 |
+
prefix: "chore"
|
22 |
+
prefix-development: "build"
|
23 |
+
include: scope
|
24 |
+
allow:
|
25 |
+
- dependency-type: development
|
26 |
+
versioning-strategy: increase
|
27 |
+
groups:
|
28 |
+
development-dependencies:
|
29 |
+
dependency-type: development
|
.github/workflows/publish.yml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish
|
2 |
+
|
3 |
+
on:
|
4 |
+
release:
|
5 |
+
types:
|
6 |
+
- created
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
publish:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v4
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v5
|
18 |
+
with:
|
19 |
+
python-version: "3.10"
|
20 |
+
|
21 |
+
- name: Install Poetry
|
22 |
+
run: pip install --no-input poetry
|
23 |
+
|
24 |
+
- name: Publish package
|
25 |
+
run: |
|
26 |
+
poetry config pypi-token.pypi "${{ secrets.POETRY_PYPI_TOKEN_PYPI }}"
|
27 |
+
poetry publish --build
|
.github/workflows/test.yml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Test
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- master
|
8 |
+
pull_request:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
test:
|
12 |
+
runs-on: ubuntu-latest
|
13 |
+
|
14 |
+
strategy:
|
15 |
+
fail-fast: false
|
16 |
+
matrix:
|
17 |
+
python-version: ["3.10", "3.11"]
|
18 |
+
|
19 |
+
name: Python ${{ matrix.python-version }}
|
20 |
+
|
21 |
+
steps:
|
22 |
+
- name: Checkout
|
23 |
+
uses: actions/checkout@v4
|
24 |
+
|
25 |
+
- name: Set up Node.js
|
26 |
+
uses: actions/setup-node@v4
|
27 |
+
with:
|
28 |
+
node-version: 21
|
29 |
+
|
30 |
+
- name: Install @devcontainers/cli
|
31 |
+
run: npm install --location=global @devcontainers/[email protected]
|
32 |
+
|
33 |
+
- name: Start Dev Container
|
34 |
+
run: |
|
35 |
+
git config --global init.defaultBranch main
|
36 |
+
PYTHON_VERSION=${{ matrix.python-version }} OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} devcontainer up --workspace-folder .
|
37 |
+
|
38 |
+
- name: Lint package
|
39 |
+
run: devcontainer exec --workspace-folder . poe lint
|
40 |
+
|
41 |
+
- name: Test package
|
42 |
+
run: devcontainer exec --workspace-folder . poe test
|
43 |
+
|
44 |
+
- name: Upload coverage
|
45 |
+
uses: codecov/codecov-action@v4
|
46 |
+
with:
|
47 |
+
files: reports/coverage.xml
|
.gitignore
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chainlit
|
2 |
+
.chainlit/
|
3 |
+
.files/
|
4 |
+
chainlit.md
|
5 |
+
|
6 |
+
# Coverage.py
|
7 |
+
htmlcov/
|
8 |
+
reports/
|
9 |
+
|
10 |
+
# cruft
|
11 |
+
*.rej
|
12 |
+
|
13 |
+
# Data
|
14 |
+
*.csv*
|
15 |
+
*.dat*
|
16 |
+
*.pickle*
|
17 |
+
*.xls*
|
18 |
+
*.zip*
|
19 |
+
data/
|
20 |
+
|
21 |
+
# direnv
|
22 |
+
.envrc
|
23 |
+
|
24 |
+
# dotenv
|
25 |
+
.env
|
26 |
+
|
27 |
+
# rerankers
|
28 |
+
.*_cache/
|
29 |
+
|
30 |
+
# Hypothesis
|
31 |
+
.hypothesis/
|
32 |
+
|
33 |
+
# Jupyter
|
34 |
+
*.ipynb
|
35 |
+
.ipynb_checkpoints/
|
36 |
+
notebooks/
|
37 |
+
|
38 |
+
# macOS
|
39 |
+
.DS_Store
|
40 |
+
|
41 |
+
# mypy
|
42 |
+
.dmypy.json
|
43 |
+
.mypy_cache/
|
44 |
+
|
45 |
+
# Node.js
|
46 |
+
node_modules/
|
47 |
+
|
48 |
+
# Poetry
|
49 |
+
.venv/
|
50 |
+
dist/
|
51 |
+
|
52 |
+
# PyCharm
|
53 |
+
.idea/
|
54 |
+
|
55 |
+
# pyenv
|
56 |
+
.python-version
|
57 |
+
|
58 |
+
# pytest
|
59 |
+
.pytest_cache/
|
60 |
+
|
61 |
+
# Python
|
62 |
+
__pycache__/
|
63 |
+
*.py[cdo]
|
64 |
+
|
65 |
+
# RAGLite
|
66 |
+
*.db
|
67 |
+
*.sqlite
|
68 |
+
|
69 |
+
# Ruff
|
70 |
+
.ruff_cache/
|
71 |
+
|
72 |
+
# Terraform
|
73 |
+
.terraform/
|
74 |
+
|
75 |
+
# VS Code
|
76 |
+
.vscode/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://pre-commit.com
|
2 |
+
default_install_hook_types: [commit-msg, pre-commit]
|
3 |
+
default_stages: [commit, manual]
|
4 |
+
fail_fast: true
|
5 |
+
repos:
|
6 |
+
- repo: meta
|
7 |
+
hooks:
|
8 |
+
- id: check-useless-excludes
|
9 |
+
- repo: https://github.com/pre-commit/pygrep-hooks
|
10 |
+
rev: v1.10.0
|
11 |
+
hooks:
|
12 |
+
- id: python-check-mock-methods
|
13 |
+
- id: python-use-type-annotations
|
14 |
+
- id: rst-backticks
|
15 |
+
- id: rst-directive-colons
|
16 |
+
- id: rst-inline-touching-normal
|
17 |
+
- id: text-unicode-replacement-char
|
18 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
19 |
+
rev: v4.5.0
|
20 |
+
hooks:
|
21 |
+
- id: check-added-large-files
|
22 |
+
- id: check-ast
|
23 |
+
- id: check-builtin-literals
|
24 |
+
- id: check-case-conflict
|
25 |
+
- id: check-docstring-first
|
26 |
+
- id: check-json
|
27 |
+
- id: check-merge-conflict
|
28 |
+
- id: check-shebang-scripts-are-executable
|
29 |
+
- id: check-symlinks
|
30 |
+
- id: check-toml
|
31 |
+
- id: check-vcs-permalinks
|
32 |
+
- id: check-xml
|
33 |
+
- id: check-yaml
|
34 |
+
- id: debug-statements
|
35 |
+
- id: destroyed-symlinks
|
36 |
+
- id: detect-private-key
|
37 |
+
- id: end-of-file-fixer
|
38 |
+
types: [python]
|
39 |
+
- id: fix-byte-order-marker
|
40 |
+
- id: mixed-line-ending
|
41 |
+
- id: name-tests-test
|
42 |
+
args: [--pytest-test-first]
|
43 |
+
- id: trailing-whitespace
|
44 |
+
types: [python]
|
45 |
+
- repo: local
|
46 |
+
hooks:
|
47 |
+
- id: commitizen
|
48 |
+
name: commitizen
|
49 |
+
entry: cz check
|
50 |
+
args: [--commit-msg-file]
|
51 |
+
require_serial: true
|
52 |
+
language: system
|
53 |
+
stages: [commit-msg]
|
54 |
+
- id: ruff-check
|
55 |
+
name: ruff check
|
56 |
+
entry: ruff check
|
57 |
+
args: ["--force-exclude", "--extend-fixable=ERA001,F401,F841,T201,T203"]
|
58 |
+
require_serial: true
|
59 |
+
language: system
|
60 |
+
types_or: [python, pyi]
|
61 |
+
- id: ruff-format
|
62 |
+
name: ruff format
|
63 |
+
entry: ruff format
|
64 |
+
args: [--force-exclude]
|
65 |
+
require_serial: true
|
66 |
+
language: system
|
67 |
+
types_or: [python, pyi]
|
68 |
+
- id: shellcheck
|
69 |
+
name: shellcheck
|
70 |
+
entry: shellcheck
|
71 |
+
args: [--check-sourced]
|
72 |
+
language: system
|
73 |
+
types: [shell]
|
74 |
+
- id: poetry-check
|
75 |
+
name: poetry check
|
76 |
+
entry: poetry check
|
77 |
+
language: system
|
78 |
+
pass_filenames: false
|
79 |
+
- id: mypy
|
80 |
+
name: mypy
|
81 |
+
entry: mypy
|
82 |
+
language: system
|
83 |
+
types: [python]
|
CHANGELOG.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## v0.2.0 (2024-10-21)
|
2 |
+
|
3 |
+
### Feat
|
4 |
+
|
5 |
+
- add Chainlit frontend (#33)
|
6 |
+
|
7 |
+
## v0.1.4 (2024-10-15)
|
8 |
+
|
9 |
+
### Fix
|
10 |
+
|
11 |
+
- fix optimal chunking edge cases (#32)
|
12 |
+
|
13 |
+
## v0.1.3 (2024-10-13)
|
14 |
+
|
15 |
+
### Fix
|
16 |
+
|
17 |
+
- upgrade pdftext (#30)
|
18 |
+
- improve chunk and segment ordering (#29)
|
19 |
+
|
20 |
+
## v0.1.2 (2024-10-08)
|
21 |
+
|
22 |
+
### Fix
|
23 |
+
|
24 |
+
- avoid pdftext v0.3.11 (#27)
|
25 |
+
|
26 |
+
## v0.1.1 (2024-10-07)
|
27 |
+
|
28 |
+
### Fix
|
29 |
+
|
30 |
+
- patch rerankers flashrank issue (#22)
|
31 |
+
|
32 |
+
## v0.1.0 (2024-10-07)
|
33 |
+
|
34 |
+
### Feat
|
35 |
+
|
36 |
+
- add reranking (#20)
|
37 |
+
- add LiteLLM and late chunking (#19)
|
38 |
+
- add PostgreSQL support (#18)
|
39 |
+
- make query adapter minimally invasive (#16)
|
40 |
+
- upgrade default CPU model to Phi-3.5-mini (#15)
|
41 |
+
- add evaluation (#14)
|
42 |
+
- infer missing font sizes (#12)
|
43 |
+
- automatically adjust number of RAG contexts (#10)
|
44 |
+
- improve exception feedback for extraction (#9)
|
45 |
+
- optimize config for CPU and GPU (#7)
|
46 |
+
- simplify document insertion (#6)
|
47 |
+
- implement basic features (#2)
|
48 |
+
- initial commit
|
49 |
+
|
50 |
+
### Fix
|
51 |
+
|
52 |
+
- lazily import optional dependencies (#11)
|
53 |
+
- improve indexing of multiple documents (#8)
|
Dockerfile
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# syntax=docker/dockerfile:1
|
2 |
+
ARG PYTHON_VERSION=3.10
|
3 |
+
FROM python:$PYTHON_VERSION-slim AS base
|
4 |
+
|
5 |
+
# Remove docker-clean so we can keep the apt cache in Docker build cache.
|
6 |
+
RUN rm /etc/apt/apt.conf.d/docker-clean
|
7 |
+
|
8 |
+
# Configure Python to print tracebacks on crash [1], and to not buffer stdout and stderr [2].
|
9 |
+
# [1] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONFAULTHANDLER
|
10 |
+
# [2] https://docs.python.org/3/using/cmdline.html#envvar-PYTHONUNBUFFERED
|
11 |
+
ENV PYTHONFAULTHANDLER 1
|
12 |
+
ENV PYTHONUNBUFFERED 1
|
13 |
+
|
14 |
+
# Create a non-root user and switch to it [1].
|
15 |
+
# [1] https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user
|
16 |
+
ARG UID=1000
|
17 |
+
ARG GID=$UID
|
18 |
+
RUN groupadd --gid $GID user && \
|
19 |
+
useradd --create-home --gid $GID --uid $UID user --no-log-init && \
|
20 |
+
chown user /opt/
|
21 |
+
USER user
|
22 |
+
|
23 |
+
# Create and activate a virtual environment.
|
24 |
+
ENV VIRTUAL_ENV /opt/raglite-env
|
25 |
+
ENV PATH $VIRTUAL_ENV/bin:$PATH
|
26 |
+
RUN python -m venv $VIRTUAL_ENV
|
27 |
+
|
28 |
+
# Set the working directory.
|
29 |
+
WORKDIR /workspaces/raglite/
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
FROM base as poetry
|
34 |
+
|
35 |
+
USER root
|
36 |
+
|
37 |
+
# Install Poetry in separate venv so it doesn't pollute the main venv.
|
38 |
+
ENV POETRY_VERSION 1.8.0
|
39 |
+
ENV POETRY_VIRTUAL_ENV /opt/poetry-env
|
40 |
+
RUN --mount=type=cache,target=/root/.cache/pip/ \
|
41 |
+
python -m venv $POETRY_VIRTUAL_ENV && \
|
42 |
+
$POETRY_VIRTUAL_ENV/bin/pip install poetry~=$POETRY_VERSION && \
|
43 |
+
ln -s $POETRY_VIRTUAL_ENV/bin/poetry /usr/local/bin/poetry
|
44 |
+
|
45 |
+
# Install compilers that may be required for certain packages or platforms.
|
46 |
+
RUN --mount=type=cache,target=/var/cache/apt/ \
|
47 |
+
--mount=type=cache,target=/var/lib/apt/ \
|
48 |
+
apt-get update && \
|
49 |
+
apt-get install --no-install-recommends --yes build-essential
|
50 |
+
|
51 |
+
USER user
|
52 |
+
|
53 |
+
# Install the run time Python dependencies in the virtual environment.
|
54 |
+
COPY --chown=user:user poetry.lock* pyproject.toml /workspaces/raglite/
|
55 |
+
RUN mkdir -p /home/user/.cache/pypoetry/ && mkdir -p /home/user/.config/pypoetry/ && \
|
56 |
+
mkdir -p src/raglite/ && touch src/raglite/__init__.py && touch README.md
|
57 |
+
RUN --mount=type=cache,uid=$UID,gid=$GID,target=/home/user/.cache/pypoetry/ \
|
58 |
+
poetry install --only main --all-extras --no-interaction
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
FROM poetry as dev
|
63 |
+
|
64 |
+
# Install development tools: curl, git, gpg, ssh, starship, sudo, vim, and zsh.
|
65 |
+
USER root
|
66 |
+
RUN --mount=type=cache,target=/var/cache/apt/ \
|
67 |
+
--mount=type=cache,target=/var/lib/apt/ \
|
68 |
+
apt-get update && \
|
69 |
+
apt-get install --no-install-recommends --yes curl git gnupg ssh sudo vim zsh && \
|
70 |
+
sh -c "$(curl -fsSL https://starship.rs/install.sh)" -- "--yes" && \
|
71 |
+
usermod --shell /usr/bin/zsh user && \
|
72 |
+
echo 'user ALL=(root) NOPASSWD:ALL' > /etc/sudoers.d/user && chmod 0440 /etc/sudoers.d/user
|
73 |
+
RUN git config --system --add safe.directory '*'
|
74 |
+
USER user
|
75 |
+
|
76 |
+
# Install the development Python dependencies in the virtual environment.
|
77 |
+
RUN --mount=type=cache,uid=$UID,gid=$GID,target=/home/user/.cache/pypoetry/ \
|
78 |
+
poetry install --all-extras --no-interaction
|
79 |
+
|
80 |
+
# Persist output generated during docker build so that we can restore it in the dev container.
|
81 |
+
COPY --chown=user:user .pre-commit-config.yaml /workspaces/raglite/
|
82 |
+
RUN mkdir -p /opt/build/poetry/ && cp poetry.lock /opt/build/poetry/ && \
|
83 |
+
git init && pre-commit install --install-hooks && \
|
84 |
+
mkdir -p /opt/build/git/ && cp .git/hooks/commit-msg .git/hooks/pre-commit /opt/build/git/
|
85 |
+
|
86 |
+
# Configure the non-root user's shell.
|
87 |
+
ENV ANTIDOTE_VERSION 1.8.6
|
88 |
+
RUN git clone --branch v$ANTIDOTE_VERSION --depth=1 https://github.com/mattmc3/antidote.git ~/.antidote/ && \
|
89 |
+
echo 'zsh-users/zsh-syntax-highlighting' >> ~/.zsh_plugins.txt && \
|
90 |
+
echo 'zsh-users/zsh-autosuggestions' >> ~/.zsh_plugins.txt && \
|
91 |
+
echo 'source ~/.antidote/antidote.zsh' >> ~/.zshrc && \
|
92 |
+
echo 'antidote load' >> ~/.zshrc && \
|
93 |
+
echo 'eval "$(starship init zsh)"' >> ~/.zshrc && \
|
94 |
+
echo 'HISTFILE=~/.history/.zsh_history' >> ~/.zshrc && \
|
95 |
+
echo 'HISTSIZE=1000' >> ~/.zshrc && \
|
96 |
+
echo 'SAVEHIST=1000' >> ~/.zshrc && \
|
97 |
+
echo 'setopt share_history' >> ~/.zshrc && \
|
98 |
+
echo 'bindkey "^[[A" history-beginning-search-backward' >> ~/.zshrc && \
|
99 |
+
echo 'bindkey "^[[B" history-beginning-search-forward' >> ~/.zshrc && \
|
100 |
+
mkdir ~/.history/ && \
|
101 |
+
zsh -c 'source ~/.zshrc'
|
docker-compose.yml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.9"
|
2 |
+
|
3 |
+
services:
|
4 |
+
|
5 |
+
devcontainer:
|
6 |
+
build:
|
7 |
+
context: .
|
8 |
+
target: dev
|
9 |
+
args:
|
10 |
+
PYTHON_VERSION: ${PYTHON_VERSION:-3.10}
|
11 |
+
UID: ${UID:-1000}
|
12 |
+
GID: ${GID:-1000}
|
13 |
+
environment:
|
14 |
+
- OPENAI_API_KEY
|
15 |
+
- POETRY_PYPI_TOKEN_PYPI
|
16 |
+
depends_on:
|
17 |
+
- postgres
|
18 |
+
networks:
|
19 |
+
- raglite-network
|
20 |
+
volumes:
|
21 |
+
- ..:/workspaces
|
22 |
+
- command-history-volume:/home/user/.history/
|
23 |
+
|
24 |
+
dev:
|
25 |
+
extends: devcontainer
|
26 |
+
stdin_open: true
|
27 |
+
tty: true
|
28 |
+
entrypoint: []
|
29 |
+
command: [ "sh", "-c", "sudo chown user $$SSH_AUTH_SOCK && cp --update /opt/build/poetry/poetry.lock /workspaces/raglite/ && mkdir -p /workspaces/raglite/.git/hooks/ && cp --update /opt/build/git/* /workspaces/raglite/.git/hooks/ && zsh" ]
|
30 |
+
environment:
|
31 |
+
- OPENAI_API_KEY
|
32 |
+
- POETRY_PYPI_TOKEN_PYPI
|
33 |
+
- SSH_AUTH_SOCK=/run/host-services/ssh-auth.sock
|
34 |
+
depends_on:
|
35 |
+
- postgres
|
36 |
+
networks:
|
37 |
+
- raglite-network
|
38 |
+
volumes:
|
39 |
+
- ~/.gitconfig:/etc/gitconfig
|
40 |
+
- ~/.ssh/known_hosts:/home/user/.ssh/known_hosts
|
41 |
+
- ${SSH_AGENT_AUTH_SOCK:-/run/host-services/ssh-auth.sock}:/run/host-services/ssh-auth.sock
|
42 |
+
profiles:
|
43 |
+
- dev
|
44 |
+
|
45 |
+
postgres:
|
46 |
+
image: pgvector/pgvector:pg16
|
47 |
+
environment:
|
48 |
+
POSTGRES_USER: raglite_user
|
49 |
+
POSTGRES_PASSWORD: raglite_password
|
50 |
+
networks:
|
51 |
+
- raglite-network
|
52 |
+
tmpfs:
|
53 |
+
- /var/lib/postgresql/data
|
54 |
+
|
55 |
+
networks:
|
56 |
+
raglite-network:
|
57 |
+
driver: bridge
|
58 |
+
|
59 |
+
volumes:
|
60 |
+
command-history-volume:
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system] # https://python-poetry.org/docs/pyproject/#poetry-and-pep-517
|
2 |
+
requires = ["poetry-core>=1.0.0"]
|
3 |
+
build-backend = "poetry.core.masonry.api"
|
4 |
+
|
5 |
+
[tool.poetry] # https://python-poetry.org/docs/pyproject/
|
6 |
+
name = "raglite"
|
7 |
+
version = "0.2.0"
|
8 |
+
description = "A Python toolkit for Retrieval-Augmented Generation (RAG) with SQLite or PostgreSQL."
|
9 |
+
authors = ["Laurent Sorber <[email protected]>"]
|
10 |
+
readme = "README.md"
|
11 |
+
repository = "https://github.com/superlinear-ai/raglite"
|
12 |
+
|
13 |
+
[tool.commitizen] # https://commitizen-tools.github.io/commitizen/config/
|
14 |
+
bump_message = "bump(release): v$current_version → v$new_version"
|
15 |
+
tag_format = "v$version"
|
16 |
+
update_changelog_on_bump = true
|
17 |
+
version_provider = "poetry"
|
18 |
+
|
19 |
+
[tool.poetry.dependencies] # https://python-poetry.org/docs/dependency-specification/
|
20 |
+
# Python:
|
21 |
+
python = ">=3.10,<4.0"
|
22 |
+
# Markdown conversion:
|
23 |
+
pdftext = ">=0.3.13"
|
24 |
+
pypandoc-binary = { version = ">=1.13", optional = true }
|
25 |
+
scikit-learn = ">=1.4.2"
|
26 |
+
# Markdown formatting:
|
27 |
+
markdown-it-py = ">=3.0.0"
|
28 |
+
mdformat-gfm = ">=0.3.6"
|
29 |
+
# Sentence and chunk splitting:
|
30 |
+
numpy = ">=1.26.4"
|
31 |
+
scipy = ">=1.5.0"
|
32 |
+
spacy = ">=3.7.0,<3.8.0"
|
33 |
+
# Large Language Models:
|
34 |
+
huggingface-hub = ">=0.22.0"
|
35 |
+
litellm = ">=1.47.1"
|
36 |
+
llama-cpp-python = ">=0.2.88"
|
37 |
+
pydantic = ">=2.7.0"
|
38 |
+
# Approximate Nearest Neighbors:
|
39 |
+
pynndescent = ">=0.5.12"
|
40 |
+
# Reranking:
|
41 |
+
langdetect = ">=1.0.9"
|
42 |
+
rerankers = { extras = ["flashrank"], version = ">=0.5.3" }
|
43 |
+
# Storage:
|
44 |
+
pg8000 = ">=1.31.2"
|
45 |
+
sqlmodel-slim = ">=0.0.18"
|
46 |
+
# Progress:
|
47 |
+
tqdm = ">=4.66.0"
|
48 |
+
# Evaluation:
|
49 |
+
pandas = ">=2.1.0"
|
50 |
+
ragas = { version = ">=0.1.12", optional = true }
|
51 |
+
# CLI:
|
52 |
+
typer = ">=0.12.5"
|
53 |
+
# Frontend:
|
54 |
+
chainlit = { version = ">=1.2.0", optional = true }
|
55 |
+
|
56 |
+
[tool.poetry.extras] # https://python-poetry.org/docs/pyproject/#extras
|
57 |
+
chainlit = ["chainlit"]
|
58 |
+
pandoc = ["pypandoc-binary"]
|
59 |
+
ragas = ["ragas"]
|
60 |
+
|
61 |
+
[tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
|
62 |
+
commitizen = ">=3.29.1"
|
63 |
+
coverage = { extras = ["toml"], version = ">=7.4.4" }
|
64 |
+
mypy = ">=1.9.0"
|
65 |
+
poethepoet = ">=0.25.0"
|
66 |
+
pre-commit = ">=3.7.0"
|
67 |
+
pytest = ">=8.1.1"
|
68 |
+
pytest-mock = ">=3.14.0"
|
69 |
+
ruff = ">=0.5.7"
|
70 |
+
safety = ">=3.1.0"
|
71 |
+
shellcheck-py = ">=0.10.0.1"
|
72 |
+
typeguard = ">=4.2.1"
|
73 |
+
xx_sent_ud_sm = { url = "https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl" }
|
74 |
+
|
75 |
+
[tool.poetry.group.dev.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
|
76 |
+
cruft = ">=2.15.0"
|
77 |
+
ipykernel = ">=6.29.4"
|
78 |
+
ipython = ">=8.8.0"
|
79 |
+
ipywidgets = ">=8.1.2"
|
80 |
+
matplotlib = ">=3.9.0"
|
81 |
+
memory-profiler = ">=0.61.0"
|
82 |
+
pdoc = ">=14.4.0"
|
83 |
+
|
84 |
+
[tool.poetry.scripts] # https://python-poetry.org/docs/pyproject/#scripts
|
85 |
+
raglite = "raglite:cli"
|
86 |
+
|
87 |
+
[tool.coverage.report] # https://coverage.readthedocs.io/en/latest/config.html#report
|
88 |
+
fail_under = 50
|
89 |
+
precision = 1
|
90 |
+
show_missing = true
|
91 |
+
skip_covered = true
|
92 |
+
|
93 |
+
[tool.coverage.run] # https://coverage.readthedocs.io/en/latest/config.html#run
|
94 |
+
branch = true
|
95 |
+
command_line = "--module pytest"
|
96 |
+
data_file = "reports/.coverage"
|
97 |
+
source = ["src"]
|
98 |
+
|
99 |
+
[tool.coverage.xml] # https://coverage.readthedocs.io/en/latest/config.html#xml
|
100 |
+
output = "reports/coverage.xml"
|
101 |
+
|
102 |
+
[tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html
|
103 |
+
junit_xml = "reports/mypy.xml"
|
104 |
+
strict = true
|
105 |
+
disallow_subclassing_any = false
|
106 |
+
disallow_untyped_decorators = false
|
107 |
+
ignore_missing_imports = true
|
108 |
+
pretty = true
|
109 |
+
show_column_numbers = true
|
110 |
+
show_error_codes = true
|
111 |
+
show_error_context = true
|
112 |
+
warn_unreachable = true
|
113 |
+
|
114 |
+
[tool.pytest.ini_options] # https://docs.pytest.org/en/latest/reference/reference.html#ini-options-ref
|
115 |
+
addopts = "--color=yes --exitfirst --failed-first --strict-config --strict-markers --verbosity=2 --junitxml=reports/pytest.xml"
|
116 |
+
filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::pytest.PytestUnraisableExceptionWarning"]
|
117 |
+
testpaths = ["src", "tests"]
|
118 |
+
xfail_strict = true
|
119 |
+
|
120 |
+
[tool.ruff] # https://github.com/charliermarsh/ruff
|
121 |
+
fix = true
|
122 |
+
line-length = 100
|
123 |
+
src = ["src", "tests"]
|
124 |
+
target-version = "py310"
|
125 |
+
|
126 |
+
[tool.ruff.lint]
|
127 |
+
select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
|
128 |
+
ignore = ["D203", "D213", "E501", "RET504", "RUF002", "S101", "S307"]
|
129 |
+
unfixable = ["ERA001", "F401", "F841", "T201", "T203"]
|
130 |
+
|
131 |
+
[tool.ruff.lint.flake8-tidy-imports]
|
132 |
+
ban-relative-imports = "all"
|
133 |
+
|
134 |
+
[tool.ruff.lint.pycodestyle]
|
135 |
+
max-doc-length = 100
|
136 |
+
|
137 |
+
[tool.ruff.lint.pydocstyle]
|
138 |
+
convention = "numpy"
|
139 |
+
|
140 |
+
[tool.poe.tasks] # https://github.com/nat-n/poethepoet
|
141 |
+
|
142 |
+
[tool.poe.tasks.docs]
|
143 |
+
help = "Generate this package's docs"
|
144 |
+
cmd = """
|
145 |
+
pdoc
|
146 |
+
--docformat $docformat
|
147 |
+
--output-directory $outputdirectory
|
148 |
+
raglite
|
149 |
+
"""
|
150 |
+
|
151 |
+
[[tool.poe.tasks.docs.args]]
|
152 |
+
help = "The docstring style (default: numpy)"
|
153 |
+
name = "docformat"
|
154 |
+
options = ["--docformat"]
|
155 |
+
default = "numpy"
|
156 |
+
|
157 |
+
[[tool.poe.tasks.docs.args]]
|
158 |
+
help = "The output directory (default: docs)"
|
159 |
+
name = "outputdirectory"
|
160 |
+
options = ["--output-directory"]
|
161 |
+
default = "docs"
|
162 |
+
|
163 |
+
[tool.poe.tasks.lint]
|
164 |
+
help = "Lint this package"
|
165 |
+
|
166 |
+
[[tool.poe.tasks.lint.sequence]]
|
167 |
+
cmd = """
|
168 |
+
pre-commit run
|
169 |
+
--all-files
|
170 |
+
--color always
|
171 |
+
"""
|
172 |
+
|
173 |
+
[[tool.poe.tasks.lint.sequence]]
|
174 |
+
shell = "safety check --continue-on-error --full-report"
|
175 |
+
|
176 |
+
[tool.poe.tasks.test]
|
177 |
+
help = "Test this package"
|
178 |
+
|
179 |
+
[[tool.poe.tasks.test.sequence]]
|
180 |
+
cmd = "coverage run"
|
181 |
+
|
182 |
+
[[tool.poe.tasks.test.sequence]]
|
183 |
+
cmd = "coverage report"
|
184 |
+
|
185 |
+
[[tool.poe.tasks.test.sequence]]
|
186 |
+
cmd = "coverage xml"
|
src/raglite/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAGLite."""
|
2 |
+
|
3 |
+
from raglite._cli import cli
|
4 |
+
from raglite._config import RAGLiteConfig
|
5 |
+
from raglite._eval import answer_evals, evaluate, insert_evals
|
6 |
+
from raglite._insert import insert_document
|
7 |
+
from raglite._query_adapter import update_query_adapter
|
8 |
+
from raglite._rag import async_rag, rag
|
9 |
+
from raglite._search import (
|
10 |
+
hybrid_search,
|
11 |
+
keyword_search,
|
12 |
+
rerank_chunks,
|
13 |
+
retrieve_chunks,
|
14 |
+
retrieve_segments,
|
15 |
+
vector_search,
|
16 |
+
)
|
17 |
+
|
18 |
+
__all__ = [
|
19 |
+
# Config
|
20 |
+
"RAGLiteConfig",
|
21 |
+
# Insert
|
22 |
+
"insert_document",
|
23 |
+
# Search
|
24 |
+
"hybrid_search",
|
25 |
+
"keyword_search",
|
26 |
+
"vector_search",
|
27 |
+
"retrieve_chunks",
|
28 |
+
"retrieve_segments",
|
29 |
+
"rerank_chunks",
|
30 |
+
# RAG
|
31 |
+
"async_rag",
|
32 |
+
"rag",
|
33 |
+
# Query adapter
|
34 |
+
"update_query_adapter",
|
35 |
+
# Evaluate
|
36 |
+
"insert_evals",
|
37 |
+
"answer_evals",
|
38 |
+
"evaluate",
|
39 |
+
# CLI
|
40 |
+
"cli",
|
41 |
+
]
|
src/raglite/_chainlit.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Chainlit frontend for RAGLite."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import chainlit as cl
|
7 |
+
from chainlit.input_widget import Switch, TextInput
|
8 |
+
|
9 |
+
from raglite import (
|
10 |
+
RAGLiteConfig,
|
11 |
+
async_rag,
|
12 |
+
hybrid_search,
|
13 |
+
insert_document,
|
14 |
+
rerank_chunks,
|
15 |
+
retrieve_chunks,
|
16 |
+
)
|
17 |
+
from raglite._markdown import document_to_markdown
|
18 |
+
|
19 |
+
async_insert_document = cl.make_async(insert_document)
|
20 |
+
async_hybrid_search = cl.make_async(hybrid_search)
|
21 |
+
async_retrieve_chunks = cl.make_async(retrieve_chunks)
|
22 |
+
async_rerank_chunks = cl.make_async(rerank_chunks)
|
23 |
+
|
24 |
+
|
25 |
+
@cl.on_chat_start
|
26 |
+
async def start_chat() -> None:
|
27 |
+
"""Initialize the chat."""
|
28 |
+
# Disable tokenizes parallelism to avoid the deadlock warning.
|
29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
30 |
+
# Add Chainlit settings with which the user can configure the RAGLite config.
|
31 |
+
default_config = RAGLiteConfig()
|
32 |
+
config = RAGLiteConfig(
|
33 |
+
db_url=os.environ.get("RAGLITE_DB_URL", default_config.db_url),
|
34 |
+
llm=os.environ.get("RAGLITE_LLM", default_config.llm),
|
35 |
+
embedder=os.environ.get("RAGLITE_EMBEDDER", default_config.embedder),
|
36 |
+
)
|
37 |
+
settings = await cl.ChatSettings( # type: ignore[no-untyped-call]
|
38 |
+
[
|
39 |
+
TextInput(id="db_url", label="Database URL", initial=str(config.db_url)),
|
40 |
+
TextInput(id="llm", label="LLM", initial=config.llm),
|
41 |
+
TextInput(id="embedder", label="Embedder", initial=config.embedder),
|
42 |
+
Switch(id="vector_search_query_adapter", label="Query adapter", initial=True),
|
43 |
+
]
|
44 |
+
).send()
|
45 |
+
await update_config(settings)
|
46 |
+
|
47 |
+
|
48 |
+
@cl.on_settings_update # type: ignore[arg-type]
|
49 |
+
async def update_config(settings: cl.ChatSettings) -> None:
|
50 |
+
"""Update the RAGLite config."""
|
51 |
+
# Update the RAGLite config given the Chainlit settings.
|
52 |
+
config = RAGLiteConfig(
|
53 |
+
db_url=settings["db_url"], # type: ignore[index]
|
54 |
+
llm=settings["llm"], # type: ignore[index]
|
55 |
+
embedder=settings["embedder"], # type: ignore[index]
|
56 |
+
vector_search_query_adapter=settings["vector_search_query_adapter"], # type: ignore[index]
|
57 |
+
)
|
58 |
+
cl.user_session.set("config", config) # type: ignore[no-untyped-call]
|
59 |
+
# Run a search to prime the pipeline if it's a local pipeline.
|
60 |
+
# TODO: Don't do this for SQLite once we switch from PyNNDescent to sqlite-vec.
|
61 |
+
if str(config.db_url).startswith("sqlite") or config.embedder.startswith("llama-cpp-python"):
|
62 |
+
# async with cl.Step(name="initialize", type="retrieval"):
|
63 |
+
query = "Hello world"
|
64 |
+
chunk_ids, _ = await async_hybrid_search(query=query, config=config)
|
65 |
+
_ = await async_rerank_chunks(query=query, chunk_ids=chunk_ids, config=config)
|
66 |
+
|
67 |
+
|
68 |
+
@cl.on_message
|
69 |
+
async def handle_message(user_message: cl.Message) -> None:
|
70 |
+
"""Respond to a user message."""
|
71 |
+
# Get the config and message history from the user session.
|
72 |
+
config: RAGLiteConfig = cl.user_session.get("config") # type: ignore[no-untyped-call]
|
73 |
+
# Determine what to do with the attachments.
|
74 |
+
inline_attachments = []
|
75 |
+
for file in user_message.elements:
|
76 |
+
if file.path:
|
77 |
+
doc_md = document_to_markdown(Path(file.path))
|
78 |
+
if len(doc_md) // 3 <= 5 * (config.chunk_max_size // 3):
|
79 |
+
# Document is small enough to attach to the context.
|
80 |
+
inline_attachments.append(f"{Path(file.path).name}:\n\n{doc_md}")
|
81 |
+
else:
|
82 |
+
# Document is too large and must be inserted into the database.
|
83 |
+
async with cl.Step(name="insert", type="run") as step:
|
84 |
+
step.input = Path(file.path).name
|
85 |
+
await async_insert_document(Path(file.path), config=config)
|
86 |
+
# Append any inline attachments to the user prompt.
|
87 |
+
user_prompt = f"{user_message.content}\n\n" + "\n\n".join(
|
88 |
+
f'<attachment index="{i}">\n{attachment.strip()}\n</attachment>'
|
89 |
+
for i, attachment in enumerate(inline_attachments)
|
90 |
+
)
|
91 |
+
# Search for relevant contexts for RAG.
|
92 |
+
async with cl.Step(name="search", type="retrieval") as step:
|
93 |
+
step.input = user_message.content
|
94 |
+
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
|
95 |
+
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
|
96 |
+
step.output = chunks
|
97 |
+
step.elements = [ # Show the top 3 chunks inline.
|
98 |
+
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
|
99 |
+
]
|
100 |
+
# Rerank the chunks.
|
101 |
+
async with cl.Step(name="rerank", type="rerank") as step:
|
102 |
+
step.input = chunks
|
103 |
+
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
|
104 |
+
step.output = chunks
|
105 |
+
step.elements = [ # Show the top 3 chunks inline.
|
106 |
+
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:3]
|
107 |
+
]
|
108 |
+
# Stream the LLM response.
|
109 |
+
assistant_message = cl.Message(content="")
|
110 |
+
async for token in async_rag(
|
111 |
+
prompt=user_prompt,
|
112 |
+
search=chunks,
|
113 |
+
messages=cl.chat_context.to_openai()[-5:], # type: ignore[no-untyped-call]
|
114 |
+
config=config,
|
115 |
+
):
|
116 |
+
await assistant_message.stream_token(token)
|
117 |
+
await assistant_message.update() # type: ignore[no-untyped-call]
|
src/raglite/_cli.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAGLite CLI."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import typer
|
6 |
+
|
7 |
+
from raglite._config import RAGLiteConfig
|
8 |
+
|
9 |
+
cli = typer.Typer()
|
10 |
+
|
11 |
+
|
12 |
+
@cli.callback()
|
13 |
+
def main() -> None:
|
14 |
+
"""RAGLite CLI."""
|
15 |
+
|
16 |
+
|
17 |
+
@cli.command()
|
18 |
+
def chainlit(
|
19 |
+
db_url: str = typer.Option(RAGLiteConfig().db_url, help="Database URL"),
|
20 |
+
llm: str = typer.Option(RAGLiteConfig().llm, help="LiteLLM LLM"),
|
21 |
+
embedder: str = typer.Option(RAGLiteConfig().embedder, help="LiteLLM embedder"),
|
22 |
+
) -> None:
|
23 |
+
"""Serve a Chainlit frontend."""
|
24 |
+
# Set the environment variables for the Chainlit frontend.
|
25 |
+
os.environ["RAGLITE_DB_URL"] = os.environ.get("RAGLITE_DB_URL", db_url)
|
26 |
+
os.environ["RAGLITE_LLM"] = os.environ.get("RAGLITE_LLM", llm)
|
27 |
+
os.environ["RAGLITE_EMBEDDER"] = os.environ.get("RAGLITE_EMBEDDER", embedder)
|
28 |
+
# Import Chainlit here as it's an optional dependency.
|
29 |
+
try:
|
30 |
+
from chainlit.cli import run_chainlit
|
31 |
+
except ImportError as error:
|
32 |
+
error_message = "To serve a Chainlit frontend, please install the `chainlit` extra."
|
33 |
+
raise ImportError(error_message) from error
|
34 |
+
# Serve the frontend.
|
35 |
+
run_chainlit(__file__.replace("_cli.py", "_chainlit.py"))
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
cli()
|
src/raglite/_config.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAGLite config."""
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from io import StringIO
|
7 |
+
|
8 |
+
from llama_cpp import llama_supports_gpu_offload
|
9 |
+
from sqlalchemy.engine import URL
|
10 |
+
|
11 |
+
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker
|
12 |
+
|
13 |
+
# Suppress rerankers output on import until [1] is fixed.
|
14 |
+
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
|
15 |
+
with contextlib.redirect_stdout(StringIO()):
|
16 |
+
from rerankers.models.ranker import BaseRanker
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass(frozen=True)
|
20 |
+
class RAGLiteConfig:
|
21 |
+
"""Configuration for RAGLite."""
|
22 |
+
|
23 |
+
# Database config.
|
24 |
+
db_url: str | URL = "sqlite:///raglite.sqlite"
|
25 |
+
# LLM config used for generation.
|
26 |
+
llm: str = field(
|
27 |
+
default_factory=lambda: (
|
28 |
+
"llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@8192"
|
29 |
+
if llama_supports_gpu_offload()
|
30 |
+
else "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096"
|
31 |
+
)
|
32 |
+
)
|
33 |
+
llm_max_tries: int = 4
|
34 |
+
# Embedder config used for indexing.
|
35 |
+
embedder: str = field(
|
36 |
+
default_factory=lambda: ( # Nomic-embed may be better if only English is used.
|
37 |
+
"llama-cpp-python/lm-kit/bge-m3-gguf/*F16.gguf"
|
38 |
+
if llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 4 # noqa: PLR2004
|
39 |
+
else "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf"
|
40 |
+
)
|
41 |
+
)
|
42 |
+
embedder_normalize: bool = True
|
43 |
+
embedder_sentence_window_size: int = 3
|
44 |
+
# Chunk config used to partition documents into chunks.
|
45 |
+
chunk_max_size: int = 1440 # Max number of characters per chunk.
|
46 |
+
# Vector search config.
|
47 |
+
vector_search_index_metric: str = "cosine" # The query adapter supports "dot" and "cosine".
|
48 |
+
vector_search_query_adapter: bool = True
|
49 |
+
# Reranking config.
|
50 |
+
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = field(
|
51 |
+
default_factory=lambda: (
|
52 |
+
("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
|
53 |
+
("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
|
54 |
+
),
|
55 |
+
compare=False, # Exclude the reranker from comparison to avoid lru_cache misses.
|
56 |
+
)
|
57 |
+
|
58 |
+
def __post_init__(self) -> None:
|
59 |
+
# Late chunking with llama-cpp-python does not apply sentence windowing.
|
60 |
+
if self.embedder.startswith("llama-cpp-python"):
|
61 |
+
object.__setattr__(self, "embedder_sentence_window_size", 1)
|
src/raglite/_database.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""PostgreSQL or SQLite database tables for RAGLite."""
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
from functools import lru_cache
|
6 |
+
from hashlib import sha256
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from litellm import get_model_info # type: ignore[attr-defined]
|
12 |
+
from markdown_it import MarkdownIt
|
13 |
+
from pydantic import ConfigDict
|
14 |
+
from sqlalchemy.engine import Engine, make_url
|
15 |
+
from sqlmodel import (
|
16 |
+
JSON,
|
17 |
+
Column,
|
18 |
+
Field,
|
19 |
+
Relationship,
|
20 |
+
Session,
|
21 |
+
SQLModel,
|
22 |
+
create_engine,
|
23 |
+
text,
|
24 |
+
)
|
25 |
+
|
26 |
+
from raglite._config import RAGLiteConfig
|
27 |
+
from raglite._litellm import LlamaCppPythonLLM
|
28 |
+
from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
|
29 |
+
|
30 |
+
|
31 |
+
def hash_bytes(data: bytes, max_len: int = 16) -> str:
|
32 |
+
"""Hash bytes to a hexadecimal string."""
|
33 |
+
return sha256(data, usedforsecurity=False).hexdigest()[:max_len]
|
34 |
+
|
35 |
+
|
36 |
+
class Document(SQLModel, table=True):
|
37 |
+
"""A document."""
|
38 |
+
|
39 |
+
# Enable JSON columns.
|
40 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
|
41 |
+
|
42 |
+
# Table columns.
|
43 |
+
id: str = Field(..., primary_key=True)
|
44 |
+
filename: str
|
45 |
+
url: str | None = Field(default=None)
|
46 |
+
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
|
47 |
+
|
48 |
+
# Add relationships so we can access document.chunks and document.evals.
|
49 |
+
chunks: list["Chunk"] = Relationship(back_populates="document")
|
50 |
+
evals: list["Eval"] = Relationship(back_populates="document")
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def from_path(doc_path: Path, **kwargs: Any) -> "Document":
|
54 |
+
"""Create a document from a file path."""
|
55 |
+
return Document(
|
56 |
+
id=hash_bytes(doc_path.read_bytes()),
|
57 |
+
filename=doc_path.name,
|
58 |
+
metadata_={
|
59 |
+
"size": doc_path.stat().st_size,
|
60 |
+
"created": doc_path.stat().st_ctime,
|
61 |
+
"modified": doc_path.stat().st_mtime,
|
62 |
+
**kwargs,
|
63 |
+
},
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
class Chunk(SQLModel, table=True):
|
68 |
+
"""A document chunk."""
|
69 |
+
|
70 |
+
# Enable JSON columns.
|
71 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
|
72 |
+
|
73 |
+
# Table columns.
|
74 |
+
id: str = Field(..., primary_key=True)
|
75 |
+
document_id: str = Field(..., foreign_key="document.id", index=True)
|
76 |
+
index: int = Field(..., index=True)
|
77 |
+
headings: str
|
78 |
+
body: str
|
79 |
+
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
|
80 |
+
|
81 |
+
# Add relationships so we can access chunk.document and chunk.embeddings.
|
82 |
+
document: Document = Relationship(back_populates="chunks")
|
83 |
+
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def from_body(
|
87 |
+
document_id: str,
|
88 |
+
index: int,
|
89 |
+
body: str,
|
90 |
+
headings: str = "",
|
91 |
+
**kwargs: Any,
|
92 |
+
) -> "Chunk":
|
93 |
+
"""Create a chunk from Markdown."""
|
94 |
+
return Chunk(
|
95 |
+
id=hash_bytes(body.encode()),
|
96 |
+
document_id=document_id,
|
97 |
+
index=index,
|
98 |
+
headings=headings,
|
99 |
+
body=body,
|
100 |
+
metadata_=kwargs,
|
101 |
+
)
|
102 |
+
|
103 |
+
def extract_headings(self) -> str:
|
104 |
+
"""Extract Markdown headings from the chunk, starting from the current Markdown headings."""
|
105 |
+
md = MarkdownIt()
|
106 |
+
heading_lines = [""] * 10
|
107 |
+
level = None
|
108 |
+
for doc in (self.headings, self.body):
|
109 |
+
for token in md.parse(doc):
|
110 |
+
if token.type == "heading_open":
|
111 |
+
level = int(token.tag[1])
|
112 |
+
elif token.type == "heading_close":
|
113 |
+
level = None
|
114 |
+
elif level is not None:
|
115 |
+
heading_content = token.content.strip().replace("\n", " ")
|
116 |
+
heading_lines[level] = ("#" * level) + " " + heading_content
|
117 |
+
heading_lines[level + 1 :] = [""] * len(heading_lines[level + 1 :])
|
118 |
+
headings = "\n".join([heading for heading in heading_lines if heading])
|
119 |
+
return headings
|
120 |
+
|
121 |
+
@property
|
122 |
+
def embedding_matrix(self) -> FloatMatrix:
|
123 |
+
"""Return this chunk's multi-vector embedding matrix."""
|
124 |
+
# Uses the relationship chunk.embeddings to access the chunk_embedding table.
|
125 |
+
return np.vstack([embedding.embedding[np.newaxis, :] for embedding in self.embeddings])
|
126 |
+
|
127 |
+
def __hash__(self) -> int:
|
128 |
+
return hash(self.id)
|
129 |
+
|
130 |
+
def __repr__(self) -> str:
|
131 |
+
return json.dumps(
|
132 |
+
{
|
133 |
+
"id": self.id,
|
134 |
+
"document_id": self.document_id,
|
135 |
+
"index": self.index,
|
136 |
+
"headings": self.headings,
|
137 |
+
"body": self.body[:100],
|
138 |
+
"metadata": self.metadata_,
|
139 |
+
},
|
140 |
+
indent=4,
|
141 |
+
)
|
142 |
+
|
143 |
+
def __str__(self) -> str:
|
144 |
+
"""Context representation of this chunk."""
|
145 |
+
return f"{self.headings.strip()}\n\n{self.body.strip()}".strip()
|
146 |
+
|
147 |
+
|
148 |
+
class ChunkEmbedding(SQLModel, table=True):
|
149 |
+
"""A (sub-)chunk embedding."""
|
150 |
+
|
151 |
+
__tablename__ = "chunk_embedding"
|
152 |
+
|
153 |
+
# Enable Embedding columns.
|
154 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
|
155 |
+
|
156 |
+
# Table columns.
|
157 |
+
id: int = Field(..., primary_key=True)
|
158 |
+
chunk_id: str = Field(..., foreign_key="chunk.id", index=True)
|
159 |
+
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))
|
160 |
+
|
161 |
+
# Add relationship so we can access embedding.chunk.
|
162 |
+
chunk: Chunk = Relationship(back_populates="embeddings")
|
163 |
+
|
164 |
+
@classmethod
|
165 |
+
def set_embedding_dim(cls, dim: int) -> None:
|
166 |
+
"""Modify the embedding column's dimension after class definition."""
|
167 |
+
cls.__table__.c["embedding"].type.dim = dim # type: ignore[attr-defined]
|
168 |
+
|
169 |
+
|
170 |
+
class IndexMetadata(SQLModel, table=True):
|
171 |
+
"""Vector and keyword search index metadata."""
|
172 |
+
|
173 |
+
__tablename__ = "index_metadata"
|
174 |
+
|
175 |
+
# Enable PickledObject columns.
|
176 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
|
177 |
+
|
178 |
+
# Table columns.
|
179 |
+
id: str = Field(..., primary_key=True)
|
180 |
+
version: datetime.datetime = Field(
|
181 |
+
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
|
182 |
+
)
|
183 |
+
metadata_: dict[str, Any] = Field(
|
184 |
+
default_factory=dict, sa_column=Column("metadata", PickledObject)
|
185 |
+
)
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
@lru_cache(maxsize=4)
|
189 |
+
def _get(id_: str, *, config: RAGLiteConfig | None = None) -> dict[str, Any] | None:
|
190 |
+
engine = create_database_engine(config)
|
191 |
+
with Session(engine) as session:
|
192 |
+
index_metadata_record = session.get(IndexMetadata, id_)
|
193 |
+
if index_metadata_record is None:
|
194 |
+
return None
|
195 |
+
return index_metadata_record.metadata_
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def get(id_: str = "default", *, config: RAGLiteConfig | None = None) -> dict[str, Any]:
|
199 |
+
metadata = IndexMetadata._get(id_, config=config) or {}
|
200 |
+
return metadata
|
201 |
+
|
202 |
+
|
203 |
+
class Eval(SQLModel, table=True):
|
204 |
+
"""A RAG evaluation example."""
|
205 |
+
|
206 |
+
__tablename__ = "eval"
|
207 |
+
|
208 |
+
# Enable JSON columns.
|
209 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # type: ignore[assignment]
|
210 |
+
|
211 |
+
# Table columns.
|
212 |
+
id: str = Field(..., primary_key=True)
|
213 |
+
document_id: str = Field(..., foreign_key="document.id", index=True)
|
214 |
+
chunk_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
215 |
+
question: str
|
216 |
+
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
217 |
+
ground_truth: str
|
218 |
+
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))
|
219 |
+
|
220 |
+
# Add relationship so we can access eval.document.
|
221 |
+
document: Document = Relationship(back_populates="evals")
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def from_chunks(
|
225 |
+
question: str,
|
226 |
+
contexts: list[Chunk],
|
227 |
+
ground_truth: str,
|
228 |
+
**kwargs: Any,
|
229 |
+
) -> "Eval":
|
230 |
+
"""Create a chunk from Markdown."""
|
231 |
+
document_id = contexts[0].document_id
|
232 |
+
chunk_ids = [context.id for context in contexts]
|
233 |
+
return Eval(
|
234 |
+
id=hash_bytes(f"{document_id}-{chunk_ids}-{question}".encode()),
|
235 |
+
document_id=document_id,
|
236 |
+
chunk_ids=chunk_ids,
|
237 |
+
question=question,
|
238 |
+
contexts=[str(context) for context in contexts],
|
239 |
+
ground_truth=ground_truth,
|
240 |
+
metadata_=kwargs,
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
@lru_cache(maxsize=1)
|
245 |
+
def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
|
246 |
+
"""Create a database engine and initialize it."""
|
247 |
+
# Parse the database URL and validate that the database backend is supported.
|
248 |
+
config = config or RAGLiteConfig()
|
249 |
+
db_url = make_url(config.db_url)
|
250 |
+
db_backend = db_url.get_backend_name()
|
251 |
+
# Update database configuration.
|
252 |
+
connect_args = {}
|
253 |
+
if db_backend == "postgresql":
|
254 |
+
# Select the pg8000 driver if not set (psycopg2 is the default), and prefer SSL.
|
255 |
+
if "+" not in db_url.drivername:
|
256 |
+
db_url = db_url.set(drivername="postgresql+pg8000")
|
257 |
+
# Support setting the sslmode for pg8000.
|
258 |
+
if "pg8000" in db_url.drivername and "sslmode" in db_url.query:
|
259 |
+
query = dict(db_url.query)
|
260 |
+
if query.pop("sslmode") != "disable":
|
261 |
+
connect_args["ssl_context"] = True
|
262 |
+
db_url = db_url.set(query=query)
|
263 |
+
elif db_backend == "sqlite":
|
264 |
+
# Optimize SQLite performance.
|
265 |
+
pragmas = {"journal_mode": "WAL", "synchronous": "NORMAL"}
|
266 |
+
db_url = db_url.update_query_dict(pragmas, append=True)
|
267 |
+
else:
|
268 |
+
error_message = "RAGLite only supports PostgreSQL and SQLite."
|
269 |
+
raise ValueError(error_message)
|
270 |
+
# Create the engine.
|
271 |
+
engine = create_engine(db_url, pool_pre_ping=True, connect_args=connect_args)
|
272 |
+
# Install database extensions.
|
273 |
+
if db_backend == "postgresql":
|
274 |
+
with Session(engine) as session:
|
275 |
+
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
276 |
+
session.commit()
|
277 |
+
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
|
278 |
+
# to date by loading that LLM.
|
279 |
+
if config.embedder.startswith("llama-cpp-python"):
|
280 |
+
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
|
281 |
+
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
|
282 |
+
model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
|
283 |
+
embedding_dim = model_info.get("output_vector_size") or -1
|
284 |
+
assert embedding_dim > 0
|
285 |
+
# Create all SQLModel tables.
|
286 |
+
ChunkEmbedding.set_embedding_dim(embedding_dim)
|
287 |
+
SQLModel.metadata.create_all(engine)
|
288 |
+
# Create backend-specific indexes.
|
289 |
+
if db_backend == "postgresql":
|
290 |
+
# Create a keyword search index with `tsvector` and a vector search index with `pgvector`.
|
291 |
+
with Session(engine) as session:
|
292 |
+
metrics = {"cosine": "cosine", "dot": "ip", "euclidean": "l2", "l1": "l1", "l2": "l2"}
|
293 |
+
session.execute(
|
294 |
+
text("""
|
295 |
+
CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body));
|
296 |
+
""")
|
297 |
+
)
|
298 |
+
session.execute(
|
299 |
+
text(f"""
|
300 |
+
CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding
|
301 |
+
USING hnsw (
|
302 |
+
(embedding::halfvec({embedding_dim}))
|
303 |
+
halfvec_{metrics[config.vector_search_index_metric]}_ops
|
304 |
+
);
|
305 |
+
""")
|
306 |
+
)
|
307 |
+
session.commit()
|
308 |
+
elif db_backend == "sqlite":
|
309 |
+
# Create a virtual table for keyword search on the chunk table.
|
310 |
+
# We use the chunk table as an external content table [1] to avoid duplicating the data.
|
311 |
+
# [1] https://www.sqlite.org/fts5.html#external_content_tables
|
312 |
+
with Session(engine) as session:
|
313 |
+
session.execute(
|
314 |
+
text("""
|
315 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS keyword_search_chunk_index USING fts5(body, content='chunk', content_rowid='rowid');
|
316 |
+
""")
|
317 |
+
)
|
318 |
+
session.execute(
|
319 |
+
text("""
|
320 |
+
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_insert AFTER INSERT ON chunk BEGIN
|
321 |
+
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
|
322 |
+
END;
|
323 |
+
""")
|
324 |
+
)
|
325 |
+
session.execute(
|
326 |
+
text("""
|
327 |
+
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_delete AFTER DELETE ON chunk BEGIN
|
328 |
+
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
|
329 |
+
END;
|
330 |
+
""")
|
331 |
+
)
|
332 |
+
session.execute(
|
333 |
+
text("""
|
334 |
+
CREATE TRIGGER IF NOT EXISTS keyword_search_chunk_index_auto_update AFTER UPDATE ON chunk BEGIN
|
335 |
+
INSERT INTO keyword_search_chunk_index(keyword_search_chunk_index, rowid, body) VALUES('delete', old.rowid, old.body);
|
336 |
+
INSERT INTO keyword_search_chunk_index(rowid, body) VALUES (new.rowid, new.body);
|
337 |
+
END;
|
338 |
+
""")
|
339 |
+
)
|
340 |
+
session.commit()
|
341 |
+
return engine
|
src/raglite/_embed.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""String embedder."""
|
2 |
+
|
3 |
+
from functools import partial
|
4 |
+
from typing import Literal
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from litellm import embedding
|
8 |
+
from llama_cpp import LLAMA_POOLING_TYPE_NONE, Llama
|
9 |
+
from tqdm.auto import tqdm, trange
|
10 |
+
|
11 |
+
from raglite._config import RAGLiteConfig
|
12 |
+
from raglite._litellm import LlamaCppPythonLLM
|
13 |
+
from raglite._typing import FloatMatrix, IntVector
|
14 |
+
|
15 |
+
|
16 |
+
def _embed_sentences_with_late_chunking( # noqa: PLR0915
|
17 |
+
sentences: list[str], *, config: RAGLiteConfig | None = None
|
18 |
+
) -> FloatMatrix:
|
19 |
+
"""Embed a document's sentences with late chunking."""
|
20 |
+
|
21 |
+
def _count_tokens(
|
22 |
+
sentences: list[str], embedder: Llama, sentinel_char: str, sentinel_tokens: list[int]
|
23 |
+
) -> list[int]:
|
24 |
+
# Join the sentences with the sentinel token and tokenise the result.
|
25 |
+
sentences_tokens = np.asarray(
|
26 |
+
embedder.tokenize(sentinel_char.join(sentences).encode(), add_bos=False), dtype=np.intp
|
27 |
+
)
|
28 |
+
# Map all sentinel token variants to the first one.
|
29 |
+
for sentinel_token in sentinel_tokens[1:]:
|
30 |
+
sentences_tokens[sentences_tokens == sentinel_token] = sentinel_tokens[0]
|
31 |
+
# Count how many tokens there are in between sentinel tokens to recover the token counts.
|
32 |
+
sentinel_indices = np.where(sentences_tokens == sentinel_tokens[0])[0]
|
33 |
+
num_tokens = np.diff(sentinel_indices, prepend=0, append=len(sentences_tokens))
|
34 |
+
assert len(num_tokens) == len(sentences), f"Sentinel `{sentinel_char}` appears in document"
|
35 |
+
num_tokens_list: list[int] = num_tokens.tolist()
|
36 |
+
return num_tokens_list
|
37 |
+
|
38 |
+
def _create_segment(
|
39 |
+
content_start_index: int,
|
40 |
+
max_tokens_preamble: int,
|
41 |
+
max_tokens_content: int,
|
42 |
+
num_tokens: IntVector,
|
43 |
+
) -> tuple[int, int]:
|
44 |
+
# Compute the segment sentence start index so that the segment preamble has no more than
|
45 |
+
# max_tokens_preamble tokens between [segment_start_index, content_start_index).
|
46 |
+
cumsum_backwards = np.cumsum(num_tokens[:content_start_index][::-1])
|
47 |
+
offset_preamble = np.searchsorted(cumsum_backwards, max_tokens_preamble, side="right")
|
48 |
+
segment_start_index = content_start_index - int(offset_preamble)
|
49 |
+
# Allow a larger segment content if we didn't use all of the allowed preamble tokens.
|
50 |
+
max_tokens_content = max_tokens_content + (
|
51 |
+
max_tokens_preamble - np.sum(num_tokens[segment_start_index:content_start_index])
|
52 |
+
)
|
53 |
+
# Compute the segment sentence end index so that the segment content has no more than
|
54 |
+
# max_tokens_content tokens between [content_start_index, segment_end_index).
|
55 |
+
cumsum_forwards = np.cumsum(num_tokens[content_start_index:])
|
56 |
+
offset_segment = np.searchsorted(cumsum_forwards, max_tokens_content, side="right")
|
57 |
+
segment_end_index = content_start_index + int(offset_segment)
|
58 |
+
return segment_start_index, segment_end_index
|
59 |
+
|
60 |
+
# Assert that we're using a llama-cpp-python model, since API-based embedding models don't
|
61 |
+
# support outputting token-level embeddings.
|
62 |
+
config = config or RAGLiteConfig()
|
63 |
+
assert config.embedder.startswith("llama-cpp-python")
|
64 |
+
embedder = LlamaCppPythonLLM.llm(
|
65 |
+
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
|
66 |
+
)
|
67 |
+
n_ctx = embedder.n_ctx()
|
68 |
+
n_batch = embedder.n_batch
|
69 |
+
# Identify the tokens corresponding to a sentinel character.
|
70 |
+
sentinel_char = "⊕"
|
71 |
+
sentinel_test = f"A{sentinel_char}B {sentinel_char} C.\n{sentinel_char}D"
|
72 |
+
sentinel_tokens = [
|
73 |
+
token
|
74 |
+
for token in embedder.tokenize(sentinel_test.encode(), add_bos=False)
|
75 |
+
if sentinel_char in embedder.detokenize([token]).decode()
|
76 |
+
]
|
77 |
+
assert len(sentinel_tokens), f"Sentinel `{sentinel_char}` not supported by embedder"
|
78 |
+
# Compute the number of tokens per sentence. We use a method based on a sentinel token to
|
79 |
+
# minimise the number of calls to embedder.tokenize, which incurs a significant overhead
|
80 |
+
# (presumably to load the tokenizer) [1].
|
81 |
+
# TODO: Make token counting faster and more robust once [1] is fixed.
|
82 |
+
# [1] https://github.com/abetlen/llama-cpp-python/issues/1763
|
83 |
+
num_tokens_list: list[int] = []
|
84 |
+
sentence_batch, sentence_batch_len = [], 0
|
85 |
+
for i, sentence in enumerate(sentences):
|
86 |
+
sentence_batch.append(sentence)
|
87 |
+
sentence_batch_len += len(sentence)
|
88 |
+
if i == len(sentences) - 1 or sentence_batch_len > (n_ctx // 2):
|
89 |
+
num_tokens_list.extend(
|
90 |
+
_count_tokens(sentence_batch, embedder, sentinel_char, sentinel_tokens)
|
91 |
+
)
|
92 |
+
sentence_batch, sentence_batch_len = [], 0
|
93 |
+
num_tokens = np.asarray(num_tokens_list, dtype=np.intp)
|
94 |
+
# Compute the maximum number of tokens for each segment's preamble and content.
|
95 |
+
# Unfortunately, llama-cpp-python truncates the input to n_batch tokens and crashes if you try
|
96 |
+
# to increase it [1]. Until this is fixed, we have to limit max_tokens to n_batch.
|
97 |
+
# TODO: Improve the context window size once [1] is fixed.
|
98 |
+
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
|
99 |
+
max_tokens = min(n_ctx, n_batch) - 16
|
100 |
+
max_tokens_preamble = round(0.382 * max_tokens) # Golden ratio.
|
101 |
+
max_tokens_content = max_tokens - max_tokens_preamble
|
102 |
+
# Compute a list of segments, each consisting of a preamble and content.
|
103 |
+
segments = []
|
104 |
+
content_start_index = 0
|
105 |
+
while content_start_index < len(sentences):
|
106 |
+
segment_start_index, segment_end_index = _create_segment(
|
107 |
+
content_start_index, max_tokens_preamble, max_tokens_content, num_tokens
|
108 |
+
)
|
109 |
+
segments.append((segment_start_index, content_start_index, segment_end_index))
|
110 |
+
content_start_index = segment_end_index
|
111 |
+
# Embed the segments and apply late chunking.
|
112 |
+
sentence_embeddings_list: list[FloatMatrix] = []
|
113 |
+
if len(segments) > 1 or segments[0][2] > 128: # noqa: PLR2004
|
114 |
+
segments = tqdm(segments, desc="Embedding", unit="segment", dynamic_ncols=True)
|
115 |
+
for segment in segments:
|
116 |
+
# Get the token embeddings of the entire segment, including preamble and content.
|
117 |
+
segment_start_index, content_start_index, segment_end_index = segment
|
118 |
+
segment_sentences = sentences[segment_start_index:segment_end_index]
|
119 |
+
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
|
120 |
+
# Split the segment embeddings into embedding matrices per sentence.
|
121 |
+
segment_tokens = num_tokens[segment_start_index:segment_end_index]
|
122 |
+
sentence_size = np.round(
|
123 |
+
len(segment_embedding) * (segment_tokens / np.sum(segment_tokens))
|
124 |
+
).astype(np.intp)
|
125 |
+
sentence_matrices = np.split(segment_embedding, np.cumsum(sentence_size)[:-1])
|
126 |
+
# Compute the segment sentence embeddings by averaging the token embeddings.
|
127 |
+
content_sentence_embeddings = [
|
128 |
+
np.mean(sentence_matrix, axis=0, keepdims=True)
|
129 |
+
for sentence_matrix in sentence_matrices[content_start_index - segment_start_index :]
|
130 |
+
]
|
131 |
+
sentence_embeddings_list.append(np.vstack(content_sentence_embeddings))
|
132 |
+
sentence_embeddings = np.vstack(sentence_embeddings_list)
|
133 |
+
# Normalise the sentence embeddings to unit norm and cast to half precision.
|
134 |
+
if config.embedder_normalize:
|
135 |
+
sentence_embeddings /= np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
|
136 |
+
sentence_embeddings = sentence_embeddings.astype(np.float16)
|
137 |
+
return sentence_embeddings
|
138 |
+
|
139 |
+
|
140 |
+
def _embed_sentences_with_windowing(
|
141 |
+
sentences: list[str], *, config: RAGLiteConfig | None = None
|
142 |
+
) -> FloatMatrix:
|
143 |
+
"""Embed a document's sentences with windowing."""
|
144 |
+
|
145 |
+
def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> FloatMatrix:
|
146 |
+
# Embed the batch of strings.
|
147 |
+
if config.embedder.startswith("llama-cpp-python"):
|
148 |
+
# LiteLLM doesn't yet support registering a custom embedder, so we handle it here.
|
149 |
+
# Additionally, we explicitly manually pool the token embeddings to obtain sentence
|
150 |
+
# embeddings because token embeddings are universally supported, while sequence
|
151 |
+
# embeddings are only supported by some models.
|
152 |
+
embedder = LlamaCppPythonLLM.llm(
|
153 |
+
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
|
154 |
+
)
|
155 |
+
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
|
156 |
+
else:
|
157 |
+
# Use LiteLLM's API to embed the batch of strings.
|
158 |
+
response = embedding(config.embedder, string_batch)
|
159 |
+
embeddings = np.asarray([item["embedding"] for item in response["data"]])
|
160 |
+
# Normalise the embeddings to unit norm and cast to half precision.
|
161 |
+
if config.embedder_normalize:
|
162 |
+
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
|
163 |
+
embeddings = embeddings.astype(np.float16)
|
164 |
+
return embeddings
|
165 |
+
|
166 |
+
# Window the sentences with a lookback of `config.embedder_sentence_window_size - 1` sentences.
|
167 |
+
config = config or RAGLiteConfig()
|
168 |
+
sentence_windows = [
|
169 |
+
"".join(sentences[max(0, i - (config.embedder_sentence_window_size - 1)) : i + 1])
|
170 |
+
for i in range(len(sentences))
|
171 |
+
]
|
172 |
+
# Embed the sentence windows in batches.
|
173 |
+
batch_size = 64
|
174 |
+
batch_range = (
|
175 |
+
partial(trange, desc="Embedding", unit="batch", dynamic_ncols=True)
|
176 |
+
if len(sentence_windows) > batch_size
|
177 |
+
else range
|
178 |
+
)
|
179 |
+
batch_embeddings = [
|
180 |
+
_embed_string_batch(sentence_windows[i : i + batch_size], config=config)
|
181 |
+
for i in batch_range(0, len(sentence_windows), batch_size) # type: ignore[operator]
|
182 |
+
]
|
183 |
+
sentence_embeddings = np.vstack(batch_embeddings)
|
184 |
+
return sentence_embeddings
|
185 |
+
|
186 |
+
|
187 |
+
def sentence_embedding_type(
|
188 |
+
*,
|
189 |
+
config: RAGLiteConfig | None = None,
|
190 |
+
) -> Literal["late_chunking", "windowing"]:
|
191 |
+
"""Return the type of sentence embeddings."""
|
192 |
+
config = config or RAGLiteConfig()
|
193 |
+
return "late_chunking" if config.embedder.startswith("llama-cpp-python") else "windowing"
|
194 |
+
|
195 |
+
|
196 |
+
def embed_sentences(sentences: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
|
197 |
+
"""Embed the sentences of a document as a NumPy matrix with one row per sentence."""
|
198 |
+
config = config or RAGLiteConfig()
|
199 |
+
if sentence_embedding_type(config=config) == "late_chunking":
|
200 |
+
sentence_embeddings = _embed_sentences_with_late_chunking(sentences, config=config)
|
201 |
+
else:
|
202 |
+
sentence_embeddings = _embed_sentences_with_windowing(sentences, config=config)
|
203 |
+
return sentence_embeddings
|
src/raglite/_eval.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Generation and evaluation of evals."""
|
2 |
+
|
3 |
+
from random import randint
|
4 |
+
from typing import ClassVar
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from pydantic import BaseModel, Field, field_validator
|
9 |
+
from sqlmodel import Session, func, select
|
10 |
+
from tqdm.auto import tqdm, trange
|
11 |
+
|
12 |
+
from raglite._config import RAGLiteConfig
|
13 |
+
from raglite._database import Chunk, Document, Eval, create_database_engine
|
14 |
+
from raglite._extract import extract_with_llm
|
15 |
+
from raglite._rag import rag
|
16 |
+
from raglite._search import hybrid_search, retrieve_segments, vector_search
|
17 |
+
from raglite._typing import SearchMethod
|
18 |
+
|
19 |
+
|
20 |
+
def insert_evals( # noqa: C901
|
21 |
+
*, num_evals: int = 100, max_contexts_per_eval: int = 20, config: RAGLiteConfig | None = None
|
22 |
+
) -> None:
|
23 |
+
"""Generate and insert evals into the database."""
|
24 |
+
|
25 |
+
class QuestionResponse(BaseModel):
|
26 |
+
"""A specific question about the content of a set of document contexts."""
|
27 |
+
|
28 |
+
question: str = Field(
|
29 |
+
...,
|
30 |
+
description="A specific question about the content of a set of document contexts.",
|
31 |
+
min_length=1,
|
32 |
+
)
|
33 |
+
system_prompt: ClassVar[str] = """
|
34 |
+
You are given a set of contexts extracted from a document.
|
35 |
+
You are a subject matter expert on the document's topic.
|
36 |
+
Your task is to generate a question to quiz other subject matter experts on the information in the provided context.
|
37 |
+
The question MUST satisfy ALL of the following criteria:
|
38 |
+
- The question SHOULD integrate as much of the provided context as possible.
|
39 |
+
- The question MUST NOT be a general or open question, but MUST instead be as specific to the provided context as possible.
|
40 |
+
- The question MUST be completely answerable using ONLY the information in the provided context, without depending on any background information.
|
41 |
+
- The question MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
|
42 |
+
- The question MUST NOT reference the existence of the context, directly or indirectly.
|
43 |
+
- The question MUST treat the context as if its contents are entirely part of your working memory.
|
44 |
+
""".strip()
|
45 |
+
|
46 |
+
@field_validator("question")
|
47 |
+
@classmethod
|
48 |
+
def validate_question(cls, value: str) -> str:
|
49 |
+
"""Validate the question."""
|
50 |
+
question = value.strip().lower()
|
51 |
+
if "context" in question or "document" in question or "question" in question:
|
52 |
+
raise ValueError
|
53 |
+
if not question.endswith("?"):
|
54 |
+
raise ValueError
|
55 |
+
return value
|
56 |
+
|
57 |
+
config = config or RAGLiteConfig()
|
58 |
+
engine = create_database_engine(config)
|
59 |
+
with Session(engine) as session:
|
60 |
+
for _ in trange(num_evals, desc="Generating evals", unit="eval", dynamic_ncols=True):
|
61 |
+
# Sample a random document from the database.
|
62 |
+
seed_document = session.exec(select(Document).order_by(func.random()).limit(1)).first()
|
63 |
+
if seed_document is None:
|
64 |
+
error_message = "First run `insert_document()` before generating evals."
|
65 |
+
raise ValueError(error_message)
|
66 |
+
# Sample a random chunk from that document.
|
67 |
+
seed_chunk = session.exec(
|
68 |
+
select(Chunk)
|
69 |
+
.where(Chunk.document_id == seed_document.id)
|
70 |
+
.order_by(func.random())
|
71 |
+
.limit(1)
|
72 |
+
).first()
|
73 |
+
if seed_chunk is None:
|
74 |
+
continue
|
75 |
+
# Expand the seed chunk into a set of related chunks.
|
76 |
+
related_chunk_ids, _ = vector_search(
|
77 |
+
np.mean(seed_chunk.embedding_matrix, axis=0, keepdims=True),
|
78 |
+
num_results=randint(2, max_contexts_per_eval // 2), # noqa: S311
|
79 |
+
config=config,
|
80 |
+
)
|
81 |
+
related_chunks = retrieve_segments(related_chunk_ids, config=config)
|
82 |
+
# Extract a question from the seed chunk's related chunks.
|
83 |
+
try:
|
84 |
+
question_response = extract_with_llm(
|
85 |
+
QuestionResponse, related_chunks, config=config
|
86 |
+
)
|
87 |
+
except ValueError:
|
88 |
+
continue
|
89 |
+
else:
|
90 |
+
question = question_response.question
|
91 |
+
# Search for candidate chunks to answer the generated question.
|
92 |
+
candidate_chunk_ids, _ = hybrid_search(
|
93 |
+
question, num_results=max_contexts_per_eval, config=config
|
94 |
+
)
|
95 |
+
candidate_chunks = [session.get(Chunk, chunk_id) for chunk_id in candidate_chunk_ids]
|
96 |
+
|
97 |
+
# Determine which candidate chunks are relevant to answer the generated question.
|
98 |
+
class ContextEvalResponse(BaseModel):
|
99 |
+
"""Indicate whether the provided context can be used to answer a given question."""
|
100 |
+
|
101 |
+
hit: bool = Field(
|
102 |
+
...,
|
103 |
+
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
|
104 |
+
)
|
105 |
+
system_prompt: ClassVar[str] = f"""
|
106 |
+
You are given a context extracted from a document.
|
107 |
+
You are a subject matter expert on the document's topic.
|
108 |
+
Your task is to answer whether the provided context contains (a part of) the answer to this question: "{question}"
|
109 |
+
An example of a context that does NOT contain (a part of) the answer is a table of contents.
|
110 |
+
""".strip()
|
111 |
+
|
112 |
+
relevant_chunks = []
|
113 |
+
for candidate_chunk in tqdm(
|
114 |
+
candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True
|
115 |
+
):
|
116 |
+
try:
|
117 |
+
context_eval_response = extract_with_llm(
|
118 |
+
ContextEvalResponse, str(candidate_chunk), config=config
|
119 |
+
)
|
120 |
+
except ValueError: # noqa: PERF203
|
121 |
+
pass
|
122 |
+
else:
|
123 |
+
if context_eval_response.hit:
|
124 |
+
relevant_chunks.append(candidate_chunk)
|
125 |
+
if not relevant_chunks:
|
126 |
+
continue
|
127 |
+
|
128 |
+
# Answer the question using the relevant chunks.
|
129 |
+
class AnswerResponse(BaseModel):
|
130 |
+
"""Answer a question using the provided context."""
|
131 |
+
|
132 |
+
answer: str = Field(
|
133 |
+
...,
|
134 |
+
description="A complete answer to the given question using the provided context.",
|
135 |
+
min_length=1,
|
136 |
+
)
|
137 |
+
system_prompt: ClassVar[str] = f"""
|
138 |
+
You are given a set of contexts extracted from a document.
|
139 |
+
You are a subject matter expert on the document's topic.
|
140 |
+
Your task is to generate a complete answer to the following question using the provided context: "{question}"
|
141 |
+
The answer MUST satisfy ALL of the following criteria:
|
142 |
+
- The answer MUST integrate as much of the provided context as possible.
|
143 |
+
- The answer MUST be entirely self-contained and able to be understood in full WITHOUT access to the provided context.
|
144 |
+
- The answer MUST NOT reference the existence of the context, directly or indirectly.
|
145 |
+
- The answer MUST treat the context as if its contents are entirely part of your working memory.
|
146 |
+
""".strip()
|
147 |
+
|
148 |
+
try:
|
149 |
+
answer_response = extract_with_llm(
|
150 |
+
AnswerResponse,
|
151 |
+
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
|
152 |
+
config=config,
|
153 |
+
)
|
154 |
+
except ValueError:
|
155 |
+
continue
|
156 |
+
else:
|
157 |
+
answer = answer_response.answer
|
158 |
+
# Store the eval in the database.
|
159 |
+
eval_ = Eval.from_chunks(
|
160 |
+
question=question,
|
161 |
+
contexts=relevant_chunks,
|
162 |
+
ground_truth=answer,
|
163 |
+
)
|
164 |
+
session.add(eval_)
|
165 |
+
session.commit()
|
166 |
+
|
167 |
+
|
168 |
+
def answer_evals(
|
169 |
+
num_evals: int = 100,
|
170 |
+
search: SearchMethod = hybrid_search,
|
171 |
+
*,
|
172 |
+
config: RAGLiteConfig | None = None,
|
173 |
+
) -> pd.DataFrame:
|
174 |
+
"""Read evals from the database and answer them with RAG."""
|
175 |
+
# Read evals from the database.
|
176 |
+
config = config or RAGLiteConfig()
|
177 |
+
engine = create_database_engine(config)
|
178 |
+
with Session(engine) as session:
|
179 |
+
evals = session.exec(select(Eval).limit(num_evals)).all()
|
180 |
+
# Answer evals with RAG.
|
181 |
+
answers: list[str] = []
|
182 |
+
contexts: list[list[str]] = []
|
183 |
+
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
|
184 |
+
response = rag(eval_.question, search=search, config=config)
|
185 |
+
answer = "".join(response)
|
186 |
+
answers.append(answer)
|
187 |
+
chunk_ids, _ = search(eval_.question, config=config)
|
188 |
+
contexts.append(retrieve_segments(chunk_ids))
|
189 |
+
# Collect the answered evals.
|
190 |
+
answered_evals: dict[str, list[str] | list[list[str]]] = {
|
191 |
+
"question": [eval_.question for eval_ in evals],
|
192 |
+
"answer": answers,
|
193 |
+
"contexts": contexts,
|
194 |
+
"ground_truth": [eval_.ground_truth for eval_ in evals],
|
195 |
+
"ground_truth_contexts": [eval_.contexts for eval_ in evals],
|
196 |
+
}
|
197 |
+
answered_evals_df = pd.DataFrame.from_dict(answered_evals)
|
198 |
+
return answered_evals_df
|
199 |
+
|
200 |
+
|
201 |
+
def evaluate(
|
202 |
+
answered_evals: pd.DataFrame | int = 100,
|
203 |
+
config: RAGLiteConfig | None = None,
|
204 |
+
) -> pd.DataFrame:
|
205 |
+
"""Evaluate the performance of a set of answered evals with Ragas."""
|
206 |
+
try:
|
207 |
+
from datasets import Dataset
|
208 |
+
from langchain_community.chat_models import ChatLiteLLM
|
209 |
+
from langchain_community.embeddings import LlamaCppEmbeddings
|
210 |
+
from langchain_community.llms import LlamaCpp
|
211 |
+
from ragas import RunConfig
|
212 |
+
from ragas import evaluate as ragas_evaluate
|
213 |
+
|
214 |
+
from raglite._litellm import LlamaCppPythonLLM
|
215 |
+
except ImportError as import_error:
|
216 |
+
error_message = "To use the `evaluate` function, please install the `ragas` extra."
|
217 |
+
raise ImportError(error_message) from import_error
|
218 |
+
|
219 |
+
# Create a set of answered evals if not provided.
|
220 |
+
config = config or RAGLiteConfig()
|
221 |
+
answered_evals_df = (
|
222 |
+
answered_evals
|
223 |
+
if isinstance(answered_evals, pd.DataFrame)
|
224 |
+
else answer_evals(num_evals=answered_evals, config=config)
|
225 |
+
)
|
226 |
+
# Load the LLM.
|
227 |
+
if config.llm.startswith("llama-cpp-python"):
|
228 |
+
llm = LlamaCppPythonLLM().llm(model=config.llm)
|
229 |
+
lc_llm = LlamaCpp(
|
230 |
+
model_path=llm.model_path,
|
231 |
+
n_batch=llm.n_batch,
|
232 |
+
n_ctx=llm.n_ctx(),
|
233 |
+
n_gpu_layers=-1,
|
234 |
+
verbose=llm.verbose,
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
lc_llm = ChatLiteLLM(model=config.llm) # type: ignore[call-arg]
|
238 |
+
# Load the embedder.
|
239 |
+
if not config.embedder.startswith("llama-cpp-python"):
|
240 |
+
error_message = "Currently, only `llama-cpp-python` embedders are supported."
|
241 |
+
raise NotImplementedError(error_message)
|
242 |
+
embedder = LlamaCppPythonLLM().llm(model=config.embedder, embedding=True)
|
243 |
+
lc_embedder = LlamaCppEmbeddings( # type: ignore[call-arg]
|
244 |
+
model_path=embedder.model_path,
|
245 |
+
n_batch=embedder.n_batch,
|
246 |
+
n_ctx=embedder.n_ctx(),
|
247 |
+
n_gpu_layers=-1,
|
248 |
+
verbose=embedder.verbose,
|
249 |
+
)
|
250 |
+
# Evaluate the answered evals with Ragas.
|
251 |
+
evaluation_df = ragas_evaluate(
|
252 |
+
dataset=Dataset.from_pandas(answered_evals_df),
|
253 |
+
llm=lc_llm,
|
254 |
+
embeddings=lc_embedder,
|
255 |
+
run_config=RunConfig(max_workers=1),
|
256 |
+
).to_pandas()
|
257 |
+
return evaluation_df
|
src/raglite/_extract.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Extract structured data from unstructured text with an LLM."""
|
2 |
+
|
3 |
+
from typing import Any, TypeVar
|
4 |
+
|
5 |
+
from litellm import completion
|
6 |
+
from pydantic import BaseModel, ValidationError
|
7 |
+
|
8 |
+
from raglite._config import RAGLiteConfig
|
9 |
+
|
10 |
+
T = TypeVar("T", bound=BaseModel)
|
11 |
+
|
12 |
+
|
13 |
+
def extract_with_llm(
|
14 |
+
return_type: type[T],
|
15 |
+
user_prompt: str | list[str],
|
16 |
+
config: RAGLiteConfig | None = None,
|
17 |
+
**kwargs: Any,
|
18 |
+
) -> T:
|
19 |
+
"""Extract structured data from unstructured text with an LLM.
|
20 |
+
|
21 |
+
This function expects a `return_type.system_prompt: ClassVar[str]` that contains the system
|
22 |
+
prompt to use. Example:
|
23 |
+
|
24 |
+
from typing import ClassVar
|
25 |
+
from pydantic import BaseModel, Field
|
26 |
+
|
27 |
+
class MyNameResponse(BaseModel):
|
28 |
+
my_name: str = Field(..., description="The user's name.")
|
29 |
+
system_prompt: ClassVar[str] = "The system prompt to use (excluded from JSON schema)."
|
30 |
+
|
31 |
+
my_name_response = extract_with_llm(MyNameResponse, "My name is Thomas A. Anderson.")
|
32 |
+
"""
|
33 |
+
# Load the default config if not provided.
|
34 |
+
config = config or RAGLiteConfig()
|
35 |
+
# Update the system prompt with the JSON schema of the return type to help the LLM.
|
36 |
+
system_prompt = (
|
37 |
+
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
|
38 |
+
"Format your response according to this JSON schema:\n",
|
39 |
+
return_type.model_json_schema(),
|
40 |
+
)
|
41 |
+
# Concatenate the user prompt if it is a list of strings.
|
42 |
+
if isinstance(user_prompt, list):
|
43 |
+
user_prompt = "\n\n".join(
|
44 |
+
f'<context index="{i}">\n{chunk.strip()}\n</context>'
|
45 |
+
for i, chunk in enumerate(user_prompt)
|
46 |
+
)
|
47 |
+
# Extract structured data from the unstructured input.
|
48 |
+
for _ in range(config.llm_max_tries):
|
49 |
+
response = completion(
|
50 |
+
model=config.llm,
|
51 |
+
messages=[
|
52 |
+
{"role": "system", "content": system_prompt},
|
53 |
+
{"role": "user", "content": user_prompt},
|
54 |
+
],
|
55 |
+
response_format={"type": "json_object", "schema": return_type.model_json_schema()},
|
56 |
+
**kwargs,
|
57 |
+
)
|
58 |
+
try:
|
59 |
+
instance = return_type.model_validate_json(response["choices"][0]["message"]["content"])
|
60 |
+
except (KeyError, ValueError, ValidationError) as e:
|
61 |
+
# Malformed response, not a JSON string, or not a valid instance of the return type.
|
62 |
+
last_exception = e
|
63 |
+
continue
|
64 |
+
else:
|
65 |
+
break
|
66 |
+
else:
|
67 |
+
error_message = f"Failed to extract {return_type} from input {user_prompt}."
|
68 |
+
raise ValueError(error_message) from last_exception
|
69 |
+
return instance
|
src/raglite/_flashrank.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Patched version of FlashRankRanker that fixes incorrect reranking [1].
|
2 |
+
|
3 |
+
[1] https://github.com/AnswerDotAI/rerankers/issues/39
|
4 |
+
"""
|
5 |
+
|
6 |
+
import contextlib
|
7 |
+
from io import StringIO
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
from flashrank import RerankRequest
|
11 |
+
|
12 |
+
# Suppress rerankers output on import until [1] is fixed.
|
13 |
+
# [1] https://github.com/AnswerDotAI/rerankers/issues/36
|
14 |
+
with contextlib.redirect_stdout(StringIO()):
|
15 |
+
from rerankers.documents import Document
|
16 |
+
from rerankers.models.flashrank_ranker import FlashRankRanker
|
17 |
+
from rerankers.results import RankedResults, Result
|
18 |
+
from rerankers.utils import prep_docs
|
19 |
+
|
20 |
+
|
21 |
+
class PatchedFlashRankRanker(FlashRankRanker):
|
22 |
+
def rank(
|
23 |
+
self,
|
24 |
+
query: str,
|
25 |
+
docs: str | list[str] | Document | list[Document],
|
26 |
+
doc_ids: list[str] | list[int] | None = None,
|
27 |
+
metadata: list[dict[str, Any]] | None = None,
|
28 |
+
) -> RankedResults:
|
29 |
+
docs = prep_docs(docs, doc_ids, metadata)
|
30 |
+
passages = [{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)]
|
31 |
+
rerank_request = RerankRequest(query=query, passages=passages)
|
32 |
+
flashrank_results = self.model.rerank(rerank_request)
|
33 |
+
ranked_results = [
|
34 |
+
Result(
|
35 |
+
document=docs[result["id"]], # This patches the incorrect ranking in the original.
|
36 |
+
score=result["score"],
|
37 |
+
rank=idx + 1,
|
38 |
+
)
|
39 |
+
for idx, result in enumerate(flashrank_results)
|
40 |
+
]
|
41 |
+
return RankedResults(results=ranked_results, query=query, has_scores=True)
|
src/raglite/_insert.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Index documents."""
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from sqlalchemy.engine import make_url
|
7 |
+
from sqlmodel import Session, select
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
|
10 |
+
from raglite._config import RAGLiteConfig
|
11 |
+
from raglite._database import Chunk, ChunkEmbedding, Document, IndexMetadata, create_database_engine
|
12 |
+
from raglite._embed import embed_sentences, sentence_embedding_type
|
13 |
+
from raglite._markdown import document_to_markdown
|
14 |
+
from raglite._split_chunks import split_chunks
|
15 |
+
from raglite._split_sentences import split_sentences
|
16 |
+
from raglite._typing import FloatMatrix
|
17 |
+
|
18 |
+
|
19 |
+
def _create_chunk_records(
|
20 |
+
document_id: str,
|
21 |
+
chunks: list[str],
|
22 |
+
chunk_embeddings: list[FloatMatrix],
|
23 |
+
config: RAGLiteConfig,
|
24 |
+
) -> tuple[list[Chunk], list[list[ChunkEmbedding]]]:
|
25 |
+
"""Process chunks into chunk and chunk embedding records."""
|
26 |
+
# Create the chunk records.
|
27 |
+
chunk_records, headings = [], ""
|
28 |
+
for i, chunk in enumerate(chunks):
|
29 |
+
# Create and append the chunk record.
|
30 |
+
record = Chunk.from_body(document_id=document_id, index=i, body=chunk, headings=headings)
|
31 |
+
chunk_records.append(record)
|
32 |
+
# Update the Markdown headings with those of this chunk.
|
33 |
+
headings = record.extract_headings()
|
34 |
+
# Create the chunk embedding records.
|
35 |
+
chunk_embedding_records = []
|
36 |
+
if sentence_embedding_type(config=config) == "late_chunking":
|
37 |
+
# Every chunk record is associated with a list of chunk embedding records, one for each of
|
38 |
+
# the sentences in the chunk.
|
39 |
+
for chunk_record, chunk_embedding in zip(chunk_records, chunk_embeddings, strict=True):
|
40 |
+
chunk_embedding_records.append(
|
41 |
+
[
|
42 |
+
ChunkEmbedding(chunk_id=chunk_record.id, embedding=sentence_embedding)
|
43 |
+
for sentence_embedding in chunk_embedding
|
44 |
+
]
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
# Embed the full chunks, including the current Markdown headings.
|
48 |
+
full_chunk_embeddings = embed_sentences([str(chunk) for chunk in chunks], config=config)
|
49 |
+
# Every chunk record is associated with a list of chunk embedding records. The chunk
|
50 |
+
# embedding records each correspond to a linear combination of a sentence embedding and an
|
51 |
+
# embedding of the full chunk with Markdown headings.
|
52 |
+
α = 0.382 # Golden ratio. # noqa: PLC2401
|
53 |
+
for chunk_record, chunk_embedding, full_chunk_embedding in zip(
|
54 |
+
chunk_records, chunk_embeddings, full_chunk_embeddings, strict=True
|
55 |
+
):
|
56 |
+
chunk_embedding_records.append(
|
57 |
+
[
|
58 |
+
ChunkEmbedding(
|
59 |
+
chunk_id=chunk_record.id,
|
60 |
+
embedding=α * sentence_embedding + (1 - α) * full_chunk_embedding,
|
61 |
+
)
|
62 |
+
for sentence_embedding in chunk_embedding
|
63 |
+
]
|
64 |
+
)
|
65 |
+
return chunk_records, chunk_embedding_records
|
66 |
+
|
67 |
+
|
68 |
+
def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> None: # noqa: PLR0915
|
69 |
+
"""Insert a document into the database and update the index."""
|
70 |
+
# Use the default config if not provided.
|
71 |
+
config = config or RAGLiteConfig()
|
72 |
+
db_backend = make_url(config.db_url).get_backend_name()
|
73 |
+
# Preprocess the document into chunks and chunk embeddings.
|
74 |
+
with tqdm(total=5, unit="step", dynamic_ncols=True) as pbar:
|
75 |
+
pbar.set_description("Initializing database")
|
76 |
+
engine = create_database_engine(config)
|
77 |
+
pbar.update(1)
|
78 |
+
pbar.set_description("Converting to Markdown")
|
79 |
+
doc = document_to_markdown(doc_path)
|
80 |
+
pbar.update(1)
|
81 |
+
pbar.set_description("Splitting sentences")
|
82 |
+
sentences = split_sentences(doc, max_len=config.chunk_max_size)
|
83 |
+
pbar.update(1)
|
84 |
+
pbar.set_description("Embedding sentences")
|
85 |
+
sentence_embeddings = embed_sentences(sentences, config=config)
|
86 |
+
pbar.update(1)
|
87 |
+
pbar.set_description("Splitting chunks")
|
88 |
+
chunks, chunk_embeddings = split_chunks(
|
89 |
+
sentences=sentences,
|
90 |
+
sentence_embeddings=sentence_embeddings,
|
91 |
+
sentence_window_size=config.embedder_sentence_window_size,
|
92 |
+
max_size=config.chunk_max_size,
|
93 |
+
)
|
94 |
+
pbar.update(1)
|
95 |
+
# Create and store the chunk records.
|
96 |
+
with Session(engine) as session:
|
97 |
+
# Add the document to the document table.
|
98 |
+
document_record = Document.from_path(doc_path)
|
99 |
+
if session.get(Document, document_record.id) is None:
|
100 |
+
session.add(document_record)
|
101 |
+
session.commit()
|
102 |
+
# Create the chunk records to insert into the chunk table.
|
103 |
+
chunk_records, chunk_embedding_records = _create_chunk_records(
|
104 |
+
document_record.id, chunks, chunk_embeddings, config
|
105 |
+
)
|
106 |
+
# Store the chunk and chunk embedding records.
|
107 |
+
for chunk_record, chunk_embedding_record_list in tqdm(
|
108 |
+
zip(chunk_records, chunk_embedding_records, strict=True),
|
109 |
+
desc="Inserting chunks",
|
110 |
+
total=len(chunk_records),
|
111 |
+
unit="chunk",
|
112 |
+
dynamic_ncols=True,
|
113 |
+
):
|
114 |
+
if session.get(Chunk, chunk_record.id) is not None:
|
115 |
+
continue
|
116 |
+
session.add(chunk_record)
|
117 |
+
session.add_all(chunk_embedding_record_list)
|
118 |
+
session.commit()
|
119 |
+
# Manually update the vector search chunk index for SQLite.
|
120 |
+
if db_backend == "sqlite":
|
121 |
+
from pynndescent import NNDescent
|
122 |
+
|
123 |
+
with Session(engine) as session:
|
124 |
+
# Get the vector search chunk index from the database, or create a new one.
|
125 |
+
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
|
126 |
+
chunk_ids = index_metadata.metadata_.get("chunk_ids", [])
|
127 |
+
chunk_sizes = index_metadata.metadata_.get("chunk_sizes", [])
|
128 |
+
# Get the unindexed chunks.
|
129 |
+
unindexed_chunks = list(session.exec(select(Chunk).offset(len(chunk_ids))).all())
|
130 |
+
if not unindexed_chunks:
|
131 |
+
return
|
132 |
+
# Assemble the unindexed chunk embeddings into a NumPy array.
|
133 |
+
unindexed_chunk_embeddings = [chunk.embedding_matrix for chunk in unindexed_chunks]
|
134 |
+
X = np.vstack(unindexed_chunk_embeddings) # noqa: N806
|
135 |
+
# Index the unindexed chunks.
|
136 |
+
with tqdm(
|
137 |
+
total=len(unindexed_chunks),
|
138 |
+
desc="Indexing chunks",
|
139 |
+
unit="chunk",
|
140 |
+
dynamic_ncols=True,
|
141 |
+
) as pbar:
|
142 |
+
# Fit or update the ANN index.
|
143 |
+
if len(chunk_ids) == 0:
|
144 |
+
nndescent = NNDescent(X, metric=config.vector_search_index_metric)
|
145 |
+
else:
|
146 |
+
nndescent = index_metadata.metadata_["index"]
|
147 |
+
nndescent.update(X)
|
148 |
+
# Prepare the ANN index so it can to handle query vectors not in the training set.
|
149 |
+
nndescent.prepare()
|
150 |
+
# Update the index metadata and mark it as dirty by recreating the dictionary.
|
151 |
+
index_metadata.metadata_ = {
|
152 |
+
**index_metadata.metadata_,
|
153 |
+
"index": nndescent,
|
154 |
+
"chunk_ids": chunk_ids + [c.id for c in unindexed_chunks],
|
155 |
+
"chunk_sizes": chunk_sizes + [len(em) for em in unindexed_chunk_embeddings],
|
156 |
+
}
|
157 |
+
# Store the updated vector search chunk index.
|
158 |
+
session.add(index_metadata)
|
159 |
+
session.commit()
|
160 |
+
pbar.update(len(unindexed_chunks))
|
src/raglite/_litellm.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Add support for llama-cpp-python models to LiteLLM."""
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import logging
|
5 |
+
import warnings
|
6 |
+
from collections.abc import AsyncIterator, Callable, Iterator
|
7 |
+
from functools import cache
|
8 |
+
from typing import Any, ClassVar, cast
|
9 |
+
|
10 |
+
import httpx
|
11 |
+
import litellm
|
12 |
+
from litellm import ( # type: ignore[attr-defined]
|
13 |
+
CustomLLM,
|
14 |
+
GenericStreamingChunk,
|
15 |
+
ModelResponse,
|
16 |
+
convert_to_model_response_object,
|
17 |
+
)
|
18 |
+
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
19 |
+
from llama_cpp import ( # type: ignore[attr-defined]
|
20 |
+
ChatCompletionRequestMessage,
|
21 |
+
CreateChatCompletionResponse,
|
22 |
+
CreateChatCompletionStreamResponse,
|
23 |
+
Llama,
|
24 |
+
LlamaRAMCache,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Reduce the logging level for LiteLLM and flashrank.
|
28 |
+
logging.getLogger("litellm").setLevel(logging.WARNING)
|
29 |
+
logging.getLogger("flashrank").setLevel(logging.WARNING)
|
30 |
+
|
31 |
+
|
32 |
+
class LlamaCppPythonLLM(CustomLLM):
|
33 |
+
"""A llama-cpp-python provider for LiteLLM.
|
34 |
+
|
35 |
+
This provider enables using llama-cpp-python models with LiteLLM. The LiteLLM model
|
36 |
+
specification is "llama-cpp-python/<hugging_face_repo_id>/<filename>@<n_ctx>", where n_ctx is
|
37 |
+
an optional parameter that specifies the context size of the model. If n_ctx is not provided or
|
38 |
+
if it's set to 0, the model's default context size is used.
|
39 |
+
|
40 |
+
Example usage:
|
41 |
+
|
42 |
+
```python
|
43 |
+
from litellm import completion
|
44 |
+
|
45 |
+
response = completion(
|
46 |
+
model="llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4092",
|
47 |
+
messages=[{"role": "user", "content": "Hello world!"}],
|
48 |
+
# stream=True
|
49 |
+
)
|
50 |
+
```
|
51 |
+
"""
|
52 |
+
|
53 |
+
# Create a lock to prevent concurrent access to llama-cpp-python models.
|
54 |
+
streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
|
55 |
+
|
56 |
+
# The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included:
|
57 |
+
# max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers.
|
58 |
+
# [1] https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion
|
59 |
+
# [2] https://docs.litellm.ai/docs/completion/input
|
60 |
+
supported_openai_params: ClassVar[list[str]] = [
|
61 |
+
"functions", # Deprecated
|
62 |
+
"function_call", # Deprecated
|
63 |
+
"tools",
|
64 |
+
"tool_choice",
|
65 |
+
"temperature",
|
66 |
+
"top_p",
|
67 |
+
"top_k",
|
68 |
+
"min_p",
|
69 |
+
"typical_p",
|
70 |
+
"stop",
|
71 |
+
"seed",
|
72 |
+
"response_format",
|
73 |
+
"max_tokens",
|
74 |
+
"presence_penalty",
|
75 |
+
"frequency_penalty",
|
76 |
+
"repeat_penalty",
|
77 |
+
"tfs_z",
|
78 |
+
"mirostat_mode",
|
79 |
+
"mirostat_tau",
|
80 |
+
"mirostat_eta",
|
81 |
+
"logit_bias",
|
82 |
+
]
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
@cache
|
86 |
+
def llm(model: str, **kwargs: Any) -> Llama:
|
87 |
+
# Drop the llama-cpp-python prefix from the model.
|
88 |
+
repo_id_filename = model.replace("llama-cpp-python/", "")
|
89 |
+
# Convert the LiteLLM model string to repo_id, filename, and n_ctx.
|
90 |
+
repo_id, filename = repo_id_filename.rsplit("/", maxsplit=1)
|
91 |
+
n_ctx = 0
|
92 |
+
if len(filename_n_ctx := filename.rsplit("@", maxsplit=1)) == 2: # noqa: PLR2004
|
93 |
+
filename, n_ctx_str = filename_n_ctx
|
94 |
+
n_ctx = int(n_ctx_str)
|
95 |
+
# Load the LLM.
|
96 |
+
with warnings.catch_warnings(): # Filter huggingface_hub warning about HF_TOKEN.
|
97 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
98 |
+
llm = Llama.from_pretrained(
|
99 |
+
repo_id=repo_id,
|
100 |
+
filename=filename,
|
101 |
+
n_ctx=n_ctx,
|
102 |
+
n_gpu_layers=-1,
|
103 |
+
verbose=False,
|
104 |
+
**kwargs,
|
105 |
+
)
|
106 |
+
# Enable caching.
|
107 |
+
llm.set_cache(LlamaRAMCache())
|
108 |
+
# Register the model info with LiteLLM.
|
109 |
+
litellm.register_model( # type: ignore[attr-defined]
|
110 |
+
{
|
111 |
+
model: {
|
112 |
+
"max_tokens": llm.n_ctx(),
|
113 |
+
"max_input_tokens": llm.n_ctx(),
|
114 |
+
"max_output_tokens": None,
|
115 |
+
"input_cost_per_token": 0.0,
|
116 |
+
"output_cost_per_token": 0.0,
|
117 |
+
"output_vector_size": llm.n_embd() if kwargs.get("embedding") else None,
|
118 |
+
"litellm_provider": "llama-cpp-python",
|
119 |
+
"mode": "embedding" if kwargs.get("embedding") else "completion",
|
120 |
+
"supported_openai_params": LlamaCppPythonLLM.supported_openai_params,
|
121 |
+
"supports_function_calling": True,
|
122 |
+
"supports_parallel_function_calling": True,
|
123 |
+
"supports_vision": False,
|
124 |
+
}
|
125 |
+
}
|
126 |
+
)
|
127 |
+
return llm
|
128 |
+
|
129 |
+
def completion( # noqa: PLR0913
|
130 |
+
self,
|
131 |
+
model: str,
|
132 |
+
messages: list[ChatCompletionRequestMessage],
|
133 |
+
api_base: str,
|
134 |
+
custom_prompt_dict: dict[str, Any],
|
135 |
+
model_response: ModelResponse,
|
136 |
+
print_verbose: Callable, # type: ignore[type-arg]
|
137 |
+
encoding: str,
|
138 |
+
api_key: str,
|
139 |
+
logging_obj: Any,
|
140 |
+
optional_params: dict[str, Any],
|
141 |
+
acompletion: Callable | None = None, # type: ignore[type-arg]
|
142 |
+
litellm_params: dict[str, Any] | None = None,
|
143 |
+
logger_fn: Callable | None = None, # type: ignore[type-arg]
|
144 |
+
headers: dict[str, Any] | None = None,
|
145 |
+
timeout: float | httpx.Timeout | None = None,
|
146 |
+
client: HTTPHandler | None = None,
|
147 |
+
) -> ModelResponse:
|
148 |
+
llm = self.llm(model)
|
149 |
+
llama_cpp_python_params = {
|
150 |
+
k: v for k, v in optional_params.items() if k in self.supported_openai_params
|
151 |
+
}
|
152 |
+
response = cast(
|
153 |
+
CreateChatCompletionResponse,
|
154 |
+
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
|
155 |
+
)
|
156 |
+
litellm_model_response: ModelResponse = convert_to_model_response_object(
|
157 |
+
response_object=response,
|
158 |
+
model_response_object=model_response,
|
159 |
+
response_type="completion",
|
160 |
+
stream=False,
|
161 |
+
)
|
162 |
+
return litellm_model_response
|
163 |
+
|
164 |
+
def streaming( # noqa: PLR0913
|
165 |
+
self,
|
166 |
+
model: str,
|
167 |
+
messages: list[ChatCompletionRequestMessage],
|
168 |
+
api_base: str,
|
169 |
+
custom_prompt_dict: dict[str, Any],
|
170 |
+
model_response: ModelResponse,
|
171 |
+
print_verbose: Callable, # type: ignore[type-arg]
|
172 |
+
encoding: str,
|
173 |
+
api_key: str,
|
174 |
+
logging_obj: Any,
|
175 |
+
optional_params: dict[str, Any],
|
176 |
+
acompletion: Callable | None = None, # type: ignore[type-arg]
|
177 |
+
litellm_params: dict[str, Any] | None = None,
|
178 |
+
logger_fn: Callable | None = None, # type: ignore[type-arg]
|
179 |
+
headers: dict[str, Any] | None = None,
|
180 |
+
timeout: float | httpx.Timeout | None = None,
|
181 |
+
client: HTTPHandler | None = None,
|
182 |
+
) -> Iterator[GenericStreamingChunk]:
|
183 |
+
llm = self.llm(model)
|
184 |
+
llama_cpp_python_params = {
|
185 |
+
k: v for k, v in optional_params.items() if k in self.supported_openai_params
|
186 |
+
}
|
187 |
+
stream = cast(
|
188 |
+
Iterator[CreateChatCompletionStreamResponse],
|
189 |
+
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),
|
190 |
+
)
|
191 |
+
for chunk in stream:
|
192 |
+
choices = chunk.get("choices", [])
|
193 |
+
for choice in choices:
|
194 |
+
text = choice.get("delta", {}).get("content", None)
|
195 |
+
finish_reason = choice.get("finish_reason")
|
196 |
+
litellm_generic_streaming_chunk = GenericStreamingChunk(
|
197 |
+
text=text, # type: ignore[typeddict-item]
|
198 |
+
is_finished=bool(finish_reason),
|
199 |
+
finish_reason=finish_reason, # type: ignore[typeddict-item]
|
200 |
+
usage=None,
|
201 |
+
index=choice.get("index"), # type: ignore[typeddict-item]
|
202 |
+
provider_specific_fields={
|
203 |
+
"id": chunk.get("id"),
|
204 |
+
"model": chunk.get("model"),
|
205 |
+
"created": chunk.get("created"),
|
206 |
+
"object": chunk.get("object"),
|
207 |
+
},
|
208 |
+
)
|
209 |
+
yield litellm_generic_streaming_chunk
|
210 |
+
|
211 |
+
async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
|
212 |
+
self,
|
213 |
+
model: str,
|
214 |
+
messages: list[ChatCompletionRequestMessage],
|
215 |
+
api_base: str,
|
216 |
+
custom_prompt_dict: dict[str, Any],
|
217 |
+
model_response: ModelResponse,
|
218 |
+
print_verbose: Callable, # type: ignore[type-arg]
|
219 |
+
encoding: str,
|
220 |
+
api_key: str,
|
221 |
+
logging_obj: Any,
|
222 |
+
optional_params: dict[str, Any],
|
223 |
+
acompletion: Callable | None = None, # type: ignore[type-arg]
|
224 |
+
litellm_params: dict[str, Any] | None = None,
|
225 |
+
logger_fn: Callable | None = None, # type: ignore[type-arg]
|
226 |
+
headers: dict[str, Any] | None = None,
|
227 |
+
timeout: float | httpx.Timeout | None = None, # noqa: ASYNC109
|
228 |
+
client: AsyncHTTPHandler | None = None,
|
229 |
+
) -> AsyncIterator[GenericStreamingChunk]:
|
230 |
+
# Start a synchronous stream.
|
231 |
+
stream = self.streaming(
|
232 |
+
model,
|
233 |
+
messages,
|
234 |
+
api_base,
|
235 |
+
custom_prompt_dict,
|
236 |
+
model_response,
|
237 |
+
print_verbose,
|
238 |
+
encoding,
|
239 |
+
api_key,
|
240 |
+
logging_obj,
|
241 |
+
optional_params,
|
242 |
+
acompletion,
|
243 |
+
litellm_params,
|
244 |
+
logger_fn,
|
245 |
+
headers,
|
246 |
+
timeout,
|
247 |
+
)
|
248 |
+
await asyncio.sleep(0) # Yield control to the event loop after initialising the context.
|
249 |
+
# Wrap the synchronous stream in an asynchronous stream.
|
250 |
+
async with LlamaCppPythonLLM.streaming_lock:
|
251 |
+
for litellm_generic_streaming_chunk in stream:
|
252 |
+
yield litellm_generic_streaming_chunk
|
253 |
+
await asyncio.sleep(0) # Yield control to the event loop after each token.
|
254 |
+
|
255 |
+
|
256 |
+
# Register the LlamaCppPythonLLM provider.
|
257 |
+
if not any(provider["provider"] == "llama-cpp-python" for provider in litellm.custom_provider_map):
|
258 |
+
litellm.custom_provider_map.append(
|
259 |
+
{"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}
|
260 |
+
)
|
261 |
+
litellm.suppress_debug_info = True
|
src/raglite/_markdown.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Convert any document to Markdown."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import mdformat
|
9 |
+
import numpy as np
|
10 |
+
from pdftext.extraction import dictionary_output
|
11 |
+
from sklearn.cluster import KMeans
|
12 |
+
|
13 |
+
|
14 |
+
def parsed_pdf_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0915
|
15 |
+
"""Convert a PDF parsed with pdftext to Markdown."""
|
16 |
+
|
17 |
+
def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]: # noqa: C901
|
18 |
+
"""Add heading level metadata to a PDF parsed with pdftext."""
|
19 |
+
|
20 |
+
def extract_font_size(span: dict[str, Any]) -> float:
|
21 |
+
"""Extract the font size from a text span."""
|
22 |
+
font_size: float = 1.0
|
23 |
+
if span["font"]["size"] > 1: # A value of 1 appears to mean "unknown" in pdftext.
|
24 |
+
font_size = span["font"]["size"]
|
25 |
+
elif digit_sequences := re.findall(r"\d+", span["font"]["name"] or ""):
|
26 |
+
font_size = float(digit_sequences[-1])
|
27 |
+
elif "\n" not in span["text"]: # Occasionally a span can contain a newline character.
|
28 |
+
if round(span["rotation"]) in (0.0, 180.0, -180.0):
|
29 |
+
font_size = span["bbox"][3] - span["bbox"][1]
|
30 |
+
elif round(span["rotation"]) in (90.0, -90.0, 270.0, -270.0):
|
31 |
+
font_size = span["bbox"][2] - span["bbox"][0]
|
32 |
+
return font_size
|
33 |
+
|
34 |
+
# Copy the pages.
|
35 |
+
pages = deepcopy(pages)
|
36 |
+
# Extract an array of all font sizes used by the text spans.
|
37 |
+
font_sizes = np.asarray(
|
38 |
+
[
|
39 |
+
extract_font_size(span)
|
40 |
+
for page in pages
|
41 |
+
for block in page["blocks"]
|
42 |
+
for line in block["lines"]
|
43 |
+
for span in line["spans"]
|
44 |
+
]
|
45 |
+
)
|
46 |
+
font_sizes = np.round(font_sizes * 2) / 2
|
47 |
+
unique_font_sizes, counts = np.unique(font_sizes, return_counts=True)
|
48 |
+
# Determine the paragraph font size as the mode font size.
|
49 |
+
tiny = unique_font_sizes < min(5, np.max(unique_font_sizes))
|
50 |
+
counts[tiny] = -counts[tiny]
|
51 |
+
mode = np.argmax(counts)
|
52 |
+
counts[tiny] = -counts[tiny]
|
53 |
+
mode_font_size = unique_font_sizes[mode]
|
54 |
+
# Determine (at most) 6 heading font sizes by clustering font sizes larger than the mode.
|
55 |
+
heading_font_sizes = unique_font_sizes[mode + 1 :]
|
56 |
+
if len(heading_font_sizes) > 0:
|
57 |
+
heading_counts = counts[mode + 1 :]
|
58 |
+
kmeans = KMeans(n_clusters=min(6, len(heading_font_sizes)), random_state=42)
|
59 |
+
kmeans.fit(heading_font_sizes[:, np.newaxis], sample_weight=heading_counts)
|
60 |
+
heading_font_sizes = np.sort(np.ravel(kmeans.cluster_centers_))[::-1]
|
61 |
+
# Add heading level information to the text spans and lines.
|
62 |
+
for page in pages:
|
63 |
+
for block in page["blocks"]:
|
64 |
+
for line in block["lines"]:
|
65 |
+
if "md" not in line:
|
66 |
+
line["md"] = {}
|
67 |
+
heading_level = np.zeros(8) # 0-5: <h1>-<h6>, 6: <p>, 7: <small>
|
68 |
+
for span in line["spans"]:
|
69 |
+
if "md" not in span:
|
70 |
+
span["md"] = {}
|
71 |
+
span_font_size = extract_font_size(span)
|
72 |
+
if span_font_size < mode_font_size:
|
73 |
+
idx = 7
|
74 |
+
elif span_font_size == mode_font_size:
|
75 |
+
idx = 6
|
76 |
+
else:
|
77 |
+
idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment]
|
78 |
+
span["md"]["heading_level"] = idx + 1
|
79 |
+
heading_level[idx] += len(span["text"])
|
80 |
+
line["md"]["heading_level"] = np.argmax(heading_level) + 1
|
81 |
+
return pages
|
82 |
+
|
83 |
+
def add_emphasis_metadata(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
84 |
+
"""Add emphasis metadata such as bold and italic to a PDF parsed with pdftext."""
|
85 |
+
# Copy the pages.
|
86 |
+
pages = deepcopy(pages)
|
87 |
+
# Add emphasis metadata to the text spans.
|
88 |
+
for page in pages:
|
89 |
+
for block in page["blocks"]:
|
90 |
+
for line in block["lines"]:
|
91 |
+
if "md" not in line:
|
92 |
+
line["md"] = {}
|
93 |
+
for span in line["spans"]:
|
94 |
+
if "md" not in span:
|
95 |
+
span["md"] = {}
|
96 |
+
span["md"]["bold"] = span["font"]["weight"] > 500 # noqa: PLR2004
|
97 |
+
span["md"]["italic"] = "ital" in (span["font"]["name"] or "").lower()
|
98 |
+
line["md"]["bold"] = all(
|
99 |
+
span["md"]["bold"] for span in line["spans"] if span["text"].strip()
|
100 |
+
)
|
101 |
+
line["md"]["italic"] = all(
|
102 |
+
span["md"]["italic"] for span in line["spans"] if span["text"].strip()
|
103 |
+
)
|
104 |
+
return pages
|
105 |
+
|
106 |
+
def strip_page_numbers(pages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
107 |
+
"""Strip page numbers from a PDF parsed with pdftext."""
|
108 |
+
# Copy the pages.
|
109 |
+
pages = deepcopy(pages)
|
110 |
+
# Remove lines that only contain a page number.
|
111 |
+
for page in pages:
|
112 |
+
for block in page["blocks"]:
|
113 |
+
block["lines"] = [
|
114 |
+
line
|
115 |
+
for line in block["lines"]
|
116 |
+
if not re.match(
|
117 |
+
r"^\s*[#0]*\d+\s*$", "".join(span["text"] for span in line["spans"])
|
118 |
+
)
|
119 |
+
]
|
120 |
+
return pages
|
121 |
+
|
122 |
+
def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901, PLR0912
|
123 |
+
"""Convert a list of pages to Markdown."""
|
124 |
+
pages_md = []
|
125 |
+
for page in pages:
|
126 |
+
page_md = ""
|
127 |
+
for block in page["blocks"]:
|
128 |
+
block_text = ""
|
129 |
+
for line in block["lines"]:
|
130 |
+
# Build the line text and style the spans.
|
131 |
+
line_text = ""
|
132 |
+
for span in line["spans"]:
|
133 |
+
if (
|
134 |
+
not line["md"]["bold"]
|
135 |
+
and not line["md"]["italic"]
|
136 |
+
and span["md"]["bold"]
|
137 |
+
and span["md"]["italic"]
|
138 |
+
):
|
139 |
+
line_text += f"***{span['text']}***"
|
140 |
+
elif not line["md"]["bold"] and span["md"]["bold"]:
|
141 |
+
line_text += f"**{span['text']}**"
|
142 |
+
elif not line["md"]["italic"] and span["md"]["italic"]:
|
143 |
+
line_text += f"*{span['text']}*"
|
144 |
+
else:
|
145 |
+
line_text += span["text"]
|
146 |
+
# Add emphasis to the line (if it's not a heading or whitespace).
|
147 |
+
line_text = line_text.rstrip()
|
148 |
+
line_is_whitespace = not line_text.strip()
|
149 |
+
line_is_heading = line["md"]["heading_level"] <= 6 # noqa: PLR2004
|
150 |
+
if not line_is_heading and not line_is_whitespace:
|
151 |
+
if line["md"]["bold"] and line["md"]["italic"]:
|
152 |
+
line_text = f"***{line_text}***"
|
153 |
+
elif line["md"]["bold"]:
|
154 |
+
line_text = f"**{line_text}**"
|
155 |
+
elif line["md"]["italic"]:
|
156 |
+
line_text = f"*{line_text}*"
|
157 |
+
# Set the heading level.
|
158 |
+
if line_is_heading and not line_is_whitespace:
|
159 |
+
line_text = f"{'#' * line['md']['heading_level']} {line_text}"
|
160 |
+
line_text += "\n"
|
161 |
+
block_text += line_text
|
162 |
+
block_text = block_text.rstrip() + "\n\n"
|
163 |
+
page_md += block_text
|
164 |
+
pages_md.append(page_md.strip())
|
165 |
+
return pages_md
|
166 |
+
|
167 |
+
def merge_split_headings(pages: list[str]) -> list[str]:
|
168 |
+
"""Merge headings that are split across lines."""
|
169 |
+
|
170 |
+
def _merge_split_headings(match: re.Match[str]) -> str:
|
171 |
+
atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()]
|
172 |
+
return f"{match.group(1)} {' '.join(atx_headings)}\n\n"
|
173 |
+
|
174 |
+
pages_md = [
|
175 |
+
re.sub(
|
176 |
+
r"^(#+)[ \t]+[^\n]+\n+(?:^\1[ \t]+[^\n]+\n+)+",
|
177 |
+
_merge_split_headings,
|
178 |
+
page,
|
179 |
+
flags=re.MULTILINE,
|
180 |
+
)
|
181 |
+
for page in pages
|
182 |
+
]
|
183 |
+
return pages_md
|
184 |
+
|
185 |
+
# Add heading level metadata.
|
186 |
+
pages = add_heading_level_metadata(pages)
|
187 |
+
# Add emphasis metadata.
|
188 |
+
pages = add_emphasis_metadata(pages)
|
189 |
+
# Strip page numbers.
|
190 |
+
pages = strip_page_numbers(pages)
|
191 |
+
# Convert the pages to Markdown.
|
192 |
+
pages_md = convert_to_markdown(pages)
|
193 |
+
# Merge headings that are split across lines.
|
194 |
+
pages_md = merge_split_headings(pages_md)
|
195 |
+
return pages_md
|
196 |
+
|
197 |
+
|
198 |
+
def document_to_markdown(doc_path: Path) -> str:
|
199 |
+
"""Convert any document to GitHub Flavored Markdown."""
|
200 |
+
# Convert the file's content to GitHub Flavored Markdown.
|
201 |
+
if doc_path.suffix == ".pdf":
|
202 |
+
# Parse the PDF with pdftext and convert it to Markdown.
|
203 |
+
pages = dictionary_output(doc_path, sort=True, keep_chars=False)
|
204 |
+
doc = "\n\n".join(parsed_pdf_to_markdown(pages))
|
205 |
+
else:
|
206 |
+
try:
|
207 |
+
# Use pandoc for everything else.
|
208 |
+
import pypandoc
|
209 |
+
|
210 |
+
doc = pypandoc.convert_file(doc_path, to="gfm")
|
211 |
+
except ImportError as error:
|
212 |
+
error_message = (
|
213 |
+
"To convert files to Markdown with pandoc, please install the `pandoc` extra."
|
214 |
+
)
|
215 |
+
raise ImportError(error_message) from error
|
216 |
+
except RuntimeError:
|
217 |
+
# File format not supported, fall back to reading the text.
|
218 |
+
doc = doc_path.read_text()
|
219 |
+
# Improve Markdown quality.
|
220 |
+
doc = mdformat.text(doc)
|
221 |
+
return doc
|
src/raglite/_query_adapter.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Compute and update an optimal query adapter."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from sqlmodel import Session, col, select
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
|
7 |
+
from raglite._config import RAGLiteConfig
|
8 |
+
from raglite._database import Chunk, ChunkEmbedding, Eval, IndexMetadata, create_database_engine
|
9 |
+
from raglite._embed import embed_sentences
|
10 |
+
from raglite._search import vector_search
|
11 |
+
|
12 |
+
|
13 |
+
def update_query_adapter( # noqa: PLR0915, C901
|
14 |
+
*,
|
15 |
+
max_triplets: int = 4096,
|
16 |
+
max_triplets_per_eval: int = 64,
|
17 |
+
optimize_top_k: int = 40,
|
18 |
+
config: RAGLiteConfig | None = None,
|
19 |
+
) -> None:
|
20 |
+
"""Compute an optimal query adapter and update the database with it.
|
21 |
+
|
22 |
+
This function computes an optimal linear transform A, called a 'query adapter', that is used to
|
23 |
+
transform a query embedding q as A @ q before searching for the nearest neighbouring chunks in
|
24 |
+
order to improve the quality of the search results.
|
25 |
+
|
26 |
+
Given a set of triplets (qᵢ, pᵢ, nᵢ), we want to find the query adapter A that increases the
|
27 |
+
score pᵢ'qᵢ of the positive chunk pᵢ and decreases the score nᵢ'qᵢ of the negative chunk nᵢ.
|
28 |
+
|
29 |
+
If the nearest neighbour search uses the dot product as its relevance score, we can find the
|
30 |
+
optimal query adapter by solving the following relaxed Procrustes optimisation problem with a
|
31 |
+
bound on the Frobenius norm of A:
|
32 |
+
|
33 |
+
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
|
34 |
+
Σᵢ (pᵢ - nᵢ)' A qᵢ
|
35 |
+
trace[ (P - N) A Q' ] where Q := [q₁'; ...; qₖ']
|
36 |
+
P := [p₁'; ...; pₖ']
|
37 |
+
N := [n₁'; ...; nₖ']
|
38 |
+
trace[ Q' (P - N) A ]
|
39 |
+
trace[ M A ] where M := Q' (P - N)
|
40 |
+
s.t. ||A||_F == 1
|
41 |
+
= M' / ||M||_F
|
42 |
+
|
43 |
+
If the nearest neighbour search uses the cosine similarity as its relevance score, we can find
|
44 |
+
the optimal query adapter by solving the following orthogonal Procrustes optimisation problem
|
45 |
+
with an orthogonality constraint on A:
|
46 |
+
|
47 |
+
A* = argmax Σᵢ pᵢ' (A qᵢ) - nᵢ' (A qᵢ)
|
48 |
+
Σᵢ (pᵢ - nᵢ)' A qᵢ
|
49 |
+
trace[ (P - N) A Q' ]
|
50 |
+
trace[ Q' (P - N) A ]
|
51 |
+
trace[ M A ]
|
52 |
+
trace[ U Σ V' A ] where U Σ V' := M is the SVD of M
|
53 |
+
trace[ Σ V' A U ]
|
54 |
+
s.t. A'A == 𝕀
|
55 |
+
= V U'
|
56 |
+
|
57 |
+
Additionally, we want to limit the effect of A* so that it adjusts q just enough to invert
|
58 |
+
incorrectly ordered (q, p, n) triplets, but not so much as to affect the correctly ordered ones.
|
59 |
+
To achieve this, we'll rewrite M as α(M / s) + (1 - α)𝕀, where s scales M to the same norm as 𝕀,
|
60 |
+
and choose the smallest α that ranks (q, p, n) correctly. If α = 0, the relevance score gap
|
61 |
+
between an incorrect (p, n) pair would be B := (p - n)' q < 0. If α = 1, the relevance score gap
|
62 |
+
would be A := (p - n)' (p - n) / ||p - n|| > 0. For a target relevance score gap of say
|
63 |
+
C := 5% * A, the optimal α is then given by αA + (1 - α)B = C => α = (B - C) / (B - A).
|
64 |
+
"""
|
65 |
+
config = config or RAGLiteConfig()
|
66 |
+
config_no_query_adapter = RAGLiteConfig(
|
67 |
+
**{**config.__dict__, "vector_search_query_adapter": False}
|
68 |
+
)
|
69 |
+
engine = create_database_engine(config)
|
70 |
+
with Session(engine) as session:
|
71 |
+
# Get random evals from the database.
|
72 |
+
chunk_embedding = session.exec(select(ChunkEmbedding).limit(1)).first()
|
73 |
+
if chunk_embedding is None:
|
74 |
+
error_message = "First run `insert_document()` to insert documents."
|
75 |
+
raise ValueError(error_message)
|
76 |
+
evals = session.exec(
|
77 |
+
select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
|
78 |
+
).all()
|
79 |
+
if len(evals) * max_triplets_per_eval < len(chunk_embedding.embedding):
|
80 |
+
error_message = "First run `insert_evals()` to generate sufficient evals."
|
81 |
+
raise ValueError(error_message)
|
82 |
+
# Loop over the evals to generate (q, p, n) triplets.
|
83 |
+
Q = np.zeros((0, len(chunk_embedding.embedding))) # noqa: N806
|
84 |
+
P = np.zeros_like(Q) # noqa: N806
|
85 |
+
N = np.zeros_like(Q) # noqa: N806
|
86 |
+
for eval_ in tqdm(
|
87 |
+
evals, desc="Extracting triplets from evals", unit="eval", dynamic_ncols=True
|
88 |
+
):
|
89 |
+
# Embed the question.
|
90 |
+
question_embedding = embed_sentences([eval_.question], config=config)
|
91 |
+
# Retrieve chunks that would be used to answer the question.
|
92 |
+
chunk_ids, _ = vector_search(
|
93 |
+
question_embedding, num_results=optimize_top_k, config=config_no_query_adapter
|
94 |
+
)
|
95 |
+
retrieved_chunks = session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all()
|
96 |
+
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
|
97 |
+
num_triplets = 0
|
98 |
+
for i, retrieved_chunk in enumerate(retrieved_chunks):
|
99 |
+
# Select irrelevant chunks.
|
100 |
+
if retrieved_chunk.id not in eval_.chunk_ids:
|
101 |
+
# Look up all positive chunks (each represented by the mean of its multi-vector
|
102 |
+
# embedding) that are ranked lower than this negative one (represented by the
|
103 |
+
# embedding in the multi-vector embedding that best matches the query).
|
104 |
+
p_mean = [
|
105 |
+
np.mean(chunk.embedding_matrix, axis=0, keepdims=True)
|
106 |
+
for chunk in retrieved_chunks[i + 1 :]
|
107 |
+
if chunk is not None and chunk.id in eval_.chunk_ids
|
108 |
+
]
|
109 |
+
n_top = retrieved_chunk.embedding_matrix[
|
110 |
+
np.argmax(retrieved_chunk.embedding_matrix @ question_embedding.T),
|
111 |
+
np.newaxis,
|
112 |
+
:,
|
113 |
+
]
|
114 |
+
# Filter out any (p, n, q) triplets for which the mean positive embedding ranks
|
115 |
+
# higher than the top negative one.
|
116 |
+
p_mean = [p_e for p_e in p_mean if (n_top - p_e) @ question_embedding.T > 0]
|
117 |
+
if not p_mean:
|
118 |
+
continue
|
119 |
+
# Stack the (p, n, q) triplets.
|
120 |
+
p = np.vstack(p_mean)
|
121 |
+
n = np.repeat(n_top, p.shape[0], axis=0)
|
122 |
+
q = np.repeat(question_embedding, p.shape[0], axis=0)
|
123 |
+
num_triplets += p.shape[0]
|
124 |
+
# Append the (query, positive, negative) tuples to the Q, P, N matrices.
|
125 |
+
Q = np.vstack([Q, q]) # noqa: N806
|
126 |
+
P = np.vstack([P, p]) # noqa: N806
|
127 |
+
N = np.vstack([N, n]) # noqa: N806
|
128 |
+
# Check if we have sufficient triplets for this eval.
|
129 |
+
if num_triplets >= max_triplets_per_eval:
|
130 |
+
break
|
131 |
+
# Check if we have sufficient triplets to compute the query adapter.
|
132 |
+
if Q.shape[0] > max_triplets:
|
133 |
+
Q, P, N = Q[:max_triplets, :], P[:max_triplets, :], N[:max_triplets, :] # noqa: N806
|
134 |
+
break
|
135 |
+
# Normalise the rows of Q, P, N.
|
136 |
+
Q /= np.linalg.norm(Q, axis=1, keepdims=True) # noqa: N806
|
137 |
+
P /= np.linalg.norm(P, axis=1, keepdims=True) # noqa: N806
|
138 |
+
N /= np.linalg.norm(N, axis=1, keepdims=True) # noqa: N806
|
139 |
+
# Compute the optimal weighted query adapter A*.
|
140 |
+
# TODO: Matmul in float16 is extremely slow compared to single or double precision, why?
|
141 |
+
gap_before = np.sum((P - N) * Q, axis=1)
|
142 |
+
gap_after = 2 * (1 - np.sum(P * N, axis=1)) / np.linalg.norm(P - N, axis=1)
|
143 |
+
gap_target = 0.05 * gap_after
|
144 |
+
α = (gap_before - gap_target) / (gap_before - gap_after) # noqa: PLC2401
|
145 |
+
MT = (α[:, np.newaxis] * (P - N)).T @ Q # noqa: N806
|
146 |
+
s = np.linalg.norm(MT, ord="fro") / np.sqrt(MT.shape[0])
|
147 |
+
MT = np.mean(α) * (MT / s) + np.mean(1 - α) * np.eye(Q.shape[1]) # noqa: N806
|
148 |
+
if config.vector_search_index_metric == "dot":
|
149 |
+
# Use the relaxed Procrustes solution.
|
150 |
+
A_star = MT / np.linalg.norm(MT, ord="fro") # noqa: N806
|
151 |
+
elif config.vector_search_index_metric == "cosine":
|
152 |
+
# Use the orthogonal Procrustes solution.
|
153 |
+
U, _, VT = np.linalg.svd(MT, full_matrices=False) # noqa: N806
|
154 |
+
A_star = U @ VT # noqa: N806
|
155 |
+
else:
|
156 |
+
error_message = f"Unsupported ANN metric: {config.vector_search_index_metric}"
|
157 |
+
raise ValueError(error_message)
|
158 |
+
# Store the optimal query adapter in the database.
|
159 |
+
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
|
160 |
+
index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
|
161 |
+
session.add(index_metadata)
|
162 |
+
session.commit()
|
src/raglite/_rag.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Retrieval-augmented generation."""
|
2 |
+
|
3 |
+
from collections.abc import AsyncIterator, Iterator
|
4 |
+
|
5 |
+
from litellm import acompletion, completion, get_model_info # type: ignore[attr-defined]
|
6 |
+
|
7 |
+
from raglite._config import RAGLiteConfig
|
8 |
+
from raglite._database import Chunk
|
9 |
+
from raglite._litellm import LlamaCppPythonLLM
|
10 |
+
from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
|
11 |
+
from raglite._typing import SearchMethod
|
12 |
+
|
13 |
+
RAG_SYSTEM_PROMPT = """
|
14 |
+
You are a friendly and knowledgeable assistant that provides complete and insightful answers.
|
15 |
+
Answer the user's question using only the context below.
|
16 |
+
When responding, you MUST NOT reference the existence of the context, directly or indirectly.
|
17 |
+
Instead, you MUST treat the context as if its contents are entirely part of your working memory.
|
18 |
+
""".strip()
|
19 |
+
|
20 |
+
|
21 |
+
def _max_contexts(
|
22 |
+
prompt: str,
|
23 |
+
*,
|
24 |
+
max_contexts: int = 5,
|
25 |
+
context_neighbors: tuple[int, ...] | None = (-1, 1),
|
26 |
+
messages: list[dict[str, str]] | None = None,
|
27 |
+
config: RAGLiteConfig | None = None,
|
28 |
+
) -> int:
|
29 |
+
"""Determine the maximum number of contexts for RAG."""
|
30 |
+
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
|
31 |
+
# to date by loading that LLM.
|
32 |
+
config = config or RAGLiteConfig()
|
33 |
+
if config.llm.startswith("llama-cpp-python"):
|
34 |
+
_ = LlamaCppPythonLLM.llm(config.llm)
|
35 |
+
# Get the model's maximum context size.
|
36 |
+
llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
|
37 |
+
model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
|
38 |
+
max_tokens = model_info.get("max_tokens") or 2048
|
39 |
+
# Reduce the maximum number of contexts to take into account the LLM's context size.
|
40 |
+
max_context_tokens = (
|
41 |
+
max_tokens
|
42 |
+
- sum(len(message["content"]) // 3 for message in messages or []) # Previous messages.
|
43 |
+
- len(RAG_SYSTEM_PROMPT) // 3 # System prompt.
|
44 |
+
- len(prompt) // 3 # User prompt.
|
45 |
+
)
|
46 |
+
max_tokens_per_context = config.chunk_max_size // 3
|
47 |
+
max_tokens_per_context *= 1 + len(context_neighbors or [])
|
48 |
+
max_contexts = min(max_contexts, max_context_tokens // max_tokens_per_context)
|
49 |
+
if max_contexts <= 0:
|
50 |
+
error_message = "Not enough context tokens available for RAG."
|
51 |
+
raise ValueError(error_message)
|
52 |
+
return max_contexts
|
53 |
+
|
54 |
+
|
55 |
+
def _contexts( # noqa: PLR0913
|
56 |
+
prompt: str,
|
57 |
+
*,
|
58 |
+
max_contexts: int = 5,
|
59 |
+
context_neighbors: tuple[int, ...] | None = (-1, 1),
|
60 |
+
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
|
61 |
+
messages: list[dict[str, str]] | None = None,
|
62 |
+
config: RAGLiteConfig | None = None,
|
63 |
+
) -> list[str]:
|
64 |
+
"""Retrieve contexts for RAG."""
|
65 |
+
# Determine the maximum number of contexts.
|
66 |
+
max_contexts = _max_contexts(
|
67 |
+
prompt,
|
68 |
+
max_contexts=max_contexts,
|
69 |
+
context_neighbors=context_neighbors,
|
70 |
+
messages=messages,
|
71 |
+
config=config,
|
72 |
+
)
|
73 |
+
# Retrieve the top chunks.
|
74 |
+
config = config or RAGLiteConfig()
|
75 |
+
chunks: list[str] | list[Chunk]
|
76 |
+
if callable(search):
|
77 |
+
# If the user has configured a reranker, we retrieve extra contexts to rerank.
|
78 |
+
extra_contexts = 3 * max_contexts if config.reranker else 0
|
79 |
+
# Retrieve relevant contexts.
|
80 |
+
chunk_ids, _ = search(prompt, num_results=max_contexts + extra_contexts, config=config)
|
81 |
+
# Rerank the relevant contexts.
|
82 |
+
chunks = rerank_chunks(query=prompt, chunk_ids=chunk_ids, config=config)
|
83 |
+
else:
|
84 |
+
# The user has passed a list of chunk_ids or chunks directly.
|
85 |
+
chunks = search
|
86 |
+
# Extend the top contexts with their neighbors and group chunks into contiguous segments.
|
87 |
+
segments = retrieve_segments(chunks[:max_contexts], neighbors=context_neighbors, config=config)
|
88 |
+
return segments
|
89 |
+
|
90 |
+
|
91 |
+
def rag( # noqa: PLR0913
|
92 |
+
prompt: str,
|
93 |
+
*,
|
94 |
+
max_contexts: int = 5,
|
95 |
+
context_neighbors: tuple[int, ...] | None = (-1, 1),
|
96 |
+
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
|
97 |
+
messages: list[dict[str, str]] | None = None,
|
98 |
+
system_prompt: str = RAG_SYSTEM_PROMPT,
|
99 |
+
config: RAGLiteConfig | None = None,
|
100 |
+
) -> Iterator[str]:
|
101 |
+
"""Retrieval-augmented generation."""
|
102 |
+
# Get the contexts for RAG as contiguous segments of chunks.
|
103 |
+
config = config or RAGLiteConfig()
|
104 |
+
segments = _contexts(
|
105 |
+
prompt,
|
106 |
+
max_contexts=max_contexts,
|
107 |
+
context_neighbors=context_neighbors,
|
108 |
+
search=search,
|
109 |
+
config=config,
|
110 |
+
)
|
111 |
+
system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
|
112 |
+
f'<context index="{i}">\n{segment.strip()}\n</context>'
|
113 |
+
for i, segment in enumerate(segments)
|
114 |
+
)
|
115 |
+
# Stream the LLM response.
|
116 |
+
stream = completion(
|
117 |
+
model=config.llm,
|
118 |
+
messages=[
|
119 |
+
*(messages or []),
|
120 |
+
{"role": "system", "content": system_prompt},
|
121 |
+
{"role": "user", "content": prompt},
|
122 |
+
],
|
123 |
+
stream=True,
|
124 |
+
)
|
125 |
+
for output in stream:
|
126 |
+
token: str = output["choices"][0]["delta"].get("content") or ""
|
127 |
+
yield token
|
128 |
+
|
129 |
+
|
130 |
+
async def async_rag( # noqa: PLR0913
|
131 |
+
prompt: str,
|
132 |
+
*,
|
133 |
+
max_contexts: int = 5,
|
134 |
+
context_neighbors: tuple[int, ...] | None = (-1, 1),
|
135 |
+
search: SearchMethod | list[str] | list[Chunk] = hybrid_search,
|
136 |
+
messages: list[dict[str, str]] | None = None,
|
137 |
+
system_prompt: str = RAG_SYSTEM_PROMPT,
|
138 |
+
config: RAGLiteConfig | None = None,
|
139 |
+
) -> AsyncIterator[str]:
|
140 |
+
"""Retrieval-augmented generation."""
|
141 |
+
# Get the contexts for RAG as contiguous segments of chunks.
|
142 |
+
config = config or RAGLiteConfig()
|
143 |
+
segments = _contexts(
|
144 |
+
prompt,
|
145 |
+
max_contexts=max_contexts,
|
146 |
+
context_neighbors=context_neighbors,
|
147 |
+
search=search,
|
148 |
+
config=config,
|
149 |
+
)
|
150 |
+
system_prompt = f"{system_prompt}\n\n" + "\n\n".join(
|
151 |
+
f'<context index="{i}">\n{segment.strip()}\n</context>'
|
152 |
+
for i, segment in enumerate(segments)
|
153 |
+
)
|
154 |
+
# Stream the LLM response.
|
155 |
+
async_stream = await acompletion(
|
156 |
+
model=config.llm,
|
157 |
+
messages=[
|
158 |
+
*(messages or []),
|
159 |
+
{"role": "system", "content": system_prompt},
|
160 |
+
{"role": "user", "content": prompt},
|
161 |
+
],
|
162 |
+
stream=True,
|
163 |
+
)
|
164 |
+
async for output in async_stream:
|
165 |
+
token: str = output["choices"][0]["delta"].get("content") or ""
|
166 |
+
yield token
|
src/raglite/_search.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Query documents."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
import string
|
5 |
+
from collections import defaultdict
|
6 |
+
from collections.abc import Sequence
|
7 |
+
from itertools import groupby
|
8 |
+
from typing import cast
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from langdetect import detect
|
12 |
+
from sqlalchemy.engine import make_url
|
13 |
+
from sqlmodel import Session, and_, col, or_, select, text
|
14 |
+
|
15 |
+
from raglite._config import RAGLiteConfig
|
16 |
+
from raglite._database import Chunk, ChunkEmbedding, IndexMetadata, create_database_engine
|
17 |
+
from raglite._embed import embed_sentences
|
18 |
+
from raglite._typing import FloatMatrix
|
19 |
+
|
20 |
+
|
21 |
+
def vector_search(
|
22 |
+
query: str | FloatMatrix,
|
23 |
+
*,
|
24 |
+
num_results: int = 3,
|
25 |
+
config: RAGLiteConfig | None = None,
|
26 |
+
) -> tuple[list[str], list[float]]:
|
27 |
+
"""Search chunks using ANN vector search."""
|
28 |
+
# Read the config.
|
29 |
+
config = config or RAGLiteConfig()
|
30 |
+
db_backend = make_url(config.db_url).get_backend_name()
|
31 |
+
# Get the index metadata (including the query adapter, and in the case of SQLite, the index).
|
32 |
+
index_metadata = IndexMetadata.get("default", config=config)
|
33 |
+
# Embed the query.
|
34 |
+
query_embedding = (
|
35 |
+
embed_sentences([query], config=config)[0, :] if isinstance(query, str) else np.ravel(query)
|
36 |
+
)
|
37 |
+
# Apply the query adapter to the query embedding.
|
38 |
+
Q = index_metadata.get("query_adapter") # noqa: N806
|
39 |
+
if config.vector_search_query_adapter and Q is not None:
|
40 |
+
query_embedding = (Q @ query_embedding).astype(query_embedding.dtype)
|
41 |
+
# Search for the multi-vector chunk embeddings that are most similar to the query embedding.
|
42 |
+
if db_backend == "postgresql":
|
43 |
+
# Check that the selected metric is supported by pgvector.
|
44 |
+
metrics = {"cosine": "<=>", "dot": "<#>", "euclidean": "<->", "l1": "<+>", "l2": "<->"}
|
45 |
+
if config.vector_search_index_metric not in metrics:
|
46 |
+
error_message = f"Unsupported metric {config.vector_search_index_metric}."
|
47 |
+
raise ValueError(error_message)
|
48 |
+
# With pgvector, we can obtain the nearest neighbours and similarities with a single query.
|
49 |
+
engine = create_database_engine(config)
|
50 |
+
with Session(engine) as session:
|
51 |
+
distance_func = getattr(
|
52 |
+
ChunkEmbedding.embedding, f"{config.vector_search_index_metric}_distance"
|
53 |
+
)
|
54 |
+
distance = distance_func(query_embedding).label("distance")
|
55 |
+
results = session.exec(
|
56 |
+
select(ChunkEmbedding.chunk_id, distance).order_by(distance).limit(8 * num_results)
|
57 |
+
)
|
58 |
+
chunk_ids_, distance = zip(*results, strict=True)
|
59 |
+
chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance)
|
60 |
+
elif db_backend == "sqlite":
|
61 |
+
# Load the NNDescent index.
|
62 |
+
index = index_metadata.get("index")
|
63 |
+
ids = np.asarray(index_metadata.get("chunk_ids"))
|
64 |
+
cumsum = np.cumsum(np.asarray(index_metadata.get("chunk_sizes")))
|
65 |
+
# Find the neighbouring multi-vector indices.
|
66 |
+
from pynndescent import NNDescent
|
67 |
+
|
68 |
+
multi_vector_indices, distance = cast(NNDescent, index).query(
|
69 |
+
query_embedding[np.newaxis, :], k=8 * num_results
|
70 |
+
)
|
71 |
+
similarity = 1 - distance[0, :]
|
72 |
+
# Transform the multi-vector indices into chunk indices, and then to chunk ids.
|
73 |
+
chunk_indices = np.searchsorted(cumsum, multi_vector_indices[0, :], side="right") + 1
|
74 |
+
chunk_ids = np.asarray([ids[chunk_index - 1] for chunk_index in chunk_indices])
|
75 |
+
# Score each unique chunk id as the mean similarity of its multi-vector hits. Chunk ids with
|
76 |
+
# fewer hits are padded with the minimum similarity of the result set.
|
77 |
+
unique_chunk_ids, counts = np.unique(chunk_ids, return_counts=True)
|
78 |
+
score = np.full(
|
79 |
+
(len(unique_chunk_ids), np.max(counts)), np.min(similarity), dtype=similarity.dtype
|
80 |
+
)
|
81 |
+
for i, (unique_chunk_id, count) in enumerate(zip(unique_chunk_ids, counts, strict=True)):
|
82 |
+
score[i, :count] = similarity[chunk_ids == unique_chunk_id]
|
83 |
+
pooled_similarity = np.mean(score, axis=1)
|
84 |
+
# Sort the chunk ids by their adjusted similarity.
|
85 |
+
sorted_indices = np.argsort(pooled_similarity)[::-1]
|
86 |
+
unique_chunk_ids = unique_chunk_ids[sorted_indices][:num_results]
|
87 |
+
pooled_similarity = pooled_similarity[sorted_indices][:num_results]
|
88 |
+
return unique_chunk_ids.tolist(), pooled_similarity.tolist()
|
89 |
+
|
90 |
+
|
91 |
+
def keyword_search(
|
92 |
+
query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
|
93 |
+
) -> tuple[list[str], list[float]]:
|
94 |
+
"""Search chunks using BM25 keyword search."""
|
95 |
+
# Read the config.
|
96 |
+
config = config or RAGLiteConfig()
|
97 |
+
db_backend = make_url(config.db_url).get_backend_name()
|
98 |
+
# Connect to the database.
|
99 |
+
engine = create_database_engine(config)
|
100 |
+
with Session(engine) as session:
|
101 |
+
if db_backend == "postgresql":
|
102 |
+
# Convert the query to a tsquery [1].
|
103 |
+
# [1] https://www.postgresql.org/docs/current/textsearch-controls.html
|
104 |
+
query_escaped = re.sub(r"[&|!():<>\"]", " ", query)
|
105 |
+
tsv_query = " | ".join(query_escaped.split())
|
106 |
+
# Perform keyword search with tsvector.
|
107 |
+
statement = text("""
|
108 |
+
SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
|
109 |
+
FROM chunk
|
110 |
+
WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
|
111 |
+
ORDER BY score DESC
|
112 |
+
LIMIT :limit;
|
113 |
+
""")
|
114 |
+
results = session.execute(statement, params={"query": tsv_query, "limit": num_results})
|
115 |
+
elif db_backend == "sqlite":
|
116 |
+
# Convert the query to an FTS5 query [1].
|
117 |
+
# [1] https://www.sqlite.org/fts5.html#full_text_query_syntax
|
118 |
+
query_escaped = re.sub(f"[{re.escape(string.punctuation)}]", "", query)
|
119 |
+
fts5_query = " OR ".join(query_escaped.split())
|
120 |
+
# Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we
|
121 |
+
# negate them to make them positive.
|
122 |
+
# [1] https://www.sqlite.org/fts5.html#the_bm25_function
|
123 |
+
statement = text("""
|
124 |
+
SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score
|
125 |
+
FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid
|
126 |
+
WHERE keyword_search_chunk_index MATCH :match
|
127 |
+
ORDER BY score DESC
|
128 |
+
LIMIT :limit;
|
129 |
+
""")
|
130 |
+
results = session.execute(statement, params={"match": fts5_query, "limit": num_results})
|
131 |
+
# Unpack the results.
|
132 |
+
chunk_ids, keyword_score = zip(*results, strict=True)
|
133 |
+
chunk_ids, keyword_score = list(chunk_ids), list(keyword_score) # type: ignore[assignment]
|
134 |
+
return chunk_ids, keyword_score # type: ignore[return-value]
|
135 |
+
|
136 |
+
|
137 |
+
def reciprocal_rank_fusion(
|
138 |
+
rankings: list[list[str]], *, k: int = 60
|
139 |
+
) -> tuple[list[str], list[float]]:
|
140 |
+
"""Reciprocal Rank Fusion."""
|
141 |
+
# Compute the RRF score.
|
142 |
+
chunk_ids = {chunk_id for ranking in rankings for chunk_id in ranking}
|
143 |
+
chunk_id_score: defaultdict[str, float] = defaultdict(float)
|
144 |
+
for ranking in rankings:
|
145 |
+
chunk_id_index = {chunk_id: i for i, chunk_id in enumerate(ranking)}
|
146 |
+
for chunk_id in chunk_ids:
|
147 |
+
chunk_id_score[chunk_id] += 1 / (k + chunk_id_index.get(chunk_id, len(chunk_id_index)))
|
148 |
+
# Rank RRF results according to descending RRF score.
|
149 |
+
rrf_chunk_ids, rrf_score = zip(
|
150 |
+
*sorted(chunk_id_score.items(), key=lambda x: x[1], reverse=True), strict=True
|
151 |
+
)
|
152 |
+
return list(rrf_chunk_ids), list(rrf_score)
|
153 |
+
|
154 |
+
|
155 |
+
def hybrid_search(
|
156 |
+
query: str, *, num_results: int = 3, num_rerank: int = 100, config: RAGLiteConfig | None = None
|
157 |
+
) -> tuple[list[str], list[float]]:
|
158 |
+
"""Search chunks by combining ANN vector search with BM25 keyword search."""
|
159 |
+
# Run both searches.
|
160 |
+
vs_chunk_ids, _ = vector_search(query, num_results=num_rerank, config=config)
|
161 |
+
ks_chunk_ids, _ = keyword_search(query, num_results=num_rerank, config=config)
|
162 |
+
# Combine the results with Reciprocal Rank Fusion (RRF).
|
163 |
+
chunk_ids, hybrid_score = reciprocal_rank_fusion([vs_chunk_ids, ks_chunk_ids])
|
164 |
+
chunk_ids, hybrid_score = chunk_ids[:num_results], hybrid_score[:num_results]
|
165 |
+
return chunk_ids, hybrid_score
|
166 |
+
|
167 |
+
|
168 |
+
def retrieve_chunks(
|
169 |
+
chunk_ids: list[str],
|
170 |
+
*,
|
171 |
+
config: RAGLiteConfig | None = None,
|
172 |
+
) -> list[Chunk]:
|
173 |
+
"""Retrieve chunks by their ids."""
|
174 |
+
config = config or RAGLiteConfig()
|
175 |
+
engine = create_database_engine(config)
|
176 |
+
with Session(engine) as session:
|
177 |
+
chunks = list(session.exec(select(Chunk).where(col(Chunk.id).in_(chunk_ids))).all())
|
178 |
+
chunks = sorted(chunks, key=lambda chunk: chunk_ids.index(chunk.id))
|
179 |
+
return chunks
|
180 |
+
|
181 |
+
|
182 |
+
def retrieve_segments(
|
183 |
+
chunk_ids: list[str] | list[Chunk],
|
184 |
+
*,
|
185 |
+
neighbors: tuple[int, ...] | None = (-1, 1),
|
186 |
+
config: RAGLiteConfig | None = None,
|
187 |
+
) -> list[str]:
|
188 |
+
"""Group chunks into contiguous segments and retrieve them."""
|
189 |
+
# Retrieve the chunks.
|
190 |
+
config = config or RAGLiteConfig()
|
191 |
+
chunks: list[Chunk] = (
|
192 |
+
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
|
193 |
+
if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
|
194 |
+
else chunk_ids
|
195 |
+
)
|
196 |
+
# Extend the chunks with their neighbouring chunks.
|
197 |
+
if neighbors:
|
198 |
+
engine = create_database_engine(config)
|
199 |
+
with Session(engine) as session:
|
200 |
+
neighbor_conditions = [
|
201 |
+
and_(Chunk.document_id == chunk.document_id, Chunk.index == chunk.index + offset)
|
202 |
+
for chunk in chunks
|
203 |
+
for offset in neighbors
|
204 |
+
]
|
205 |
+
chunks += list(session.exec(select(Chunk).where(or_(*neighbor_conditions))).all())
|
206 |
+
# Keep only the unique chunks.
|
207 |
+
chunks = list(set(chunks))
|
208 |
+
# Sort the chunks by document_id and index (needed for groupby).
|
209 |
+
chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index))
|
210 |
+
# Group the chunks into contiguous segments.
|
211 |
+
segments: list[list[Chunk]] = []
|
212 |
+
for _, group in groupby(chunks, key=lambda chunk: chunk.document_id):
|
213 |
+
segment: list[Chunk] = []
|
214 |
+
for chunk in group:
|
215 |
+
if not segment or chunk.index == segment[-1].index + 1:
|
216 |
+
segment.append(chunk)
|
217 |
+
else:
|
218 |
+
segments.append(segment)
|
219 |
+
segment = [chunk]
|
220 |
+
segments.append(segment)
|
221 |
+
# Rank segments according to the aggregate relevance of their chunks.
|
222 |
+
chunk_id_to_score = {chunk.id: 1 / (i + 1) for i, chunk in enumerate(chunks)}
|
223 |
+
segments.sort(
|
224 |
+
key=lambda segment: sum(chunk_id_to_score.get(chunk.id, 0.0) for chunk in segment),
|
225 |
+
reverse=True,
|
226 |
+
)
|
227 |
+
# Convert the segments into strings.
|
228 |
+
segments = [
|
229 |
+
segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc]
|
230 |
+
for segment in segments
|
231 |
+
]
|
232 |
+
return segments # type: ignore[return-value]
|
233 |
+
|
234 |
+
|
235 |
+
def rerank_chunks(
|
236 |
+
query: str,
|
237 |
+
chunk_ids: list[str] | list[Chunk],
|
238 |
+
*,
|
239 |
+
config: RAGLiteConfig | None = None,
|
240 |
+
) -> list[Chunk]:
|
241 |
+
"""Rerank chunks according to their relevance to a given query."""
|
242 |
+
# Retrieve the chunks.
|
243 |
+
config = config or RAGLiteConfig()
|
244 |
+
chunks: list[Chunk] = (
|
245 |
+
retrieve_chunks(chunk_ids, config=config) # type: ignore[arg-type,assignment]
|
246 |
+
if all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
|
247 |
+
else chunk_ids
|
248 |
+
)
|
249 |
+
# Early exit if no reranker is configured.
|
250 |
+
if not config.reranker:
|
251 |
+
return chunks
|
252 |
+
# Select the reranker.
|
253 |
+
if isinstance(config.reranker, Sequence):
|
254 |
+
# Detect the languages of the chunks and queries.
|
255 |
+
langs = {detect(str(chunk)) for chunk in chunks}
|
256 |
+
langs.add(detect(query))
|
257 |
+
# If all chunks and the query are in the same language, use a language-specific reranker.
|
258 |
+
rerankers = dict(config.reranker)
|
259 |
+
if len(langs) == 1 and (lang := next(iter(langs))) in rerankers:
|
260 |
+
reranker = rerankers[lang]
|
261 |
+
else:
|
262 |
+
reranker = rerankers.get("other")
|
263 |
+
else:
|
264 |
+
# A specific reranker was configured.
|
265 |
+
reranker = config.reranker
|
266 |
+
# Rerank the chunks.
|
267 |
+
if reranker:
|
268 |
+
results = reranker.rank(query=query, docs=[str(chunk) for chunk in chunks])
|
269 |
+
chunks = [chunks[result.doc_id] for result in results.results]
|
270 |
+
return chunks
|
src/raglite/_split_chunks.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Split a document into semantic chunks."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from scipy.optimize import linprog
|
7 |
+
from scipy.sparse import coo_matrix
|
8 |
+
|
9 |
+
from raglite._typing import FloatMatrix
|
10 |
+
|
11 |
+
|
12 |
+
def split_chunks( # noqa: C901, PLR0915
|
13 |
+
sentences: list[str],
|
14 |
+
sentence_embeddings: FloatMatrix,
|
15 |
+
sentence_window_size: int = 3,
|
16 |
+
max_size: int = 1440,
|
17 |
+
) -> tuple[list[str], list[FloatMatrix]]:
|
18 |
+
"""Split sentences into optimal semantic chunks with corresponding sentence embeddings."""
|
19 |
+
# Validate the input.
|
20 |
+
sentence_length = np.asarray([len(sentence) for sentence in sentences])
|
21 |
+
if not np.all(sentence_length <= max_size):
|
22 |
+
error_message = "Sentence with length larger than chunk max_size detected."
|
23 |
+
raise ValueError(error_message)
|
24 |
+
if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0):
|
25 |
+
error_message = "Sentence embeddings with zero norm detected."
|
26 |
+
raise ValueError(error_message)
|
27 |
+
# Exit early if there is only one chunk to return.
|
28 |
+
if len(sentences) <= 1 or sum(sentence_length) <= max_size:
|
29 |
+
return ["".join(sentences)] if sentences else sentences, [sentence_embeddings]
|
30 |
+
# Normalise the sentence embeddings to unit norm.
|
31 |
+
X = sentence_embeddings.astype(np.float32) # noqa: N806
|
32 |
+
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
|
33 |
+
# Select nonoutlying sentences and remove the discourse vector.
|
34 |
+
q15, q85 = np.quantile(sentence_length, [0.15, 0.85])
|
35 |
+
nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85)
|
36 |
+
discourse = np.mean(X[nonoutlying_sentences, :], axis=0)
|
37 |
+
discourse = discourse / np.linalg.norm(discourse)
|
38 |
+
if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps):
|
39 |
+
X = X - np.outer(X @ discourse, discourse) # noqa: N806
|
40 |
+
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
|
41 |
+
# For each partition point in the list of sentences, compute the similarity of the windows
|
42 |
+
# before and after the partition point. Sentence embeddings are assumed to be of the sentence
|
43 |
+
# itself and at most the (sentence_window_size - 1) sentences that preceed it.
|
44 |
+
sentence_window_size = min(len(sentences) - 1, sentence_window_size)
|
45 |
+
windows_before = X[:-sentence_window_size]
|
46 |
+
windows_after = X[sentence_window_size:]
|
47 |
+
partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype)
|
48 |
+
partition_similarity[: len(windows_before)] = np.sum(windows_before * windows_after, axis=1)
|
49 |
+
# Make partition similarity nonnegative before modification and optimisation.
|
50 |
+
partition_similarity = np.maximum(
|
51 |
+
(partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps)
|
52 |
+
)
|
53 |
+
# Modify the partition similarity to encourage splitting on Markdown headings.
|
54 |
+
prev_sentence_is_heading = True
|
55 |
+
for i, sentence in enumerate(sentences[:-1]):
|
56 |
+
is_heading = bool(re.match(r"^#+\s", sentence.replace("\n", "").strip()))
|
57 |
+
if is_heading:
|
58 |
+
# Encourage splitting before a heading.
|
59 |
+
if not prev_sentence_is_heading:
|
60 |
+
partition_similarity[i - 1] = partition_similarity[i - 1] / 4
|
61 |
+
# Don't split immediately after a heading.
|
62 |
+
partition_similarity[i] = 1.0
|
63 |
+
prev_sentence_is_heading = is_heading
|
64 |
+
# Solve an optimisation problem to find the best partition points.
|
65 |
+
sentence_length_cumsum = np.cumsum(sentence_length)
|
66 |
+
row_indices = []
|
67 |
+
col_indices = []
|
68 |
+
data = []
|
69 |
+
for i in range(len(sentences) - 1):
|
70 |
+
r = sentence_length_cumsum[i - 1] if i > 0 else 0
|
71 |
+
idx = np.searchsorted(sentence_length_cumsum - r, max_size)
|
72 |
+
assert idx > i
|
73 |
+
if idx == len(sentence_length_cumsum):
|
74 |
+
break
|
75 |
+
cols = list(range(i, idx))
|
76 |
+
col_indices.extend(cols)
|
77 |
+
row_indices.extend([i] * len(cols))
|
78 |
+
data.extend([1] * len(cols))
|
79 |
+
A = coo_matrix( # noqa: N806
|
80 |
+
(data, (row_indices, col_indices)),
|
81 |
+
shape=(max(row_indices) + 1, len(sentences) - 1),
|
82 |
+
dtype=np.float32,
|
83 |
+
)
|
84 |
+
b_ub = np.ones(A.shape[0], dtype=np.float32)
|
85 |
+
res = linprog(
|
86 |
+
partition_similarity,
|
87 |
+
A_ub=-A,
|
88 |
+
b_ub=-b_ub,
|
89 |
+
bounds=(0, 1),
|
90 |
+
integrality=[1] * A.shape[1],
|
91 |
+
)
|
92 |
+
if not res.success:
|
93 |
+
error_message = "Optimization of chunk partitions failed."
|
94 |
+
raise ValueError(error_message)
|
95 |
+
# Split the sentences and their window embeddings into optimal chunks.
|
96 |
+
partition_indices = (np.where(res.x)[0] + 1).tolist()
|
97 |
+
chunks = [
|
98 |
+
"".join(sentences[i:j])
|
99 |
+
for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True)
|
100 |
+
]
|
101 |
+
chunk_embeddings = np.split(sentence_embeddings, partition_indices)
|
102 |
+
return chunks, chunk_embeddings
|
src/raglite/_split_sentences.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Sentence splitter."""
|
2 |
+
|
3 |
+
import re
|
4 |
+
|
5 |
+
import spacy
|
6 |
+
from markdown_it import MarkdownIt
|
7 |
+
from spacy.language import Language
|
8 |
+
|
9 |
+
|
10 |
+
@Language.component("_mark_additional_sentence_boundaries")
|
11 |
+
def _mark_additional_sentence_boundaries(doc: spacy.tokens.Doc) -> spacy.tokens.Doc:
|
12 |
+
"""Mark additional sentence boundaries in Markdown documents."""
|
13 |
+
|
14 |
+
def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]:
|
15 |
+
"""Get the indexes of the headings in a Markdown document."""
|
16 |
+
md = MarkdownIt()
|
17 |
+
tokens = md.parse(doc)
|
18 |
+
headings = []
|
19 |
+
lines = doc.splitlines(keepends=True)
|
20 |
+
char_idx = [0]
|
21 |
+
for line in lines:
|
22 |
+
char_idx.append(char_idx[-1] + len(line))
|
23 |
+
for token in tokens:
|
24 |
+
if token.type == "heading_open":
|
25 |
+
start_line, end_line = token.map # type: ignore[misc]
|
26 |
+
heading_start = char_idx[start_line]
|
27 |
+
heading_end = char_idx[end_line]
|
28 |
+
headings.append((heading_start, heading_end))
|
29 |
+
return headings
|
30 |
+
|
31 |
+
headings = get_markdown_heading_indexes(doc.text)
|
32 |
+
for heading_start, heading_end in headings:
|
33 |
+
# Mark the start of a heading as a new sentence.
|
34 |
+
for token in doc:
|
35 |
+
if heading_start <= token.idx:
|
36 |
+
token.is_sent_start = True
|
37 |
+
break
|
38 |
+
# Mark the end of a heading as a new sentence.
|
39 |
+
for token in doc:
|
40 |
+
if heading_end <= token.idx:
|
41 |
+
token.is_sent_start = True
|
42 |
+
break
|
43 |
+
return doc
|
44 |
+
|
45 |
+
|
46 |
+
def split_sentences(doc: str, max_len: int | None = None) -> list[str]:
|
47 |
+
"""Split a document into sentences."""
|
48 |
+
# Split sentences with spaCy.
|
49 |
+
try:
|
50 |
+
nlp = spacy.load("xx_sent_ud_sm")
|
51 |
+
except OSError as error:
|
52 |
+
error_message = "Please install `xx_sent_ud_sm` with `pip install https://github.com/explosion/spacy-models/releases/download/xx_sent_ud_sm-3.7.0/xx_sent_ud_sm-3.7.0-py3-none-any.whl`."
|
53 |
+
raise ImportError(error_message) from error
|
54 |
+
nlp.add_pipe("_mark_additional_sentence_boundaries", before="senter")
|
55 |
+
sentences = [sent.text_with_ws for sent in nlp(doc).sents if sent.text.strip()]
|
56 |
+
# Apply additional splits on paragraphs and sentences because spaCy's splitting is not perfect.
|
57 |
+
if max_len is not None:
|
58 |
+
for pattern in (r"(?<=\n\n)", r"(?<=\.\s)"):
|
59 |
+
sentences = [
|
60 |
+
part
|
61 |
+
for sent in sentences
|
62 |
+
for part in ([sent] if len(sent) <= max_len else re.split(pattern, sent))
|
63 |
+
]
|
64 |
+
# Recursively split long sentences in the middle if they are still too long.
|
65 |
+
if max_len is not None:
|
66 |
+
while any(len(sentence) > max_len for sentence in sentences):
|
67 |
+
sentences = [
|
68 |
+
part
|
69 |
+
for sent in sentences
|
70 |
+
for part in (
|
71 |
+
[sent]
|
72 |
+
if len(sent) <= max_len
|
73 |
+
else [sent[: len(sent) // 2], sent[len(sent) // 2 :]]
|
74 |
+
)
|
75 |
+
]
|
76 |
+
return sentences
|
src/raglite/_typing.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""RAGLite typing."""
|
2 |
+
|
3 |
+
import io
|
4 |
+
import pickle
|
5 |
+
from collections.abc import Callable
|
6 |
+
from typing import Any, Protocol
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from sqlalchemy.engine import Dialect
|
10 |
+
from sqlalchemy.sql.operators import Operators
|
11 |
+
from sqlalchemy.types import Float, LargeBinary, TypeDecorator, TypeEngine, UserDefinedType
|
12 |
+
|
13 |
+
from raglite._config import RAGLiteConfig
|
14 |
+
|
15 |
+
FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]]
|
16 |
+
FloatVector = np.ndarray[tuple[int], np.dtype[np.floating[Any]]]
|
17 |
+
IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]
|
18 |
+
|
19 |
+
|
20 |
+
class SearchMethod(Protocol):
|
21 |
+
def __call__(
|
22 |
+
self, query: str, *, num_results: int = 3, config: RAGLiteConfig | None = None
|
23 |
+
) -> tuple[list[str], list[float]]: ...
|
24 |
+
|
25 |
+
|
26 |
+
class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
|
27 |
+
"""A NumPy array column type for SQLAlchemy."""
|
28 |
+
|
29 |
+
impl = LargeBinary
|
30 |
+
|
31 |
+
def process_bind_param(
|
32 |
+
self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
|
33 |
+
) -> bytes | None:
|
34 |
+
"""Convert a NumPy array to bytes."""
|
35 |
+
if value is None:
|
36 |
+
return None
|
37 |
+
buffer = io.BytesIO()
|
38 |
+
np.save(buffer, value, allow_pickle=False, fix_imports=False)
|
39 |
+
return buffer.getvalue()
|
40 |
+
|
41 |
+
def process_result_value(
|
42 |
+
self, value: bytes | None, dialect: Dialect
|
43 |
+
) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
|
44 |
+
"""Convert bytes to a NumPy array."""
|
45 |
+
if value is None:
|
46 |
+
return None
|
47 |
+
return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]
|
48 |
+
|
49 |
+
|
50 |
+
class PickledObject(TypeDecorator[object]):
|
51 |
+
"""A pickled object column type for SQLAlchemy."""
|
52 |
+
|
53 |
+
impl = LargeBinary
|
54 |
+
|
55 |
+
def process_bind_param(self, value: object | None, dialect: Dialect) -> bytes | None:
|
56 |
+
"""Convert a Python object to bytes."""
|
57 |
+
if value is None:
|
58 |
+
return None
|
59 |
+
return pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL, fix_imports=False)
|
60 |
+
|
61 |
+
def process_result_value(self, value: bytes | None, dialect: Dialect) -> object | None:
|
62 |
+
"""Convert bytes to a Python object."""
|
63 |
+
if value is None:
|
64 |
+
return None
|
65 |
+
return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301
|
66 |
+
|
67 |
+
|
68 |
+
class HalfVecComparatorMixin(UserDefinedType.Comparator[FloatVector]):
|
69 |
+
"""A mixin that provides comparison operators for halfvecs."""
|
70 |
+
|
71 |
+
def cosine_distance(self, other: FloatVector) -> Operators:
|
72 |
+
"""Compute the cosine distance."""
|
73 |
+
return self.op("<=>", return_type=Float)(other)
|
74 |
+
|
75 |
+
def dot_distance(self, other: FloatVector) -> Operators:
|
76 |
+
"""Compute the dot product distance."""
|
77 |
+
return self.op("<#>", return_type=Float)(other)
|
78 |
+
|
79 |
+
def euclidean_distance(self, other: FloatVector) -> Operators:
|
80 |
+
"""Compute the Euclidean distance."""
|
81 |
+
return self.op("<->", return_type=Float)(other)
|
82 |
+
|
83 |
+
def l1_distance(self, other: FloatVector) -> Operators:
|
84 |
+
"""Compute the L1 distance."""
|
85 |
+
return self.op("<+>", return_type=Float)(other)
|
86 |
+
|
87 |
+
def l2_distance(self, other: FloatVector) -> Operators:
|
88 |
+
"""Compute the L2 distance."""
|
89 |
+
return self.op("<->", return_type=Float)(other)
|
90 |
+
|
91 |
+
|
92 |
+
class HalfVec(UserDefinedType[FloatVector]):
|
93 |
+
"""A PostgreSQL half-precision vector column type for SQLAlchemy."""
|
94 |
+
|
95 |
+
cache_ok = True # HalfVec is immutable.
|
96 |
+
|
97 |
+
def __init__(self, dim: int | None = None) -> None:
|
98 |
+
super().__init__()
|
99 |
+
self.dim = dim
|
100 |
+
|
101 |
+
def get_col_spec(self, **kwargs: Any) -> str:
|
102 |
+
return f"halfvec({self.dim})"
|
103 |
+
|
104 |
+
def bind_processor(self, dialect: Dialect) -> Callable[[FloatVector | None], str | None]:
|
105 |
+
"""Process NumPy ndarray to PostgreSQL halfvec format for bound parameters."""
|
106 |
+
|
107 |
+
def process(value: FloatVector | None) -> str | None:
|
108 |
+
return f"[{','.join(str(x) for x in np.ravel(value))}]" if value is not None else None
|
109 |
+
|
110 |
+
return process
|
111 |
+
|
112 |
+
def result_processor(
|
113 |
+
self, dialect: Dialect, coltype: Any
|
114 |
+
) -> Callable[[str | None], FloatVector | None]:
|
115 |
+
"""Process PostgreSQL halfvec format to NumPy ndarray."""
|
116 |
+
|
117 |
+
def process(value: str | None) -> FloatVector | None:
|
118 |
+
if value is None:
|
119 |
+
return None
|
120 |
+
return np.fromstring(value.strip("[]"), sep=",", dtype=np.float16)
|
121 |
+
|
122 |
+
return process
|
123 |
+
|
124 |
+
class comparator_factory(HalfVecComparatorMixin): # noqa: N801
|
125 |
+
...
|
126 |
+
|
127 |
+
|
128 |
+
class Embedding(TypeDecorator[FloatVector]):
|
129 |
+
"""An embedding column type for SQLAlchemy."""
|
130 |
+
|
131 |
+
cache_ok = True # Embedding is immutable.
|
132 |
+
|
133 |
+
impl = NumpyArray
|
134 |
+
|
135 |
+
def __init__(self, dim: int = -1):
|
136 |
+
super().__init__()
|
137 |
+
self.dim = dim
|
138 |
+
|
139 |
+
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[FloatVector]:
|
140 |
+
if dialect.name == "postgresql":
|
141 |
+
return dialect.type_descriptor(HalfVec(self.dim))
|
142 |
+
return dialect.type_descriptor(NumpyArray())
|
143 |
+
|
144 |
+
class comparator_factory(HalfVecComparatorMixin): # noqa: N801
|
145 |
+
...
|
src/raglite/py.typed
ADDED
File without changes
|
tests/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""RAGLite test suite."""
|
tests/conftest.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fixtures for the tests."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import socket
|
5 |
+
import tempfile
|
6 |
+
from collections.abc import Generator
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import pytest
|
10 |
+
from sqlalchemy import create_engine, text
|
11 |
+
|
12 |
+
from raglite import RAGLiteConfig, insert_document
|
13 |
+
|
14 |
+
POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"
|
15 |
+
|
16 |
+
|
17 |
+
def is_postgres_running() -> bool:
|
18 |
+
"""Check if PostgreSQL is running."""
|
19 |
+
try:
|
20 |
+
with socket.create_connection(("postgres", 5432), timeout=1):
|
21 |
+
return True
|
22 |
+
except OSError:
|
23 |
+
return False
|
24 |
+
|
25 |
+
|
26 |
+
def is_openai_available() -> bool:
|
27 |
+
"""Check if an OpenAI API key is set."""
|
28 |
+
return bool(os.environ.get("OPENAI_API_KEY"))
|
29 |
+
|
30 |
+
|
31 |
+
def pytest_sessionstart(session: pytest.Session) -> None:
|
32 |
+
"""Reset the PostgreSQL and SQLite databases."""
|
33 |
+
if is_postgres_running():
|
34 |
+
engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
|
35 |
+
with engine.connect() as conn:
|
36 |
+
for variant in ["local", "remote"]:
|
37 |
+
conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}"))
|
38 |
+
conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.fixture(scope="session")
|
42 |
+
def sqlite_url() -> Generator[str, None, None]:
|
43 |
+
"""Create a temporary SQLite database file and return the database URL."""
|
44 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
45 |
+
db_file = Path(temp_dir) / "raglite_test.sqlite"
|
46 |
+
yield f"sqlite:///{db_file}"
|
47 |
+
|
48 |
+
|
49 |
+
@pytest.fixture(
|
50 |
+
scope="session",
|
51 |
+
params=[
|
52 |
+
pytest.param("sqlite", id="sqlite"),
|
53 |
+
pytest.param(
|
54 |
+
POSTGRES_URL,
|
55 |
+
id="postgres",
|
56 |
+
marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
|
57 |
+
),
|
58 |
+
],
|
59 |
+
)
|
60 |
+
def database(request: pytest.FixtureRequest) -> str:
|
61 |
+
"""Get a database URL to test RAGLite with."""
|
62 |
+
db_url: str = (
|
63 |
+
request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
|
64 |
+
)
|
65 |
+
return db_url
|
66 |
+
|
67 |
+
|
68 |
+
@pytest.fixture(
|
69 |
+
scope="session",
|
70 |
+
params=[
|
71 |
+
pytest.param(
|
72 |
+
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
|
73 |
+
id="bge_m3",
|
74 |
+
),
|
75 |
+
pytest.param(
|
76 |
+
"text-embedding-3-small",
|
77 |
+
id="openai_text_embedding_3_small",
|
78 |
+
marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
|
79 |
+
),
|
80 |
+
],
|
81 |
+
)
|
82 |
+
def embedder(request: pytest.FixtureRequest) -> str:
|
83 |
+
"""Get an embedder model URL to test RAGLite with."""
|
84 |
+
embedder: str = request.param
|
85 |
+
return embedder
|
86 |
+
|
87 |
+
|
88 |
+
@pytest.fixture(scope="session")
|
89 |
+
def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
|
90 |
+
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
|
91 |
+
# Select the database based on the embedder.
|
92 |
+
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
|
93 |
+
if "postgres" in database:
|
94 |
+
database = database.replace("/postgres", f"/raglite_test_{variant}")
|
95 |
+
elif "sqlite" in database:
|
96 |
+
database = database.replace(".sqlite", f"_{variant}.sqlite")
|
97 |
+
# Create a RAGLite config for the given database and embedder.
|
98 |
+
db_config = RAGLiteConfig(db_url=database, embedder=embedder)
|
99 |
+
# Insert a document and update the index.
|
100 |
+
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
|
101 |
+
insert_document(doc_path, config=db_config)
|
102 |
+
return db_config
|
tests/specrel.pdf
ADDED
Binary file (178 kB). View file
|
|
tests/test_embed.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite's embedding functionality."""
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from raglite import RAGLiteConfig
|
8 |
+
from raglite._embed import embed_sentences
|
9 |
+
from raglite._markdown import document_to_markdown
|
10 |
+
from raglite._split_sentences import split_sentences
|
11 |
+
|
12 |
+
|
13 |
+
def test_embed(embedder: str) -> None:
|
14 |
+
"""Test embedding a document."""
|
15 |
+
raglite_test_config = RAGLiteConfig(embedder=embedder, embedder_normalize=True)
|
16 |
+
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
|
17 |
+
doc = document_to_markdown(doc_path)
|
18 |
+
sentences = split_sentences(doc, max_len=raglite_test_config.chunk_max_size)
|
19 |
+
sentence_embeddings = embed_sentences(sentences, config=raglite_test_config)
|
20 |
+
assert isinstance(sentences, list)
|
21 |
+
assert isinstance(sentence_embeddings, np.ndarray)
|
22 |
+
assert len(sentences) == len(sentence_embeddings)
|
23 |
+
assert sentence_embeddings.shape[1] >= 128 # noqa: PLR2004
|
24 |
+
assert sentence_embeddings.dtype == np.float16
|
25 |
+
assert np.all(np.isfinite(sentence_embeddings))
|
26 |
+
assert np.allclose(np.linalg.norm(sentence_embeddings, axis=1), 1.0, rtol=1e-3)
|
tests/test_import.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite."""
|
2 |
+
|
3 |
+
import raglite
|
4 |
+
|
5 |
+
|
6 |
+
def test_import() -> None:
|
7 |
+
"""Test that the package can be imported."""
|
8 |
+
assert isinstance(raglite.__name__, str)
|
tests/test_markdown.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test Markdown conversion."""
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from raglite._markdown import document_to_markdown
|
6 |
+
|
7 |
+
|
8 |
+
def test_pdf_with_missing_font_sizes() -> None:
|
9 |
+
"""Test conversion of a PDF with missing font sizes."""
|
10 |
+
# Convert a PDF whose parsed font sizes are all equal to 1.
|
11 |
+
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
|
12 |
+
doc = document_to_markdown(doc_path)
|
13 |
+
# Verify that we can reconstruct the font sizes and heading levels regardless of the missing
|
14 |
+
# font size data.
|
15 |
+
expected_heading = """
|
16 |
+
# ON THE ELECTRODYNAMICS OF MOVING BODIES
|
17 |
+
|
18 |
+
## By A. EINSTEIN June 30, 1905
|
19 |
+
|
20 |
+
It is known that Maxwell
|
21 |
+
""".strip()
|
22 |
+
assert doc.startswith(expected_heading)
|
tests/test_rag.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite's RAG functionality."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import TYPE_CHECKING
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
from llama_cpp import llama_supports_gpu_offload
|
8 |
+
|
9 |
+
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from raglite._database import Chunk
|
13 |
+
from raglite._typing import SearchMethod
|
14 |
+
|
15 |
+
|
16 |
+
def is_accelerator_available() -> bool:
|
17 |
+
"""Check if an accelerator is available."""
|
18 |
+
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
|
22 |
+
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
|
23 |
+
"""Test Retrieval-Augmented Generation."""
|
24 |
+
# Assemble different types of search inputs for RAG.
|
25 |
+
prompt = "What does it mean for two events to be simultaneous?"
|
26 |
+
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
|
27 |
+
hybrid_search, # A search method as input.
|
28 |
+
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
|
29 |
+
retrieve_chunks( # Chunks as input.
|
30 |
+
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
|
31 |
+
),
|
32 |
+
]
|
33 |
+
# Answer a question with RAG.
|
34 |
+
for search_input in search_inputs:
|
35 |
+
stream = rag(prompt, search=search_input, config=raglite_test_config)
|
36 |
+
answer = ""
|
37 |
+
for update in stream:
|
38 |
+
assert isinstance(update, str)
|
39 |
+
answer += update
|
40 |
+
assert "simultaneous" in answer.lower()
|
tests/test_rerank.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite's reranking functionality."""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from rerankers.models.ranker import BaseRanker
|
5 |
+
|
6 |
+
from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks
|
7 |
+
from raglite._database import Chunk
|
8 |
+
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.fixture(
|
12 |
+
params=[
|
13 |
+
pytest.param(None, id="no_reranker"),
|
14 |
+
pytest.param(FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0), id="flashrank_english"),
|
15 |
+
pytest.param(
|
16 |
+
(
|
17 |
+
("en", FlashRankRanker("ms-marco-MiniLM-L-12-v2", verbose=0)),
|
18 |
+
("other", FlashRankRanker("ms-marco-MultiBERT-L-12", verbose=0)),
|
19 |
+
),
|
20 |
+
id="flashrank_multilingual",
|
21 |
+
),
|
22 |
+
],
|
23 |
+
)
|
24 |
+
def reranker(
|
25 |
+
request: pytest.FixtureRequest,
|
26 |
+
) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None:
|
27 |
+
"""Get a reranker to test RAGLite with."""
|
28 |
+
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param
|
29 |
+
return reranker
|
30 |
+
|
31 |
+
|
32 |
+
def test_reranker(
|
33 |
+
raglite_test_config: RAGLiteConfig,
|
34 |
+
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None,
|
35 |
+
) -> None:
|
36 |
+
"""Test inserting a document, updating the indexes, and searching for a query."""
|
37 |
+
# Update the config with the reranker.
|
38 |
+
raglite_test_config = RAGLiteConfig(
|
39 |
+
db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker
|
40 |
+
)
|
41 |
+
# Search for a query.
|
42 |
+
query = "What does it mean for two events to be simultaneous?"
|
43 |
+
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config)
|
44 |
+
# Retrieve the chunks.
|
45 |
+
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
|
46 |
+
assert all(isinstance(chunk, Chunk) for chunk in chunks)
|
47 |
+
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
|
48 |
+
# Rerank the chunks given an inverted chunk order.
|
49 |
+
reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config)
|
50 |
+
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
|
51 |
+
assert reranked_chunks[0] == chunks[0]
|
52 |
+
# Test that we can also rerank given the chunk_ids only.
|
53 |
+
reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config)
|
54 |
+
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder:
|
55 |
+
assert reranked_chunks[0] == chunks[0]
|
tests/test_search.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite's search functionality."""
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from raglite import (
|
6 |
+
RAGLiteConfig,
|
7 |
+
hybrid_search,
|
8 |
+
keyword_search,
|
9 |
+
retrieve_chunks,
|
10 |
+
retrieve_segments,
|
11 |
+
vector_search,
|
12 |
+
)
|
13 |
+
from raglite._database import Chunk
|
14 |
+
from raglite._typing import SearchMethod
|
15 |
+
|
16 |
+
|
17 |
+
@pytest.fixture(
|
18 |
+
params=[
|
19 |
+
pytest.param(keyword_search, id="keyword_search"),
|
20 |
+
pytest.param(vector_search, id="vector_search"),
|
21 |
+
pytest.param(hybrid_search, id="hybrid_search"),
|
22 |
+
],
|
23 |
+
)
|
24 |
+
def search_method(
|
25 |
+
request: pytest.FixtureRequest,
|
26 |
+
) -> SearchMethod:
|
27 |
+
"""Get a search method to test RAGLite with."""
|
28 |
+
search_method: SearchMethod = request.param
|
29 |
+
return search_method
|
30 |
+
|
31 |
+
|
32 |
+
def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
|
33 |
+
"""Test searching for a query."""
|
34 |
+
# Search for a query.
|
35 |
+
query = "What does it mean for two events to be simultaneous?"
|
36 |
+
num_results = 5
|
37 |
+
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
|
38 |
+
assert len(chunk_ids) == len(scores) == num_results
|
39 |
+
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
|
40 |
+
assert all(isinstance(score, float) for score in scores)
|
41 |
+
# Retrieve the chunks.
|
42 |
+
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
|
43 |
+
assert all(isinstance(chunk, Chunk) for chunk in chunks)
|
44 |
+
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
|
45 |
+
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
|
46 |
+
# Extend the chunks with their neighbours and group them into contiguous segments.
|
47 |
+
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
|
48 |
+
assert all(isinstance(segment, str) for segment in segments)
|
tests/test_split_chunks.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test RAGLite's chunk splitting functionality."""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from raglite._split_chunks import split_chunks
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.mark.parametrize(
|
10 |
+
"sentences",
|
11 |
+
[
|
12 |
+
pytest.param([], id="one_chunk:no_sentences"),
|
13 |
+
pytest.param(["Hello world"], id="one_chunk:one_sentence"),
|
14 |
+
pytest.param(["Hello world"] * 2, id="one_chunk:two_sentences"),
|
15 |
+
pytest.param(["Hello world"] * 3, id="one_chunk:three_sentences"),
|
16 |
+
pytest.param(["Hello world"] * 100, id="one_chunk:many_sentences"),
|
17 |
+
pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_sentences_a"),
|
18 |
+
pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_sentences_b"),
|
19 |
+
pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_sentences_a"),
|
20 |
+
pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_sentences_b"),
|
21 |
+
pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_sentences_c"),
|
22 |
+
pytest.param(["X" * 1000] * 100, id="n_chunks:many_sentences_a"),
|
23 |
+
pytest.param(["X" * 100] * 1000, id="n_chunks:many_sentences_b"),
|
24 |
+
],
|
25 |
+
)
|
26 |
+
def test_edge_cases(sentences: list[str]) -> None:
|
27 |
+
"""Test chunk splitting edge cases."""
|
28 |
+
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
|
29 |
+
chunks, chunk_embeddings = split_chunks(
|
30 |
+
sentences, sentence_embeddings, sentence_window_size=3, max_size=1440
|
31 |
+
)
|
32 |
+
assert isinstance(chunks, list)
|
33 |
+
assert isinstance(chunk_embeddings, list)
|
34 |
+
assert len(chunk_embeddings) == (len(chunks) if sentences else 1)
|
35 |
+
assert all(isinstance(chunk, str) for chunk in chunks)
|
36 |
+
assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
|
37 |
+
assert all(ce.dtype == sentence_embeddings.dtype for ce in chunk_embeddings)
|
38 |
+
assert sum(ce.shape[0] for ce in chunk_embeddings) == sentence_embeddings.shape[0]
|
39 |
+
assert all(ce.shape[1] == sentence_embeddings.shape[1] for ce in chunk_embeddings)
|
40 |
+
|
41 |
+
|
42 |
+
@pytest.mark.parametrize(
|
43 |
+
"sentences",
|
44 |
+
[
|
45 |
+
pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
|
46 |
+
pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
|
47 |
+
pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
|
48 |
+
],
|
49 |
+
)
|
50 |
+
def test_long_sentence(sentences: list[str]) -> None:
|
51 |
+
"""Test chunking on sentences that are too long."""
|
52 |
+
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
|
53 |
+
with pytest.raises(
|
54 |
+
ValueError, match="Sentence with length larger than chunk max_size detected."
|
55 |
+
):
|
56 |
+
_ = split_chunks(sentences, sentence_embeddings, sentence_window_size=3, max_size=1440)
|