kye commited on
Commit
bd01535
·
1 Parent(s): d15e424

Upload 79 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Andromeda/.DS_Store +0 -0
  3. Andromeda/.env +6 -0
  4. Andromeda/.github/ISSUE_TEMPLATE/---bug-report.md +36 -0
  5. Andromeda/.github/ISSUE_TEMPLATE/---feature-request.md +25 -0
  6. Andromeda/.github/ISSUE_TEMPLATE/---model-questions.md +17 -0
  7. Andromeda/.github/mcp/mcp_pytest.py +139 -0
  8. Andromeda/.github/workflows/FUNDING.md +13 -0
  9. Andromeda/.github/workflows/code-quality.yaml +44 -0
  10. Andromeda/.github/workflows/codeql-analysis.yml +70 -0
  11. Andromeda/.github/workflows/coverage.yaml +32 -0
  12. Andromeda/.github/workflows/docker.yaml +62 -0
  13. Andromeda/.github/workflows/pr-cpu.yaml +43 -0
  14. Andromeda/.github/workflows/pr-gpu.yaml +40 -0
  15. Andromeda/.github/workflows/pytest-cpu.yaml +48 -0
  16. Andromeda/.github/workflows/pytest-gpu.yaml +80 -0
  17. Andromeda/.github/workflows/python-publish.yml +39 -0
  18. Andromeda/.github/workflows/release.yaml +60 -0
  19. Andromeda/.gitignore +2 -0
  20. Andromeda/Andromeda/README.md +121 -0
  21. Andromeda/Andromeda/__init__.py +3 -0
  22. Andromeda/Andromeda/configs.py +128 -0
  23. Andromeda/Andromeda/core/__init__.py +8 -0
  24. Andromeda/Andromeda/core/attend.py +252 -0
  25. Andromeda/Andromeda/core/autoregressive_wrapper.py +150 -0
  26. Andromeda/Andromeda/core/flash.py +289 -0
  27. Andromeda/Andromeda/core/transformer.py +1376 -0
  28. Andromeda/Andromeda/dataset_prep/__init__.py +0 -0
  29. Andromeda/Andromeda/dataset_prep/books.py +12 -0
  30. Andromeda/Andromeda/inference.py +198 -0
  31. Andromeda/Andromeda/model.py +118 -0
  32. Andromeda/Andromeda/old/__init__.py +0 -0
  33. Andromeda/Andromeda/old/sophia.py +200 -0
  34. Andromeda/Andromeda/old/training.py +294 -0
  35. Andromeda/Andromeda/old/training_1.py +350 -0
  36. Andromeda/Andromeda/old/training_sophia.py +369 -0
  37. Andromeda/Andromeda/train.py +700 -0
  38. Andromeda/Andromeda/utils/__init__.py +0 -0
  39. Andromeda/Andromeda/utils/decoupled_optimizer.py +147 -0
  40. Andromeda/Andromeda/utils/helpers.py +17 -0
  41. Andromeda/Andromeda/utils/rf_utils.py +186 -0
  42. Andromeda/Andromeda/utils/stable_adamw.py +96 -0
  43. Andromeda/DOCs/Corporation/MONETIZATION.md +51 -0
  44. Andromeda/DOCs/Design/Dyson.md +26 -0
  45. Andromeda/DOCs/Design/MODEL_ARCHITECTURE.md +57 -0
  46. Andromeda/DOCs/Design/SPEED.md +11 -0
  47. Andromeda/DOCs/Design/Specs.md +196 -0
  48. Andromeda/DOCs/Docs/DOCUMENTATION.md +145 -0
  49. Andromeda/DOCs/Docs/TRAINING.md +82 -0
  50. Andromeda/DOCs/Docs/Training/DATASET_STRATEGY.md +100 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ Andromeda/images/andromeda-banner.png filter=lfs diff=lfs merge=lfs -text
Andromeda/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Andromeda/.env ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ MASTER_ADDR=""
2
+ MASTER_PORT=""
3
+ RANK=""
4
+ WORLD_SIZE=""
5
+ # export TORCH_CPP_LOG_LEVEL=INFO NCCL_DEBUG=INFO
6
+
Andromeda/.github/ISSUE_TEMPLATE/---bug-report.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F41B Bug report"
3
+ about: Submit a bug report to improve our library!
4
+ title: ''
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ <!-- Please check for related issues (both open and closed) before filing this issue. -->
11
+
12
+ ## Environment
13
+ <!-- Please copy paste the output of running `composer_collect_env` below-->
14
+ <!--
15
+ If you can't install composer for some reason, you can also use the PyTorch collect env script
16
+
17
+ wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
18
+ # For security purposes, please check the contents of collect_env.py before running it.
19
+ python collect_env.py
20
+ -->
21
+
22
+ ## To reproduce
23
+
24
+ Steps to reproduce the behavior:
25
+
26
+ 1.
27
+ 2.
28
+ 3.
29
+
30
+ ## Expected behavior
31
+
32
+ <!-- A clear and concise description of what you would expect to happen. -->
33
+
34
+ ## Additional context
35
+
36
+ <!-- Please provide any additional context. -->
Andromeda/.github/ISSUE_TEMPLATE/---feature-request.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F680 Feature request"
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: enhancement
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ <!-- Please check for related feature requests (both open and closed) before filing this request. -->
11
+
12
+ ## 🚀 Feature Request
13
+ <!-- A clear and concise description of the feature proposal -->
14
+
15
+ ## Motivation
16
+
17
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
18
+
19
+ ## [Optional] Implementation
20
+
21
+ <!-- Optionally, sketch out an implementation or interface needed. -->
22
+
23
+ ## Additional context
24
+
25
+ <!-- Add any other context or screenshots about the feature request here. -->
Andromeda/.github/ISSUE_TEMPLATE/---model-questions.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U00002753 Model-related question"
3
+ about: Ask a question about using our released models
4
+ title: ''
5
+ labels: question
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ <!-- Please check for related question (both open and closed) before filing this question. -->
11
+
12
+ ## ❓ Question
13
+ <!-- A clear and concise description of the question -->
14
+
15
+ ## Additional context
16
+
17
+ <!-- Add any other context or screenshots about the feature request here. -->
Andromeda/.github/mcp/mcp_pytest.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Run pytest using MCP."""
5
+
6
+ import argparse
7
+ import time
8
+
9
+ from mcli.sdk import (RunConfig, RunStatus, create_run, follow_run_logs,
10
+ stop_run, wait_for_run_status)
11
+
12
+ if __name__ == '__main__':
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--name',
16
+ type=str,
17
+ default='mcp-pytest',
18
+ help='Base name of run')
19
+ parser.add_argument('--cluster',
20
+ type=str,
21
+ default='r1z4',
22
+ help='Cluster to use')
23
+ parser.add_argument('--gpu_type',
24
+ type=str,
25
+ default='a100_40gb',
26
+ help='Type of GPU to use')
27
+ parser.add_argument('--gpu_num',
28
+ type=int,
29
+ default=2,
30
+ help='Number of the GPU to use')
31
+ parser.add_argument('--image',
32
+ type=str,
33
+ default='mosaicml/pytorch:latest',
34
+ help='Docker image to use')
35
+ parser.add_argument('--git_branch',
36
+ type=str,
37
+ help='Git branch to check out')
38
+ parser.add_argument(
39
+ '--git_commit',
40
+ type=str,
41
+ help='Git commit to check out. Overrides git_branch if specified')
42
+ parser.add_argument(
43
+ '--pr_number',
44
+ type=int,
45
+ help=
46
+ 'PR number to check out. Overrides git_branch/git_commit if specified')
47
+ parser.add_argument('--pytest_markers',
48
+ type=str,
49
+ help='Markers to pass to pytest')
50
+ parser.add_argument('--pytest_command',
51
+ type=str,
52
+ help='Command to run pytest')
53
+ parser.add_argument('--timeout',
54
+ type=int,
55
+ default=1800,
56
+ help='Timeout for run (in seconds)')
57
+ args = parser.parse_args()
58
+
59
+ name = args.name
60
+ git_integration = {
61
+ 'integration_type': 'git_repo',
62
+ 'git_repo': 'mosaicml/llm-foundry',
63
+ 'ssh_clone': 'False',
64
+ }
65
+ if args.git_branch is not None and args.git_commit is None:
66
+ name += f'-branch-{args.git_branch}'
67
+ git_integration['git_branch'] = args.git_branch
68
+ if args.git_commit is not None:
69
+ name += f'-commit-{args.git_commit}'
70
+ git_integration['git_commit'] = args.git_commit
71
+
72
+ command = 'cd llm-foundry'
73
+
74
+ # Checkout a specific PR if specified
75
+ if args.pr_number is not None:
76
+ name += f'-pr-{args.pr_number}'
77
+ command += f'''
78
+
79
+ git fetch origin pull/{args.pr_number}/head:pr_branch
80
+
81
+ git checkout pr_branch
82
+
83
+ '''
84
+
85
+ # Shorten name if too long
86
+ if len(name) > 56:
87
+ name = name[:56]
88
+
89
+ command += f'''
90
+
91
+ pip install --upgrade --user .[all]
92
+
93
+ export COMMON_ARGS="-v --durations=20 -m '{args.pytest_markers}'"
94
+
95
+ make test PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
96
+
97
+ make test-dist PYTEST='{args.pytest_command}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
98
+
99
+ python -m coverage combine
100
+
101
+ python -m coverage report
102
+ '''
103
+
104
+ config = RunConfig(
105
+ name=name,
106
+ cluster=args.cluster,
107
+ gpu_type=args.gpu_type,
108
+ gpu_num=args.gpu_num,
109
+ image=args.image,
110
+ integrations=[git_integration],
111
+ command=command,
112
+ )
113
+
114
+ # Create run
115
+ run = create_run(config)
116
+ print(f'[GHA] Run created: {run.name}')
117
+
118
+ # Wait until run starts before fetching logs
119
+ run = wait_for_run_status(run, status='running')
120
+ start_time = time.time()
121
+ print('[GHA] Run started. Following logs...')
122
+
123
+ # Print logs
124
+ for line in follow_run_logs(run):
125
+ print(line, end='')
126
+ # Check if args.timeout seconds have elapsed
127
+ if time.time() - start_time > args.timeout:
128
+ print(
129
+ f'[GHA] Run timed out and did not complete in {args.timeout/60} minutes.'
130
+ )
131
+ run = stop_run(run)
132
+ print('[GHA] Run stopped.')
133
+ break
134
+
135
+ print('[GHA] Run completed. Waiting for run to finish...')
136
+ run = wait_for_run_status(run, status='completed')
137
+
138
+ # Fail if command exited with non-zero exit code or timed out
139
+ assert run.status == RunStatus.COMPLETED
Andromeda/.github/workflows/FUNDING.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: [kyegomez]
4
+ patreon: # Replace with a single Patreon username
5
+ open_collective: # Replace with a single Open Collective username
6
+ ko_fi: # Replace with a single Ko-fi username
7
+ tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8
+ community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9
+ liberapay: # Replace with a single Liberapay username
10
+ issuehunt: # Replace with a single IssueHunt username
11
+ otechie: # Replace with a single Otechie username
12
+ lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13
+ custom: #Nothing
Andromeda/.github/workflows/code-quality.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Code Quality Checks
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ - release/**
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - release/**
11
+ workflow_call:
12
+ workflow_dispatch:
13
+ # Cancel old runs when a new commit is pushed to the same branch if not on main or dev
14
+ concurrency:
15
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
16
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
17
+ defaults:
18
+ run:
19
+ working-directory: .
20
+ jobs:
21
+ code-quality:
22
+ runs-on: ubuntu-20.04
23
+ timeout-minutes: 10
24
+ strategy:
25
+ matrix:
26
+ python_version:
27
+ - '3.8'
28
+ - '3.9'
29
+ - '3.10'
30
+ pip_deps:
31
+ - '[dev]'
32
+ steps:
33
+ - uses: actions/checkout@v3
34
+ - uses: actions/setup-python@v4
35
+ with:
36
+ python-version: ${{ matrix.python_version }}
37
+ - name: Setup
38
+ run: |
39
+ set -ex
40
+ python -m pip install --upgrade 'pip<23' wheel
41
+ python -m pip install --upgrade .${{ matrix.pip_deps }}
42
+ - name: Run checks
43
+ run: |
44
+ pre-commit run --all-files
Andromeda/.github/workflows/codeql-analysis.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For most projects, this workflow file will not need changing; you simply need
2
+ # to commit it to your repository.
3
+ #
4
+ # You may wish to alter this file to override the set of languages analyzed,
5
+ # or to provide custom queries or build logic.
6
+ #
7
+ # ******** NOTE ********
8
+ # We have attempted to detect the languages in your repository. Please check
9
+ # the `language` matrix defined below to confirm you have the correct set of
10
+ # supported CodeQL languages.
11
+ #
12
+ name: 'CodeQL'
13
+
14
+ on:
15
+ push:
16
+ branches: [main]
17
+ pull_request:
18
+ # The branches below must be a subset of the branches above
19
+ branches: [main]
20
+ schedule:
21
+ - cron: '0 9 * * 1' # Every Monday at 09:00 (9:00 AM)
22
+
23
+ jobs:
24
+ analyze:
25
+ name: Analyze
26
+ runs-on: ubuntu-latest
27
+ permissions:
28
+ actions: read
29
+ contents: read
30
+ security-events: write
31
+
32
+ strategy:
33
+ fail-fast: false
34
+ matrix:
35
+ language: ['python']
36
+ # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37
+ # Learn more about CodeQL language support at https://git.io/codeql-language-support
38
+
39
+ steps:
40
+ - name: Checkout repository
41
+ uses: actions/checkout@v2
42
+
43
+ # Initializes the CodeQL tools for scanning.
44
+ - name: Initialize CodeQL
45
+ uses: github/codeql-action/init@v2
46
+ with:
47
+ languages: ${{ matrix.language }}
48
+ # If you wish to specify custom queries, you can do so here or in a config file.
49
+ # By default, queries listed here will override any specified in a config file.
50
+ # Prefix the list here with "+" to use these queries and those in the config file.
51
+ # queries: ./path/to/local/query, your-org/your-repo/queries@main
52
+
53
+ # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
54
+ # If this step fails, then you should remove it and run the build manually (see below)
55
+ - name: Autobuild
56
+ uses: github/codeql-action/autobuild@v2
57
+
58
+ # ℹ️ Command-line programs to run using the OS shell.
59
+ # 📚 https://git.io/JvXDl
60
+
61
+ # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
62
+ # and modify them (or add more) to build your code if your project
63
+ # uses a compiled language
64
+
65
+ # - run: |
66
+ # make bootstrap
67
+ # make release
68
+
69
+ - name: Perform CodeQL Analysis
70
+ uses: github/codeql-action/analyze@v2
Andromeda/.github/workflows/coverage.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PyTest Coverage
2
+ on:
3
+ workflow_call:
4
+ inputs:
5
+ download-path:
6
+ required: true
7
+ type: string
8
+ jobs:
9
+ coverage:
10
+ timeout-minutes: 5
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Checkout Repo
14
+ uses: actions/checkout@v3
15
+ - name: Setup
16
+ run: |
17
+ set -ex
18
+ python -m pip install --upgrade 'pip<23' wheel
19
+ pip install coverage[toml]==6.5.0
20
+ - name: Download artifacts
21
+ uses: actions/download-artifact@v3
22
+ with:
23
+ path: ${{ inputs.download-path }}
24
+ - name: Generate coverage report
25
+ run: |
26
+ set -ex
27
+
28
+ # Flatten the coverage files
29
+ ls ${{ inputs.download-path }} | while read x; do mv ${{ inputs.download-path }}/$x/.coverage .coverage.$x; done
30
+
31
+ python -m coverage combine
32
+ python -m coverage report
Andromeda/.github/workflows/docker.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Docker
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ workflow_dispatch: {}
7
+ jobs:
8
+ docker-build:
9
+ runs-on: ubuntu-latest
10
+ if: github.repository_owner == 'mosaicml'
11
+ strategy:
12
+ matrix:
13
+ include:
14
+ - name: '1.13.1_cu117'
15
+ base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
16
+ - name: '2.0.1_cu118'
17
+ base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
18
+
19
+ steps:
20
+ - name: Maximize Build Space on Worker
21
+ uses: easimon/maximize-build-space@v4
22
+ with:
23
+ overprovision-lvm: true
24
+ remove-dotnet: true
25
+ remove-android: true
26
+ remove-haskell: true
27
+
28
+ - name: Checkout
29
+ uses: actions/checkout@v3
30
+
31
+ - name: Setup QEMU
32
+ uses: docker/setup-qemu-action@v2
33
+
34
+ - name: Setup Docker Buildx
35
+ uses: docker/setup-buildx-action@v2
36
+
37
+ - name: Login to DockerHub
38
+ uses: docker/login-action@v2
39
+ with:
40
+ username: ${{ secrets.DOCKER_HUB_USERNAME }}
41
+ password: ${{ secrets.DOCKER_HUB_PASSWORD }}
42
+
43
+ - name: Calculate Docker Image Variables
44
+ run: |
45
+ set -euxo pipefail
46
+
47
+ ###################
48
+ # Calculate the tag
49
+ ###################
50
+ GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
51
+ echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
52
+
53
+ - name: Build and Push the Docker Image
54
+ uses: docker/build-push-action@v3
55
+ with:
56
+ context: .
57
+ tags: mosaicml/llm-foundry:${{ matrix.name }}-latest,
58
+ mosaicml/llm-foundry:${{ matrix.name }}-${{ env.IMAGE_TAG }}
59
+ push: true
60
+ cache-from: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache
61
+ cache-to: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache,mode=max
62
+ build-args: BASE_IMAGE=${{ matrix.base_image }}
Andromeda/.github/workflows/pr-cpu.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PR CPU tests
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ - release/*
7
+ pull_request:
8
+ branches:
9
+ - main
10
+ - release/*
11
+ workflow_dispatch:
12
+ # Cancel old runs when a new commit is pushed to the same branch if not on main or dev
13
+ concurrency:
14
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
15
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
16
+ jobs:
17
+ pytest-cpu:
18
+ uses: ./.github/workflows/pytest-cpu.yaml
19
+ strategy:
20
+ matrix:
21
+ include:
22
+ - name: 'cpu-latest'
23
+ container: mosaicml/pytorch:latest_cpu # mosaicml/pytorch:1.13.1_cpu-python3.10-ubuntu20.04
24
+ markers: 'not gpu'
25
+ pytest_command: 'coverage run -m pytest'
26
+ - name: 'cpu-2.0.1'
27
+ container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
28
+ markers: 'not gpu'
29
+ pytest_command: 'coverage run -m pytest'
30
+ name: ${{ matrix.name }}
31
+ if: github.repository_owner == 'mosaicml'
32
+ with:
33
+ container: ${{ matrix.container }}
34
+ name: ${{ matrix.name }}
35
+ pytest-command: ${{ matrix.pytest_command }}
36
+ pytest-markers: ${{ matrix.markers }}
37
+ coverage:
38
+ uses: ./.github/workflows/coverage.yaml
39
+ name: Coverage Results
40
+ if: github.repository_owner == 'mosaicml'
41
+ needs: [pytest-cpu]
42
+ with:
43
+ download-path: artifacts
Andromeda/.github/workflows/pr-gpu.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PR GPU tests
2
+ on:
3
+ push:
4
+ branches:
5
+ - main
6
+ - release/*
7
+ pull_request_target:
8
+ branches:
9
+ - main
10
+ - release/**
11
+ workflow_dispatch:
12
+ # Cancel old runs when a new commit is pushed to the same branch if not on main or dev
13
+ concurrency:
14
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
15
+ cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
16
+ jobs:
17
+ pytest-gpu:
18
+ uses: ./.github/workflows/pytest-gpu.yaml
19
+ strategy:
20
+ matrix:
21
+ include:
22
+ - name: 'gpu-latest'
23
+ container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
24
+ markers: 'gpu'
25
+ pytest_command: 'coverage run -m pytest'
26
+ - name: 'gpu-2.0.1'
27
+ container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
28
+ markers: 'gpu'
29
+ pytest_command: 'coverage run -m pytest'
30
+ name: ${{ matrix.name }}
31
+ if: github.repository_owner == 'mosaicml'
32
+ with:
33
+ container: ${{ matrix.container }}
34
+ mcloud-timeout: 1200
35
+ name: ${{ matrix.name }}
36
+ pytest-command: ${{ matrix.pytest_command }}
37
+ pytest-markers: ${{ matrix.markers }}
38
+ python-version: 3.9
39
+ secrets:
40
+ mcloud-api-key: ${{ secrets.MCLOUD_API_KEY }}
Andromeda/.github/workflows/pytest-cpu.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Pytest CPU
2
+ on:
3
+ workflow_call:
4
+ inputs:
5
+ container:
6
+ required: true
7
+ type: string
8
+ name:
9
+ required: true
10
+ type: string
11
+ pytest-command:
12
+ required: true
13
+ type: string
14
+ pytest-markers:
15
+ required: true
16
+ type: string
17
+ jobs:
18
+ pytest-cpu:
19
+ timeout-minutes: 30
20
+ runs-on: ubuntu-latest
21
+ container: ${{ inputs.container }}
22
+ steps:
23
+ - name: Checkout Repo
24
+ uses: actions/checkout@v3
25
+ - name: Setup
26
+ run: |
27
+ set -ex
28
+ export PATH=/composer-python:$PATH
29
+ python -m pip install --upgrade 'pip<23' wheel
30
+ python -m pip install --upgrade .[dev]
31
+ - name: Run Tests
32
+ id: tests
33
+ run: |
34
+ set -ex
35
+ export PATH=/composer-python:$PATH
36
+ export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}'"
37
+
38
+ # Necessary to run git diff for doctests
39
+ git config --global --add safe.directory /__w/llm-foundry/llm-foundry
40
+
41
+ make test PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS --codeblocks"
42
+ # make test-dist PYTEST='${{ inputs.pytest-command }}' EXTRA_ARGS="$COMMON_ARGS" WORLD_SIZE=2
43
+
44
+ python -m coverage combine
45
+ - uses: actions/upload-artifact@v3
46
+ with:
47
+ name: coverage-${{ github.sha }}-${{ inputs.name }}
48
+ path: .coverage
Andromeda/.github/workflows/pytest-gpu.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Pytest GPU
2
+ on:
3
+ workflow_call:
4
+ inputs:
5
+ container:
6
+ required: true
7
+ type: string
8
+ mcloud-timeout:
9
+ required: false
10
+ type: number
11
+ default: 1800
12
+ name:
13
+ required: true
14
+ type: string
15
+ pytest-command:
16
+ required: true
17
+ type: string
18
+ pytest-markers:
19
+ required: true
20
+ type: string
21
+ python-version:
22
+ required: false
23
+ type: string
24
+ default: 3.9
25
+ secrets:
26
+ mcloud-api-key:
27
+ required: true
28
+ jobs:
29
+ pytest-gpu:
30
+ timeout-minutes: 60 # ${{ inputs.gha-timeout }} for some reason not able to turn this into an input
31
+ runs-on: ubuntu-latest
32
+ env:
33
+ MOSAICML_API_KEY: ${{ secrets.mcloud-api-key }}
34
+ steps:
35
+ - name: Checkout Repo
36
+ uses: actions/checkout@v3
37
+ - name: Setup Python
38
+ uses: actions/setup-python@v4
39
+ with:
40
+ python-version: ${{ inputs.python-version }}
41
+ - name: Cache pip
42
+ uses: actions/cache@v3
43
+ with:
44
+ # This path is specific to Ubuntu
45
+ path: ~/.cache/pip
46
+ # Look to see if there is a cache hit for the corresponding requirements file
47
+ key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
48
+ restore-keys: |
49
+ ${{ runner.os }}-pip-
50
+ ${{ runner.os }}-
51
+ - name: Setup MCLI
52
+ run: |
53
+ set -ex
54
+ python -m pip install mosaicml-cli
55
+ mcli init --mcloud
56
+ mcli version
57
+ - name: Submit Run
58
+ id: tests
59
+ run: |
60
+ set -ex
61
+
62
+ PR_NUMBER="$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH")"
63
+ REF_ARGS=""
64
+
65
+ # Use the PR number if it exists, commit SHA for protected branches and the branch name otherwise
66
+ if [ -z "$PR_NUMBER" ] || [ "$PR_NUMBER" = "null" ]; then
67
+ if [[ "$GITHUB_REF" =~ "refs/heads/main" || "$GITHUB_REF" =~ "refs/heads/release" ]]; then
68
+ REF_ARGS="--git_commit $GITHUB_SHA"
69
+ else
70
+ REF_ARGS="--git_branch $GITHUB_REF_NAME"
71
+ fi
72
+ else
73
+ REF_ARGS="--pr_number $PR_NUMBER"
74
+ fi
75
+
76
+ python .github/mcp/mcp_pytest.py \
77
+ --image '${{ inputs.container }}' \
78
+ --pytest_markers '${{ inputs.pytest-markers }}' \
79
+ --pytest_command '${{ inputs.pytest-command }}' \
80
+ --timeout ${{ inputs.mcloud-timeout }} ${REF_ARGS}
Andromeda/.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ permissions:
16
+ contents: read
17
+
18
+ jobs:
19
+ deploy:
20
+
21
+ runs-on: ubuntu-latest
22
+
23
+ steps:
24
+ - uses: actions/checkout@v3
25
+ - name: Set up Python
26
+ uses: actions/setup-python@v3
27
+ with:
28
+ python-version: '3.x'
29
+ - name: Install dependencies
30
+ run: |
31
+ python -m pip install --upgrade pip
32
+ pip install build
33
+ - name: Build package
34
+ run: python -m build
35
+ - name: Publish package
36
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37
+ with:
38
+ user: __token__
39
+ password: ${{ secrets.PYPI_API_TOKEN }}
Andromeda/.github/workflows/release.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*'
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ code-quality:
11
+ uses: ./.github/workflows/code-quality.yaml
12
+
13
+ pypi-packaging:
14
+ name: Build and Publish llm-foundry PyPI Package
15
+ needs:
16
+ - code-quality
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - name: Checkout source
20
+ uses: actions/checkout@v3
21
+
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v3
24
+ with:
25
+ python-version: '3.9'
26
+
27
+ - name: Build source and wheel distributions
28
+ run: |
29
+ if [[ "${{ github.ref }}" =~ refs\/tags\/v ]]; then
30
+ PYPI_PACKAGE_NAME="llm-foundry"
31
+ else
32
+ PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
33
+ fi
34
+
35
+ # Remove the peft, xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not
36
+ # support direct installs. The error message for importing PEFT, FusedCrossEntropy,
37
+ # and flash_attn_triton gives instructions on how to install if a user tries to use it
38
+ # without this dependency.
39
+ sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
40
+ sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py
41
+ sed '/peft@git+https:\/\/github.com\/huggingface\/peft.git.*/d' -i setup.py
42
+
43
+ python -m pip install --upgrade build twine
44
+ python -m build
45
+ twine check --strict dist/*
46
+
47
+ - name: Publish 📦 to PyPI
48
+ uses: pypa/gh-action-pypi-publish@release/v1
49
+ if: contains(github.ref, 'refs/tags/v')
50
+ with:
51
+ user: __token__
52
+ password: ${{ secrets.PROD_PYPI_API_TOKEN }}
53
+
54
+ - name: Publish distribution 📦 to Test PyPI
55
+ uses: pypa/gh-action-pypi-publish@release/v1
56
+ if: contains(github.ref, 'refs/heads/') || contains(github.ref, 'refs/pull/')
57
+ with:
58
+ user: __token__
59
+ password: ${{ secrets.TEST_PYPI_API_TOKEN }}
60
+ repository_url: https://test.pypi.org/legacy/
Andromeda/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .DS_Store
2
+ dist
Andromeda/Andromeda/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Transformer Model Technical Research Analysis
2
+
3
+ This document provides an analysis of the hyperparameters and configurations of the given Transformer model, focusing on dimensions, depth, and heads, as well as an architectural overview of their meanings and use cases.
4
+
5
+ ## Model Configuration
6
+
7
+ ```python
8
+ model = Transformer(
9
+ num_tokens=20000,
10
+ max_seq_len=8192,
11
+ use_abs_pos_emb = False,
12
+ attn_layers = Decoder(
13
+ dim=512,
14
+ depth=6,
15
+ heads=8,
16
+ alibi_pos_bias=True,
17
+ alibi_num_heads=4,
18
+ rotary_xpos=True,
19
+ attn_flash = True,
20
+ deepnorm=True,
21
+ shift_tokens=1,
22
+ attn_one_kv_head = True,
23
+ )
24
+ )
25
+ ```
26
+
27
+ ### Hyperparameters
28
+
29
+ 1. **num_tokens**: The number of unique tokens in the input vocabulary. In this case, the model is configured to handle 20,000 unique tokens.
30
+
31
+ 2. **max_seq_len**: The maximum sequence length that the model can handle. The current configuration supports sequences of up to 8,192 tokens.
32
+
33
+ 3. **use_abs_pos_emb**: A boolean flag indicating whether to use absolute positional embeddings. The model is configured not to use absolute positional embeddings (`False`).
34
+
35
+ 4. **dim**: The dimensionality of the input embeddings and the internal representations within the Transformer layers. The model uses a dimensionality of 512.
36
+
37
+ 5. **depth**: The number of Transformer layers (or blocks) in the model. This model has a depth of 6, meaning it has 6 layers.
38
+
39
+ 6. **heads**: The number of attention heads in the multi-head self-attention mechanism. This model uses 8 attention heads.
40
+
41
+ ### Additional Configurations
42
+
43
+ - **alibi_pos_bias**: A boolean flag indicating whether to use the Alibi position bias mechanism. The model is configured to use Alibi position bias (`True`).
44
+
45
+ - **alibi_num_heads**: The number of Alibi attention heads to use. The model is configured to use 4 Alibi attention heads.
46
+
47
+ - **rotary_xpos**: A boolean flag indicating whether to use the rotary positional encoding mechanism. The model is configured to use rotary positional encoding (`True`).
48
+
49
+ - **attn_flash**: A boolean flag indicating whether to use the Flash attention mechanism. The model is configured to use Flash attention (`True`).
50
+
51
+ - **deepnorm**: A boolean flag indicating whether to use deep normalization. The model is configured to use deep normalization (`True`).
52
+
53
+ - **shift_tokens**: The number of tokens to shift during training to form the target sequence. The model is configured to shift by 1 token (`1`).
54
+
55
+ - **attn_one_kv_head**: A boolean flag indicating whether to use one key-value head for attention instead of multiple heads. The model is configured to use one key-value head (`True`).
56
+
57
+ ## Architectural Overview
58
+
59
+ ### Dimensions
60
+
61
+ - **Input Embedding Dimension (dim)**: This hyperparameter defines the size of the input embeddings and the internal representations within the Transformer layers. A larger dimensionality can capture more complex relationships between tokens but may require more computational resources.
62
+
63
+ ### Depth
64
+
65
+ - **Number of Transformer Layers (depth)**: This hyperparameter defines the number of Transformer layers (or blocks) in the model. Each layer consists of a multi-head self-attention mechanism followed by a position-wise feed-forward network. Increasing the depth allows the model to capture more complex and hierarchical relationships between tokens but may also increase the risk of overfitting and require more computational resources.
66
+
67
+ ### Heads
68
+
69
+ - **Number of Attention Heads (heads)**: This hyperparameter defines the number of attention heads in the multi-head self-attention mechanism. Each head processes the input sequence independently and captures different aspects of the relationships between tokens. The outputs of all heads are then concatenated and transformed to produce the final output. Increasing the number of attention heads can help the model capture more diverse and fine-grained relationships between tokens but may also increase computational complexity and memory requirements.
70
+
71
+ ## Benefits and Consequences of Increasing Hyperparameters
72
+
73
+ ### Dimensions
74
+
75
+ **Benefits:**
76
+
77
+ - Better representation: Increasing the dimensionality of the input embeddings and internal representations allows the model to capture more complex relationships between tokens.
78
+
79
+ - Improved model expressiveness: A higher dimensionality may enable the model to learn more expressive features, leading to better performance on complex tasks.
80
+
81
+ **Consequences:**
82
+
83
+ - Computational complexity: Increasing the dimensionality will increase the computational complexity of the model, which may lead to longer training and inference times.
84
+
85
+ - Memory requirements: A higher dimensionality will increase the memory requirements of the model, potentially limiting its applicability on resource-constrained hardware.
86
+
87
+ - Risk of overfitting: Models with a higher dimensionality may be more prone to overfitting, especially if the size of the training dataset is small.
88
+
89
+ ### Depth
90
+
91
+ **Benefits:**
92
+
93
+ - Hierarchical representation: Increasing the depth of the model allows it to capture more complex and hierarchical relationships between tokens, which can lead to improved performance on tasks that require understanding long-range dependencies.
94
+
95
+ - Enhanced feature extraction: Deeper models can extract features at different levels of abstraction, potentially improving their ability to generalize to new data.
96
+
97
+ **Consequences:**
98
+
99
+ - Computational complexity: Increasing the depth will increase the computational complexity of the model, leading to longer training and inference times.
100
+
101
+ - Memory requirements: A deeper model will require more memory, potentially limiting its applicability on resource-constrained hardware.
102
+
103
+ - Risk of overfitting: Deeper models may be more prone to overfitting, especially if the size of the training dataset is small.
104
+
105
+ - Vanishing/exploding gradients: Deeper models may suffer from vanishing or exploding gradients during training, making it harder to optimize the model. Techniques such as layer normalization or skip connections can help mitigate this issue.
106
+
107
+ ### Heads
108
+
109
+ **Benefits:**
110
+
111
+ - Diverse attention: Increasing the number of attention heads allows the model to capture more diverse and fine-grained relationships between tokens, which can improve its ability to understand the input data.
112
+
113
+ - Robustness: Multi-head attention can make the model more robust, as each head can focus on different aspects of the input data.
114
+
115
+ **Consequences:**
116
+
117
+ - Computational complexity: Increasing the number of attention heads will increase the computational complexity of the model, leading to longer training and inference times.
118
+
119
+ - Memory requirements: A model with more attention heads will require more memory, potentially limiting its applicability on resource-constrained hardware.
120
+
121
+ - Diminishing returns: There may be diminishing returns when increasing the number of attention heads beyond a certain point, as the model may already be capturing most of the relevant information with fewer heads.
Andromeda/Andromeda/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # from Andromeda.train import Train
2
+ from Andromeda.model import AndromedaTokenizer, Andromeda
3
+ from Andromeda.train import Train, train
Andromeda/Andromeda/configs.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Andromeda.model import AndromedaEmbedding, Andromeda
2
+
3
+
4
+ Andromeda1Billion = Andromeda(
5
+ num_tokens=25000,
6
+ max_seq_len=4192,
7
+ dim=2048,
8
+ depth=16,
9
+ dim_head=128,
10
+ heads=8,
11
+ use_abs_pos_emb=False,
12
+ alibi_pos_bias=True,
13
+ alibi_num_heads=4,
14
+ rotary_xpos=True,
15
+ attn_flash=True,
16
+ # shift_tokens=1,
17
+ attn_one_kv_head=True,
18
+ qk_norm=True,
19
+ attn_qk_norm=True,
20
+ attn_qk_norm_dim_scale=True,
21
+ embedding_provider=AndromedaEmbedding()
22
+ )
23
+
24
+
25
+
26
+ Andromeda3Billion = Andromeda(
27
+ num_tokens=50432,
28
+ max_seq_len=8192,
29
+ dim=3072,
30
+ depth=24,
31
+ dim_head=128,
32
+ heads=12,
33
+ use_abs_pos_emb=False,
34
+ alibi_pos_bias=True,
35
+ alibi_num_heads=6,
36
+ rotary_xpos=True,
37
+ attn_flash=True,
38
+ shift_tokens=1,
39
+ attn_one_kv_head=True,
40
+ qk_norm=True,
41
+ attn_qk_norm=True,
42
+ attn_qk_norm_dim_scale=True,
43
+ embedding_provider=AndromedaEmbedding()
44
+ )
45
+
46
+
47
+
48
+ Andromeda7Billion = Andromeda(
49
+ num_tokens=50432,
50
+ max_seq_len=8192,
51
+ dim=4096,
52
+ depth=32,
53
+ dim_head=128,
54
+ heads=16,
55
+ use_abs_pos_emb=False,
56
+ alibi_pos_bias=True,
57
+ alibi_num_heads=8,
58
+ rotary_xpos=True,
59
+ attn_flash=True,
60
+ shift_tokens=1,
61
+ attn_one_kv_head=True,
62
+ qk_norm=True,
63
+ attn_qk_norm=True,
64
+ attn_qk_norm_dim_scale=True,
65
+ embedding_provider=AndromedaEmbedding()
66
+ )
67
+
68
+ Andromeda10Billion = Andromeda(
69
+ num_tokens=50432,
70
+ max_seq_len=8192,
71
+ dim=5120,
72
+ depth=32,
73
+ dim_head=128,
74
+ heads=20,
75
+ use_abs_pos_emb=False,
76
+ alibi_pos_bias=True,
77
+ alibi_num_heads=4,
78
+ rotary_xpos=True,
79
+ attn_flash=True,
80
+ shift_tokens=1,
81
+ attn_one_kv_head=True,
82
+ qk_norm=True,
83
+ attn_qk_norm=True,
84
+ attn_qk_norm_dim_scale=True,
85
+ embedding_provider=AndromedaEmbedding()
86
+ )
87
+
88
+ Andromeda15Billion = Andromeda(
89
+ num_tokens=50432,
90
+ max_seq_len=8192,
91
+ dim=6144,
92
+ depth=40,
93
+ dim_head=128,
94
+ heads=24,
95
+ use_abs_pos_emb=False,
96
+ alibi_pos_bias=True,
97
+ alibi_num_heads=4,
98
+ rotary_xpos=True,
99
+ attn_flash=True,
100
+ shift_tokens=1,
101
+ attn_one_kv_head=True,
102
+ qk_norm=True,
103
+ attn_qk_norm=True,
104
+ attn_qk_norm_dim_scale=True,
105
+ embedding_provider=AndromedaEmbedding()
106
+ )
107
+
108
+ Andromeda20Billion = Andromeda(
109
+ num_tokens=50432,
110
+ max_seq_len=8192,
111
+ dim=7168,
112
+ depth=48,
113
+ dim_head=128,
114
+ heads=28,
115
+ use_abs_pos_emb=False,
116
+ alibi_pos_bias=True,
117
+ alibi_num_heads=4,
118
+ rotary_xpos=True,
119
+ attn_flash=True,
120
+ shift_tokens=1,
121
+ attn_one_kv_head=True,
122
+ qk_norm=True,
123
+ attn_qk_norm=True,
124
+ attn_qk_norm_dim_scale=True,
125
+ embedding_provider=AndromedaEmbedding()
126
+ )
127
+
128
+ #to GPT like 176Billion Parameters 122888 dimension, 96 depth, 96 heads, attn dim head 128
Andromeda/Andromeda/core/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from packaging import version
3
+
4
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
5
+ from einops._torch_specific import allow_ops_in_compiled_graph
6
+ allow_ops_in_compiled_graph()
7
+
8
+
Andromeda/Andromeda/core/attend.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ import torch.nn.functional as F
6
+
7
+ from collections import namedtuple
8
+ from functools import wraps
9
+ from packaging import version
10
+ from dataclasses import dataclass
11
+ from einops import rearrange
12
+
13
+ from Andromeda.core.flash import attention
14
+
15
+ # from flash import FlashAttention
16
+
17
+ # constants
18
+
19
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
20
+
21
+ @dataclass
22
+ class Intermediates:
23
+ qk_similarities: Tensor = None
24
+ pre_softmax_attn: Tensor = None
25
+ post_softmax_attn: Tensor = None
26
+
27
+ # helpers
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+ def once(fn):
36
+ called = False
37
+ @wraps(fn)
38
+ def inner(x):
39
+ nonlocal called
40
+ if called:
41
+ return
42
+ called = True
43
+ return fn(x)
44
+ return inner
45
+
46
+ print_once = once(print)
47
+
48
+ # main class
49
+
50
+ class Attend(nn.Module):
51
+ def __init__(
52
+ self,
53
+ *,
54
+ dropout = 0.,
55
+ causal = False,
56
+ heads = None,
57
+ talking_heads = False,
58
+ scale = None,
59
+ qk_norm = False,
60
+ flash = False,
61
+ triton = False,
62
+ ):
63
+ super().__init__()
64
+ self.scale = scale
65
+ self.qk_norm = qk_norm
66
+ self.causal = causal
67
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
68
+
69
+ self.dropout = dropout
70
+ self.attn_dropout = nn.Dropout(dropout)
71
+
72
+ # talking heads
73
+
74
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
75
+
76
+ self.talking_heads = talking_heads
77
+ if talking_heads:
78
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
79
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
80
+
81
+ # flash attention
82
+ self.flash = flash
83
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
84
+
85
+ # determine efficient attention configs for cuda and cpu
86
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
87
+ self.cuda_config = None
88
+
89
+ if not torch.cuda.is_available() or not flash:
90
+ return
91
+
92
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
93
+
94
+ if device_properties.major == 8 and device_properties.minor == 0:
95
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
96
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
97
+ else:
98
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
99
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
100
+
101
+ def flash_attn(
102
+ self,
103
+ q, k, v,
104
+ mask = None,
105
+ attn_bias = None
106
+ ):
107
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
108
+
109
+ # Recommended for multi-query single-key-value attention by Tri Dao
110
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
111
+
112
+ if k.ndim == 3:
113
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
114
+
115
+ if v.ndim == 3:
116
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
117
+
118
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
119
+
120
+ if self.qk_norm:
121
+ default_scale = q.shape[-1] ** -0.5
122
+ q = q * (default_scale / self.scale)
123
+
124
+ # Check if mask exists and expand to compatible shape
125
+ # The mask is B L, so it would have to be expanded to B H N L
126
+
127
+ causal = self.causal
128
+
129
+ if exists(mask):
130
+ assert mask.ndim == 4
131
+ mask = mask.expand(batch, heads, q_len, k_len)
132
+
133
+ # manually handle causal mask, if another mask was given
134
+
135
+ if causal:
136
+ causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
137
+ mask = mask | causal_mask
138
+ causal = False
139
+
140
+ # handle alibi positional bias
141
+ # convert from bool to float
142
+
143
+ if exists(attn_bias):
144
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)
145
+
146
+ # if mask given, the mask would already contain the causal mask from above logic
147
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
148
+
149
+ mask_value = -torch.finfo(q.dtype).max
150
+
151
+ if exists(mask):
152
+ attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
153
+ elif causal:
154
+ causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
155
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
156
+ causal = False
157
+
158
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
159
+ # make it an additive bias here
160
+
161
+ mask = attn_bias
162
+
163
+ # Check if there is a compatible device for flash attention
164
+
165
+ config = self.cuda_config if is_cuda else self.cpu_config
166
+
167
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
168
+
169
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
170
+ out = F.scaled_dot_product_attention(
171
+ q, k, v,
172
+ attn_mask = mask,
173
+ dropout_p = self.dropout if self.training else 0.,
174
+ is_causal = causal
175
+ )
176
+
177
+ return out, Intermediates()
178
+
179
+ def forward(
180
+ self,
181
+ q, k, v,
182
+ mask = None,
183
+ attn_bias = None,
184
+ prev_attn = None
185
+ ):
186
+ """
187
+ einstein notation
188
+ b - batch
189
+ h - heads
190
+ n, i, j - sequence length (base sequence length, source, target)
191
+ d - feature dimension
192
+ """
193
+
194
+ n, device = q.shape[-2], q.device
195
+
196
+ scale = default(self.scale, q.shape[-1] ** -0.5)
197
+
198
+ if self.flash:
199
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
200
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
201
+ # return FlashAttention(q, k, v, mask=mask, attn_bias=attn_bias )
202
+
203
+ if self.triton:
204
+ return attention(q, k, v, self.casual, scale)
205
+
206
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
207
+
208
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
209
+
210
+ if exists(prev_attn):
211
+ dots = dots + prev_attn
212
+
213
+ qk_similarities = dots.clone()
214
+
215
+ if self.talking_heads:
216
+ dots = self.pre_softmax_talking_heads(dots)
217
+
218
+ if exists(attn_bias):
219
+ dots = dots + attn_bias
220
+
221
+ dtype = dots.dtype
222
+ pre_softmax_attn = dots.clone()
223
+
224
+ mask_value = -torch.finfo(dots.dtype).max
225
+
226
+ if exists(mask):
227
+ dots = dots.masked_fill(mask, mask_value)
228
+
229
+ if self.causal:
230
+ i, j = dots.shape[-2:]
231
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
232
+ dots = dots.masked_fill(causal_mask, mask_value)
233
+
234
+ attn = self.attn_fn(dots, dim = -1)
235
+ attn = attn.type(dtype)
236
+
237
+ post_softmax_attn = attn.clone()
238
+
239
+ attn = self.attn_dropout(attn)
240
+
241
+ if self.talking_heads:
242
+ attn = self.post_softmax_talking_heads(attn)
243
+
244
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
245
+
246
+ intermediates = Intermediates(
247
+ qk_similarities = qk_similarities,
248
+ pre_softmax_attn = pre_softmax_attn,
249
+ post_softmax_attn = post_softmax_attn
250
+ )
251
+
252
+ return out, intermediates
Andromeda/Andromeda/core/autoregressive_wrapper.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+ from einops import rearrange, pack, unpack
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+ def eval_decorator(fn):
12
+ def inner(self, *args, **kwargs):
13
+ was_training = self.training
14
+ self.eval()
15
+ out = fn(self, *args, **kwargs)
16
+ self.train(was_training)
17
+ return out
18
+ return inner
19
+
20
+ # nucleus
21
+
22
+ def top_p(logits, thres = 0.9):
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
24
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
25
+
26
+ sorted_indices_to_remove = cum_probs > (1 - thres)
27
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
28
+ sorted_indices_to_remove[:, 0] = 0
29
+
30
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
31
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
32
+
33
+ # topk
34
+
35
+ def top_k(logits, thres = 0.9):
36
+ k = ceil((1 - thres) * logits.shape[-1])
37
+ val, ind = torch.topk(logits, k)
38
+ probs = torch.full_like(logits, float('-inf'))
39
+ probs.scatter_(1, ind, val)
40
+ return probs
41
+
42
+ # top_a
43
+
44
+ def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
45
+ probs = F.softmax(logits, dim=-1)
46
+ limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
47
+ logits[probs < limit] = float('-inf')
48
+ logits[probs >= limit] = 1
49
+ return logits
50
+
51
+ # autoregressive wrapper class
52
+
53
+ class AutoregressiveWrapper(nn.Module):
54
+ def __init__(
55
+ self,
56
+ net,
57
+ ignore_index = -100,
58
+ pad_value = 0,
59
+ mask_prob = 0.
60
+ ):
61
+ super().__init__()
62
+ self.pad_value = pad_value
63
+ self.ignore_index = ignore_index
64
+
65
+ self.net = net
66
+ self.max_seq_len = net.max_seq_len
67
+
68
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
69
+ assert mask_prob < 1.
70
+ self.mask_prob = mask_prob
71
+
72
+ @torch.no_grad()
73
+ @eval_decorator
74
+ def generate(
75
+ self,
76
+ start_tokens,
77
+ seq_len,
78
+ eos_token = None,
79
+ temperature = 1.,
80
+ filter_logits_fn = top_k,
81
+ filter_thres = 0.9,
82
+ min_p_pow = 2.0,
83
+ min_p_ratio = 0.02,
84
+ **kwargs
85
+ ):
86
+
87
+ start_tokens, ps = pack([start_tokens], '* n')
88
+
89
+ b, t = start_tokens.shape
90
+
91
+ out = start_tokens
92
+
93
+ for _ in range(seq_len):
94
+ x = out[:, -self.max_seq_len:]
95
+
96
+ logits = self.net(x, **kwargs)[:, -1]
97
+
98
+ if filter_logits_fn in {top_k, top_p}:
99
+ filtered_logits = filter_logits_fn(logits, thres = filter_thres)
100
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
101
+
102
+ elif filter_logits_fn is top_a:
103
+ filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio)
104
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
105
+
106
+ sample = torch.multinomial(probs, 1)
107
+
108
+ out = torch.cat((out, sample), dim=-1)
109
+
110
+ if exists(eos_token):
111
+ is_eos_tokens = (out == eos_token)
112
+
113
+ if is_eos_tokens.any(dim = -1).all():
114
+ # mask out everything after the eos tokens
115
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
116
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
117
+ out = out.masked_fill(mask, self.pad_value)
118
+ break
119
+
120
+ out = out[:, t:]
121
+
122
+ out, = unpack(out, ps, '* n')
123
+
124
+ return out
125
+
126
+ def forward(self, x, return_loss=True, **kwargs):
127
+ seq, ignore_index = x.shape[1], self.ignore_index
128
+
129
+ inp, target = x[:, :-1], x[:, 1:]
130
+
131
+ if self.mask_prob > 0.:
132
+ rand = torch.randn(inp.shape, device = x.device)
133
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
134
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
135
+ indices = rand.topk(num_mask, dim = -1).indices
136
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
137
+ kwargs.update(self_attn_context_mask = mask)
138
+
139
+ logits = self.net(inp, **kwargs)
140
+
141
+ loss = F.cross_entropy(
142
+ rearrange(logits, 'b n c -> b c n'),
143
+ target,
144
+ ignore_index = ignore_index
145
+ )
146
+
147
+ if return_loss:
148
+ return logits, loss
149
+
150
+ return logits
Andromeda/Andromeda/core/flash.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import triton
4
+ import triton.language as tl
5
+
6
+
7
+ @triton.jit
8
+ def max_fn(x, y):
9
+ return tl.math.max(x, y)
10
+
11
+
12
+ @triton.jit
13
+ def _fwd_kernel(
14
+ Q, K, V, sm_scale,
15
+ L,
16
+ Out,
17
+ stride_qz, stride_qh, stride_qm, stride_qk,
18
+ stride_kz, stride_kh, stride_kn, stride_kk,
19
+ stride_vz, stride_vh, stride_vk, stride_vn,
20
+ stride_oz, stride_oh, stride_om, stride_on,
21
+ Z, H, N_CTX,
22
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
23
+ BLOCK_N: tl.constexpr,
24
+ IS_CAUSAL: tl.constexpr,
25
+ ):
26
+ start_m = tl.program_id(0)
27
+ off_hz = tl.program_id(1)
28
+ qvk_offset = off_hz * stride_qh
29
+ Q_block_ptr = tl.make_block_ptr(
30
+ base=Q + qvk_offset,
31
+ shape=(N_CTX, BLOCK_DMODEL),
32
+ strides=(stride_qm, stride_qk),
33
+ offsets=(start_m * BLOCK_M, 0),
34
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
35
+ order=(1, 0)
36
+ )
37
+ K_block_ptr = tl.make_block_ptr(
38
+ base=K + qvk_offset,
39
+ shape=(BLOCK_DMODEL, N_CTX),
40
+ strides=(stride_kk, stride_kn),
41
+ offsets=(0, 0),
42
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
43
+ order=(0, 1)
44
+ )
45
+ V_block_ptr = tl.make_block_ptr(
46
+ base=V + qvk_offset,
47
+ shape=(N_CTX, BLOCK_DMODEL),
48
+ strides=(stride_vk, stride_vn),
49
+ offsets=(0, 0),
50
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
51
+ order=(1, 0)
52
+ )
53
+ # initialize offsets
54
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
55
+ offs_n = tl.arange(0, BLOCK_N)
56
+ # initialize pointer to m and l
57
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
58
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
59
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
60
+ # scale sm_scale by log_2(e) and use
61
+ # 2^x instead of exp in the loop because CSE and LICM
62
+ # don't work as expected with `exp` in the loop
63
+ qk_scale = sm_scale * 1.44269504
64
+ # load q: it will stay in SRAM throughout
65
+ q = tl.load(Q_block_ptr)
66
+ q = (q * qk_scale).to(tl.float16)
67
+ # loop over k, v and update accumulator
68
+ lo = 0
69
+ hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
70
+ for start_n in range(lo, hi, BLOCK_N):
71
+ # -- load k, v --
72
+ k = tl.load(K_block_ptr)
73
+ v = tl.load(V_block_ptr)
74
+ # -- compute qk ---
75
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
76
+ if IS_CAUSAL:
77
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
78
+ qk += tl.dot(q, k)
79
+ # -- compute scaling constant ---
80
+ m_i_new = tl.maximum(m_i, tl.max(qk, 1))
81
+ alpha = tl.math.exp2(m_i - m_i_new)
82
+ p = tl.math.exp2(qk - m_i_new[:, None])
83
+ # -- scale and update acc --
84
+ acc_scale = l_i * 0 + alpha # workaround some compiler bug
85
+ acc *= acc_scale[:, None]
86
+ acc += tl.dot(p.to(tl.float16), v)
87
+ # -- update m_i and l_i --
88
+ l_i = l_i * alpha + tl.sum(p, 1)
89
+ m_i = m_i_new
90
+ # update pointers
91
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
92
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
93
+ # write back l and m
94
+ acc = acc / l_i[:, None]
95
+ l_ptrs = L + off_hz * N_CTX + offs_m
96
+ tl.store(l_ptrs, m_i + tl.math.log2(l_i))
97
+ # write back O
98
+ O_block_ptr = tl.make_block_ptr(
99
+ base=Out + qvk_offset,
100
+ shape=(N_CTX, BLOCK_DMODEL),
101
+ strides=(stride_om, stride_on),
102
+ offsets=(start_m * BLOCK_M, 0),
103
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
104
+ order=(1, 0)
105
+ )
106
+ tl.store(O_block_ptr, acc.to(tl.float16))
107
+
108
+
109
+ @triton.jit
110
+ def _bwd_preprocess(
111
+ Out, DO,
112
+ Delta,
113
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
114
+ ):
115
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
116
+ off_n = tl.arange(0, D_HEAD)
117
+ # load
118
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
119
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
120
+ # compute
121
+ delta = tl.sum(o * do, axis=1)
122
+ # write-back
123
+ tl.store(Delta + off_m, delta)
124
+
125
+
126
+ @triton.jit
127
+ def _bwd_kernel(
128
+ Q, K, V, sm_scale, Out, DO,
129
+ DQ, DK, DV,
130
+ L,
131
+ D,
132
+ stride_qz, stride_qh, stride_qm, stride_qk,
133
+ stride_kz, stride_kh, stride_kn, stride_kk,
134
+ stride_vz, stride_vh, stride_vk, stride_vn,
135
+ Z, H, N_CTX,
136
+ num_block,
137
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
138
+ BLOCK_N: tl.constexpr,
139
+ CAUSAL: tl.constexpr,
140
+ ):
141
+ off_hz = tl.program_id(0)
142
+ off_z = off_hz // H
143
+ off_h = off_hz % H
144
+ qk_scale = sm_scale * 1.44269504
145
+ # offset pointers for batch/head
146
+ Q += off_z * stride_qz + off_h * stride_qh
147
+ K += off_z * stride_qz + off_h * stride_qh
148
+ V += off_z * stride_qz + off_h * stride_qh
149
+ DO += off_z * stride_qz + off_h * stride_qh
150
+ DQ += off_z * stride_qz + off_h * stride_qh
151
+ DK += off_z * stride_qz + off_h * stride_qh
152
+ DV += off_z * stride_qz + off_h * stride_qh
153
+ for start_n in range(0, num_block):
154
+ if CAUSAL:
155
+ lo = start_n * BLOCK_M
156
+ else:
157
+ lo = 0
158
+ # initialize row/col offsets
159
+ offs_qm = lo + tl.arange(0, BLOCK_M)
160
+ offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
161
+ offs_m = tl.arange(0, BLOCK_N)
162
+ offs_k = tl.arange(0, BLOCK_DMODEL)
163
+ # initialize pointers to value-like data
164
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
165
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
166
+ v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
167
+ do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
168
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
169
+ # pointer to row-wise quantities in value-like data
170
+ D_ptrs = D + off_hz * N_CTX
171
+ l_ptrs = L + off_hz * N_CTX
172
+ # initialize dv amd dk
173
+ dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
174
+ dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
175
+ # k and v stay in SRAM throughout
176
+ k = tl.load(k_ptrs)
177
+ v = tl.load(v_ptrs)
178
+ # loop over rows
179
+ for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
180
+ offs_m_curr = start_m + offs_m
181
+ # load q, k, v, do on-chip
182
+ q = tl.load(q_ptrs)
183
+ # recompute p = softmax(qk, dim=-1).T
184
+ if CAUSAL:
185
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
186
+ else:
187
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
188
+ qk += tl.dot(q, tl.trans(k))
189
+ qk *= qk_scale
190
+ l_i = tl.load(l_ptrs + offs_m_curr)
191
+ p = tl.math.exp2(qk - l_i[:, None])
192
+ # compute dv
193
+ do = tl.load(do_ptrs)
194
+ dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
195
+ # compute dp = dot(v, do)
196
+ Di = tl.load(D_ptrs + offs_m_curr)
197
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
198
+ dp += tl.dot(do, tl.trans(v))
199
+ # compute ds = p * (dp - delta[:, None])
200
+ ds = p * dp * sm_scale
201
+ # compute dk = dot(ds.T, q)
202
+ dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
203
+ # compute dq
204
+ dq = tl.load(dq_ptrs)
205
+ dq += tl.dot(ds.to(Q.dtype.element_ty), k)
206
+ tl.store(dq_ptrs, dq)
207
+ # increment pointers
208
+ dq_ptrs += BLOCK_M * stride_qm
209
+ q_ptrs += BLOCK_M * stride_qm
210
+ do_ptrs += BLOCK_M * stride_qm
211
+ # write-back
212
+ dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
213
+ dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
214
+ tl.store(dv_ptrs, dv)
215
+ tl.store(dk_ptrs, dk)
216
+
217
+
218
+ empty = torch.empty(128, device="cuda")
219
+
220
+
221
+ class _attention(torch.autograd.Function):
222
+
223
+ @staticmethod
224
+ def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal, sm_scale):
225
+ # shape constraints
226
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
227
+ assert Lq == Lk and Lk == Lv
228
+ assert Lk in {16, 32, 64, 128}
229
+ o = torch.empty_like(q)
230
+ BLOCK_M = 128
231
+ BLOCK_N = 64
232
+ grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
233
+ L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
234
+
235
+ num_warps = 4 if Lk <= 64 else 8
236
+ _fwd_kernel[grid](
237
+ q, k, v, sm_scale,
238
+ L,
239
+ o,
240
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
241
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
242
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
243
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
244
+ q.shape[0], q.shape[1], q.shape[2],
245
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
246
+ IS_CAUSAL=causal,
247
+ num_warps=num_warps,
248
+ num_stages=4)
249
+
250
+ ctx.save_for_backward(q, k, v, o, L)
251
+ ctx.grid = grid
252
+ ctx.sm_scale = sm_scale
253
+ ctx.BLOCK_DMODEL = Lk
254
+ ctx.causal = causal
255
+ return o
256
+
257
+ @staticmethod
258
+ def backward(ctx, do):
259
+ BLOCK = 128
260
+ q, k, v, o, L = ctx.saved_tensors
261
+ do = do.contiguous()
262
+ dq = torch.zeros_like(q, dtype=torch.float32)
263
+ dk = torch.empty_like(k)
264
+ dv = torch.empty_like(v)
265
+ delta = torch.empty_like(L)
266
+ _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
267
+ o, do,
268
+ delta,
269
+ BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
270
+ )
271
+ _bwd_kernel[(ctx.grid[1],)](
272
+ q, k, v, ctx.sm_scale,
273
+ o, do,
274
+ dq, dk, dv,
275
+ L, delta,
276
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
277
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
278
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
279
+ q.shape[0], q.shape[1], q.shape[2],
280
+ ctx.grid[0],
281
+ BLOCK_M=BLOCK, BLOCK_N=BLOCK,
282
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
283
+ CAUSAL=ctx.causal,
284
+ num_stages=1,
285
+ )
286
+ return dq, dk, dv, None, None
287
+
288
+
289
+ attention = _attention.apply
Andromeda/Andromeda/core/transformer.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from random import random
3
+
4
+ import torch
5
+ from torch import nn, einsum, Tensor
6
+ import torch.nn.functional as F
7
+
8
+ from functools import partial, wraps
9
+ from inspect import isfunction
10
+ from dataclasses import dataclass
11
+ from typing import List
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from Andromeda.core.attend import Attend, Intermediates
16
+ from Andromeda.core.autoregressive_wrapper import AutoregressiveWrapper
17
+
18
+ from abc import ABC, abstractmethod
19
+ # import bitsandbytes as bnb
20
+
21
+ # constants
22
+
23
+ DEFAULT_DIM_HEAD = 64
24
+
25
+ @dataclass
26
+ class LayerIntermediates:
27
+ hiddens: List[Tensor] = None
28
+ attn_intermediates: List[Intermediates] = None
29
+
30
+ # helpers
31
+
32
+ def exists(val):
33
+ return val is not None
34
+
35
+ def default(val, d):
36
+ if exists(val):
37
+ return val
38
+ return d() if isfunction(d) else d
39
+
40
+ def cast_tuple(val, depth):
41
+ return val if isinstance(val, tuple) else (val,) * depth
42
+
43
+ def maybe(fn):
44
+ @wraps(fn)
45
+ def inner(x, *args, **kwargs):
46
+ if not exists(x):
47
+ return x
48
+ return fn(x, *args, **kwargs)
49
+ return inner
50
+
51
+ class always():
52
+ def __init__(self, val):
53
+ self.val = val
54
+ def __call__(self, *args, **kwargs):
55
+ return self.val
56
+
57
+ class not_equals():
58
+ def __init__(self, val):
59
+ self.val = val
60
+ def __call__(self, x, *args, **kwargs):
61
+ return x != self.val
62
+
63
+ class equals():
64
+ def __init__(self, val):
65
+ self.val = val
66
+ def __call__(self, x, *args, **kwargs):
67
+ return x == self.val
68
+
69
+ # tensor helpers
70
+
71
+ def max_neg_value(tensor):
72
+ return -torch.finfo(tensor.dtype).max
73
+
74
+ def l2norm(t, groups = 1):
75
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
76
+ t = F.normalize(t, p = 2, dim = -1)
77
+ return rearrange(t, '... g d -> ... (g d)')
78
+
79
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
80
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
81
+ zeros = ((0, 0) * dims_from_right)
82
+ return F.pad(t, (*zeros, *pad), value = value)
83
+
84
+ def or_reduce(masks):
85
+ head, *body = masks
86
+ for rest in body:
87
+ head = head | rest
88
+ return head
89
+
90
+ # init helpers
91
+
92
+ def init_zero_(layer):
93
+ nn.init.constant_(layer.weight, 0.)
94
+ if exists(layer.bias):
95
+ nn.init.constant_(layer.bias, 0.)
96
+
97
+ # keyword argument helpers
98
+
99
+ def pick_and_pop(keys, d):
100
+ values = list(map(lambda key: d.pop(key), keys))
101
+ return dict(zip(keys, values))
102
+
103
+ def group_dict_by_key(cond, d):
104
+ return_val = [dict(),dict()]
105
+ for key in d.keys():
106
+ match = bool(cond(key))
107
+ ind = int(not match)
108
+ return_val[ind][key] = d[key]
109
+ return (*return_val,)
110
+
111
+ def string_begins_with(prefix, str):
112
+ return str.startswith(prefix)
113
+
114
+ def group_by_key_prefix(prefix, d):
115
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
116
+
117
+ def groupby_prefix_and_trim(prefix, d):
118
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
119
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
120
+ return kwargs_without_prefix, kwargs
121
+
122
+ # initializations
123
+
124
+ def deepnorm_init(
125
+ transformer,
126
+ beta,
127
+ module_name_match_list = ['.ff.', '.to_v', '.to_out']
128
+ ):
129
+ for name, module in transformer.named_modules():
130
+ if type(module) != nn.Linear:
131
+ continue
132
+
133
+ needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
134
+ gain = beta if needs_beta_gain else 1
135
+ nn.init.xavier_normal_(module.weight.data, gain = gain)
136
+
137
+ if exists(module.bias):
138
+ nn.init.constant_(module.bias.data, 0)
139
+
140
+ # structured dropout, more effective than traditional attention dropouts
141
+
142
+ def dropout_seq(seq, mask, dropout):
143
+ b, n, *_, device = *seq.shape, seq.device
144
+ logits = torch.randn(b, n, device = device)
145
+
146
+ if exists(mask):
147
+ mask_value = max_neg_value(logits)
148
+ logits = logits.masked_fill(~mask, mask_value)
149
+
150
+ keep_prob = 1. - dropout
151
+ num_keep = max(1, int(keep_prob * n))
152
+ keep_indices = logits.topk(num_keep, dim = 1).indices
153
+
154
+ batch_indices = torch.arange(b, device = device)
155
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
156
+
157
+ seq = seq[batch_indices, keep_indices]
158
+
159
+ if exists(mask):
160
+ seq_counts = mask.sum(dim = -1)
161
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
162
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
163
+
164
+ mask = mask[batch_indices, keep_indices] & keep_mask
165
+
166
+ return seq, mask
167
+
168
+ # activations
169
+
170
+ class ReluSquared(nn.Module):
171
+ def forward(self, x):
172
+ return F.relu(x) ** 2
173
+
174
+
175
+ #tokenization
176
+ class BaseTokenizer(ABC):
177
+ @abstractmethod
178
+ def tokenize(self, text: str) -> List[int]:
179
+ pass
180
+
181
+ class CustomTokenizer(BaseTokenizer):
182
+ def tokenize(self, text: str) -> List[int]:
183
+ # Your custom tokenization algorithm
184
+ tokens = ...
185
+ return tokens
186
+
187
+ # embedding
188
+
189
+ class BaseEmbedding(ABC):
190
+ @abstractmethod
191
+ def get_embedding(self, num_tokens: int, dim: int) -> nn.Module:
192
+ # Custom embedding function or model
193
+ embedding = ...
194
+
195
+ return embedding
196
+
197
+ class AndromedaEmbedding(BaseEmbedding):
198
+ def get_embedding(self, num_tokens: int, dim: int) -> nn.Module:
199
+ embedding = nn.Embedding(num_tokens, dim)
200
+
201
+ return embedding
202
+
203
+ # class AndromedaBnBEmbedding(BaseEmbedding):
204
+ # def get_embedding(self, num_tokens: int, dim: int, padding_idx: int = 0) -> bnb.nn.modules:
205
+ # embedding = bnb.nn.modules.Embedding(num_tokens, dim, padding_idx)
206
+
207
+ # return embedding
208
+
209
+ class TokenEmbedding(nn.Module):
210
+ def __init__(self, dim, num_tokens, embedding_provider: BaseEmbedding, l2norm_embed = False):
211
+ super().__init__()
212
+ self.l2norm_embed = l2norm_embed
213
+ self.emb = embedding_provider.get_embedding(num_tokens, dim)
214
+ # nn.Embedding(num_tokens, dim)
215
+
216
+ def forward(self, x):
217
+ token_emb = self.emb(x)
218
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
219
+
220
+ # positional embeddings
221
+
222
+ class AbsolutePositionalEmbedding(nn.Module):
223
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
224
+ super().__init__()
225
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
226
+ self.max_seq_len = max_seq_len
227
+ self.l2norm_embed = l2norm_embed
228
+ self.emb = nn.Embedding(max_seq_len, dim)
229
+
230
+ def forward(self, x, pos = None):
231
+ seq_len, device = x.shape[1], x.device
232
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
233
+
234
+ if not exists(pos):
235
+ pos = torch.arange(seq_len, device = device)
236
+
237
+ pos_emb = self.emb(pos)
238
+ pos_emb = pos_emb * self.scale
239
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
240
+
241
+ class ScaledSinusoidalEmbedding(nn.Module):
242
+ def __init__(self, dim, theta = 10000):
243
+ super().__init__()
244
+ assert (dim % 2) == 0
245
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
246
+
247
+ half_dim = dim // 2
248
+ freq_seq = torch.arange(half_dim).float() / half_dim
249
+ inv_freq = theta ** -freq_seq
250
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
251
+
252
+ def forward(self, x, pos = None):
253
+ seq_len, device = x.shape[1], x.device
254
+
255
+ if not exists(pos):
256
+ pos = torch.arange(seq_len, device = device)
257
+
258
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
259
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
260
+ return emb * self.scale
261
+
262
+ class RelativePositionBias(nn.Module):
263
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
264
+ super().__init__()
265
+ self.scale = scale
266
+ self.causal = causal
267
+ self.num_buckets = num_buckets
268
+ self.max_distance = max_distance
269
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
270
+
271
+ @staticmethod
272
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
273
+ ret = 0
274
+ n = -relative_position
275
+ if not causal:
276
+ num_buckets //= 2
277
+ ret += (n < 0).long() * num_buckets
278
+ n = torch.abs(n)
279
+ else:
280
+ n = torch.max(n, torch.zeros_like(n))
281
+
282
+ max_exact = num_buckets // 2
283
+ is_small = n < max_exact
284
+
285
+ val_if_large = max_exact + (
286
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
287
+ ).long()
288
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
289
+
290
+ ret += torch.where(is_small, n, val_if_large)
291
+ return ret
292
+
293
+ @property
294
+ def device(self):
295
+ return next(self.parameters()).device
296
+
297
+ def forward(self, i, j):
298
+ device = self.device
299
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
300
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
301
+ rel_pos = k_pos[None, :] - q_pos[:, None]
302
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
303
+ values = self.relative_attention_bias(rp_bucket)
304
+ bias = rearrange(values, 'i j h -> h i j')
305
+ return bias * self.scale
306
+
307
+ class DynamicPositionBias(nn.Module):
308
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
309
+ super().__init__()
310
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
311
+ self.log_distance = log_distance
312
+
313
+ self.mlp = nn.ModuleList([])
314
+
315
+ self.mlp.append(nn.Sequential(
316
+ nn.Linear(1, dim),
317
+ nn.LayerNorm(dim) if norm else nn.Identity(),
318
+ nn.SiLU()
319
+ ))
320
+
321
+ for _ in range(depth - 1):
322
+ self.mlp.append(nn.Sequential(
323
+ nn.Linear(dim, dim),
324
+ nn.LayerNorm(dim) if norm else nn.Identity(),
325
+ nn.SiLU()
326
+ ))
327
+
328
+ self.mlp.append(nn.Linear(dim, heads))
329
+
330
+ @property
331
+ def device(self):
332
+ return next(self.parameters()).device
333
+
334
+ def forward(self, i, j):
335
+ assert i == j
336
+ n, device = j, self.device
337
+
338
+ # get the (n x n) matrix of distances
339
+ seq_arange = torch.arange(n, device = device)
340
+ context_arange = torch.arange(n, device = device)
341
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
342
+ indices += (n - 1)
343
+
344
+ # input to continuous positions MLP
345
+ pos = torch.arange(-n + 1, n, device = device).float()
346
+ pos = rearrange(pos, '... -> ... 1')
347
+
348
+ if self.log_distance:
349
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
350
+
351
+ for layer in self.mlp:
352
+ pos = layer(pos)
353
+
354
+ # get position biases
355
+ bias = pos[indices]
356
+ bias = rearrange(bias, 'i j h -> h i j')
357
+ return bias
358
+
359
+ class AlibiPositionalBias(nn.Module):
360
+ def __init__(self, heads, total_heads, **kwargs):
361
+ super().__init__()
362
+ self.heads = heads
363
+ self.total_heads = total_heads
364
+
365
+ slopes = Tensor(self._get_slopes(heads))
366
+ slopes = rearrange(slopes, 'h -> h 1 1')
367
+ self.register_buffer('slopes', slopes, persistent = False)
368
+ self.register_buffer('bias', None, persistent = False)
369
+
370
+ def get_bias(self, i, j, device):
371
+ i_arange = torch.arange(j - i, j, device = device)
372
+ j_arange = torch.arange(j, device = device)
373
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
374
+ return bias
375
+
376
+ @staticmethod
377
+ def _get_slopes(heads):
378
+ def get_slopes_power_of_2(n):
379
+ start = (2**(-2**-(math.log2(n)-3)))
380
+ ratio = start
381
+ return [start*ratio**i for i in range(n)]
382
+
383
+ if math.log2(heads).is_integer():
384
+ return get_slopes_power_of_2(heads)
385
+
386
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
387
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
388
+
389
+ @property
390
+ def device(self):
391
+ return next(self.buffers()).device
392
+
393
+ def forward(self, i, j):
394
+ h, device = self.total_heads, self.device
395
+
396
+ if exists(self.bias) and self.bias.shape[-1] >= j:
397
+ return self.bias[..., :i, :j]
398
+
399
+ bias = self.get_bias(i, j, device)
400
+ bias = bias * self.slopes
401
+
402
+ num_heads_unalibied = h - bias.shape[0]
403
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
404
+ self.register_buffer('bias', bias, persistent = False)
405
+
406
+ return self.bias
407
+
408
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
409
+ def __init__(self, heads, total_heads):
410
+ super().__init__(heads, total_heads)
411
+ log_slopes = torch.log(self.slopes)
412
+ self.learned_logslopes = nn.Parameter(log_slopes)
413
+
414
+ def forward(self, i, j):
415
+ h, i, j, device = self.heads, self.device
416
+
417
+ def get_slopes(param):
418
+ return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
419
+
420
+ if exists(self.bias) and self.bias.shape[-1] >= j:
421
+ bias = self.bias[..., :i, :j]
422
+ else:
423
+ bias = self.get_bias(i, j, device)
424
+ self.register_buffer('bias', bias, persistent = False)
425
+
426
+ slopes = get_slopes(self.learned_logslopes)
427
+ bias = bias * slopes
428
+
429
+ return bias
430
+
431
+ class RotaryEmbedding(nn.Module):
432
+ def __init__(
433
+ self,
434
+ dim,
435
+ use_xpos = False,
436
+ scale_base = 512,
437
+ interpolation_factor=1.,
438
+ base=10000,
439
+ base_rescale_factor=1.
440
+ ):
441
+ super().__init__()
442
+ base *= base_rescale_factor ** (dim / (dim - 2))
443
+
444
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
445
+
446
+ self.register_buffer('inv_freq', inv_freq)
447
+
448
+ if not use_xpos:
449
+ self.register_buffer('scale', None)
450
+ return
451
+
452
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
453
+
454
+ self.scale_base = scale_base
455
+ self.register_buffer('scale', scale)
456
+
457
+ def forward(self, seq_len, device):
458
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
459
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
460
+ freqs = torch.cat((freqs, freqs), dim = -1)
461
+
462
+ if not exists(self.scale):
463
+ return freqs, 1.
464
+
465
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
466
+ scale = self.scale ** rearrange(power, 'n -> n 1')
467
+ scale = torch.cat((scale, scale), dim = -1)
468
+
469
+ return freqs, scale
470
+
471
+
472
+ def rotate_half(x):
473
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
474
+ x1, x2 = x.unbind(dim = -2)
475
+ return torch.cat((-x2, x1), dim = -1)
476
+
477
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
478
+ seq_len = t.shape[-2]
479
+ freqs = freqs[-seq_len:, :]
480
+ return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
481
+
482
+ # norms
483
+
484
+ class Scale(nn.Module):
485
+ def __init__(self, value, fn):
486
+ super().__init__()
487
+ self.value = value
488
+ self.fn = fn
489
+
490
+ def forward(self, x, **kwargs):
491
+ out = self.fn(x, **kwargs)
492
+ def scale_fn(t):
493
+ return t * self.value
494
+
495
+ if not isinstance(out, tuple):
496
+ return scale_fn(out)
497
+
498
+ return (scale_fn(out[0]), *out[1:])
499
+
500
+ class ScaleNorm(nn.Module):
501
+ def __init__(self, dim, eps = 1e-5):
502
+ super().__init__()
503
+ self.eps = eps
504
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
505
+
506
+ def forward(self, x):
507
+ norm = torch.norm(x, dim = -1, keepdim = True)
508
+ return x / norm.clamp(min = self.eps) * self.g
509
+
510
+ class RMSNorm(nn.Module):
511
+ def __init__(self, dim, eps = 1e-8):
512
+ super().__init__()
513
+ self.scale = dim ** -0.5
514
+ self.eps = eps
515
+ self.g = nn.Parameter(torch.ones(dim))
516
+
517
+ def forward(self, x):
518
+ norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
519
+ return x / norm.clamp(min = self.eps) * self.g
520
+
521
+ # residual and residual gates
522
+
523
+ class Residual(nn.Module):
524
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
525
+ super().__init__()
526
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
527
+ self.scale_residual_constant = scale_residual_constant
528
+
529
+ def forward(self, x, residual):
530
+ if exists(self.residual_scale):
531
+ residual = residual * self.residual_scale
532
+
533
+ if self.scale_residual_constant != 1:
534
+ residual = residual * self.scale_residual_constant
535
+
536
+ return x + residual
537
+
538
+ class GRUGating(nn.Module):
539
+ def __init__(self, dim, scale_residual = False, **kwargs):
540
+ super().__init__()
541
+ self.gru = nn.GRUCell(dim, dim)
542
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
543
+
544
+ def forward(self, x, residual):
545
+ if exists(self.residual_scale):
546
+ residual = residual * self.residual_scale
547
+
548
+ gated_output = self.gru(
549
+ rearrange(x, 'b n d -> (b n) d'),
550
+ rearrange(residual, 'b n d -> (b n) d')
551
+ )
552
+
553
+ return gated_output.reshape_as(x)
554
+
555
+ # token shifting
556
+
557
+ def shift(t, amount, mask = None):
558
+ if amount == 0:
559
+ return t
560
+ else:
561
+ amount = min(amount, t.shape[1])
562
+
563
+ if exists(mask):
564
+ t = t.masked_fill(~mask[..., None], 0.)
565
+
566
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
567
+
568
+ class ShiftTokens(nn.Module):
569
+ def __init__(self, shifts, fn):
570
+ super().__init__()
571
+ self.fn = fn
572
+ self.shifts = tuple(shifts)
573
+
574
+ def forward(self, x, **kwargs):
575
+ mask = kwargs.get('mask', None)
576
+ shifts = self.shifts
577
+ segments = len(shifts)
578
+ feats_per_shift = x.shape[-1] // segments
579
+ splitted = x.split(feats_per_shift, dim = -1)
580
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
581
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
582
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
583
+ return self.fn(x, **kwargs)
584
+
585
+ # feedforward
586
+
587
+ class GLU(nn.Module):
588
+ def __init__(self, dim_in, dim_out, activation):
589
+ super().__init__()
590
+ self.act = activation
591
+ self.proj = nn.Linear(dim_in, dim_out * 2)
592
+
593
+ def forward(self, x):
594
+ x, gate = self.proj(x).chunk(2, dim = -1)
595
+ return x * self.act(gate)
596
+
597
+ class FeedForward(nn.Module):
598
+ def __init__(
599
+ self,
600
+ dim,
601
+ dim_out = None,
602
+ mult = 4,
603
+ glu = False,
604
+ swish = False,
605
+ relu_squared = False,
606
+ post_act_ln = False,
607
+ dropout = 0.,
608
+ no_bias = False,
609
+ zero_init_output = False
610
+ ):
611
+ super().__init__()
612
+ inner_dim = int(dim * mult)
613
+ dim_out = default(dim_out, dim)
614
+
615
+ if relu_squared:
616
+ activation = ReluSquared()
617
+ elif swish:
618
+ activation = nn.SiLU()
619
+ else:
620
+ activation = nn.GELU()
621
+
622
+ project_in = nn.Sequential(
623
+ nn.Linear(dim, inner_dim, bias = not no_bias),
624
+ activation
625
+ ) if not glu else GLU(dim, inner_dim, activation)
626
+
627
+ self.ff = nn.Sequential(
628
+ project_in,
629
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
630
+ nn.Dropout(dropout),
631
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
632
+ )
633
+
634
+ # init last linear layer to 0
635
+ if zero_init_output:
636
+ init_zero_(self.ff[-1])
637
+
638
+ def forward(self, x):
639
+ return self.ff(x)
640
+
641
+ # attention. it is all we need
642
+
643
+ class Attention(nn.Module):
644
+ def __init__(
645
+ self,
646
+ dim,
647
+ dim_head = DEFAULT_DIM_HEAD,
648
+ heads = 8,
649
+ causal = False,
650
+ flash = False,
651
+ talking_heads = False,
652
+ head_scale = False,
653
+ sparse_topk = None,
654
+ num_mem_kv = 0,
655
+ dropout = 0.,
656
+ on_attn = False,
657
+ gate_values = False,
658
+ zero_init_output = False,
659
+ max_attend_past = None,
660
+ qk_norm = False,
661
+ qk_norm_groups = 1,
662
+ qk_norm_scale = 10,
663
+ qk_norm_dim_scale = False,
664
+ one_kv_head = False,
665
+ shared_kv = False,
666
+ value_dim_head = None,
667
+ tensor_product = False # https://arxiv.org/abs/2208.06061
668
+ ):
669
+ super().__init__()
670
+ self.scale = dim_head ** -0.5
671
+
672
+ self.heads = heads
673
+ self.causal = causal
674
+ self.max_attend_past = max_attend_past
675
+
676
+ value_dim_head = default(value_dim_head, dim_head)
677
+ q_dim = k_dim = dim_head * heads
678
+ v_dim = out_dim = value_dim_head * heads
679
+
680
+ self.one_kv_head = one_kv_head
681
+ if one_kv_head:
682
+ k_dim = dim_head
683
+ v_dim = value_dim_head
684
+ out_dim = v_dim * heads
685
+
686
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
687
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
688
+
689
+ # shared key / values, for further memory savings during inference
690
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
691
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
692
+
693
+ # relations projection from tp-attention
694
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
695
+
696
+ # add GLU gating for aggregated values, from alphafold2
697
+ self.to_v_gate = None
698
+ if gate_values:
699
+ self.to_v_gate = nn.Linear(dim, out_dim)
700
+ nn.init.constant_(self.to_v_gate.weight, 0)
701
+ nn.init.constant_(self.to_v_gate.bias, 1)
702
+
703
+ # cosine sim attention
704
+ self.qk_norm = qk_norm
705
+ self.qk_norm_groups = qk_norm_groups
706
+ self.qk_norm_scale = qk_norm_scale
707
+
708
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
709
+ self.qk_norm_dim_scale = qk_norm_dim_scale
710
+
711
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
712
+ if qk_norm and qk_norm_dim_scale:
713
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head))
714
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head))
715
+
716
+ assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups'
717
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
718
+
719
+ # attend class - includes core attention algorithm + talking heads
720
+
721
+ self.attend = Attend(
722
+ heads = heads,
723
+ causal = causal,
724
+ talking_heads = talking_heads,
725
+ dropout = dropout,
726
+ qk_norm = qk_norm,
727
+ scale = qk_norm_scale if qk_norm else self.scale,
728
+ flash = flash
729
+ )
730
+
731
+ # head scaling
732
+ self.head_scale = head_scale
733
+ if head_scale:
734
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
735
+
736
+ # explicit topk sparse attention
737
+ self.sparse_topk = sparse_topk
738
+
739
+ # add memory key / values
740
+ self.num_mem_kv = num_mem_kv
741
+ if num_mem_kv > 0:
742
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
743
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
744
+
745
+ # attention on attention
746
+ self.attn_on_attn = on_attn
747
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
748
+
749
+ # init output projection 0
750
+ if zero_init_output:
751
+ init_zero_(self.to_out)
752
+
753
+ def forward(
754
+ self,
755
+ x,
756
+ context = None,
757
+ mask = None,
758
+ context_mask = None,
759
+ attn_mask = None,
760
+ rel_pos = None,
761
+ rotary_pos_emb = None,
762
+ prev_attn = None,
763
+ mem = None
764
+ ):
765
+ b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context)
766
+ kv_input = default(context, x)
767
+
768
+ q_input = x
769
+ k_input = kv_input
770
+ v_input = kv_input
771
+ r_input = x
772
+
773
+ if exists(mem):
774
+ k_input = torch.cat((mem, k_input), dim = -2)
775
+ v_input = torch.cat((mem, v_input), dim = -2)
776
+
777
+ q = self.to_q(q_input)
778
+ k = self.to_k(k_input)
779
+ v = self.to_v(v_input) if exists(self.to_v) else k
780
+ r = self.to_r(r_input) if exists(self.to_r) else None
781
+
782
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
783
+
784
+ if not self.one_kv_head:
785
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))
786
+
787
+ if self.qk_norm:
788
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
789
+ q, k = map(qk_l2norm, (q, k))
790
+
791
+ q = q * self.qk_norm_q_scale
792
+ k = k * self.qk_norm_k_scale
793
+
794
+ if exists(rotary_pos_emb) and not has_context:
795
+ freqs, xpos_scale = rotary_pos_emb
796
+ l = freqs.shape[-1]
797
+
798
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
799
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
800
+
801
+ ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)))
802
+ q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr)))
803
+
804
+ input_mask = default(context_mask, mask)
805
+
806
+ if self.num_mem_kv > 0:
807
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
808
+
809
+ if self.qk_norm:
810
+ mem_k = l2norm(mem_k)
811
+ mem_k = mem_k * self.qk_norm_k_scale
812
+
813
+ k = torch.cat((mem_k, k), dim = -2)
814
+ v = torch.cat((mem_v, v), dim = -2)
815
+
816
+ if exists(input_mask):
817
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
818
+
819
+
820
+ i, j = map(lambda t: t.shape[-2], (q, k))
821
+
822
+ # determine masking
823
+
824
+ max_neg_value(q)
825
+ masks = []
826
+ final_attn_mask = None
827
+
828
+ if exists(input_mask):
829
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
830
+ masks.append(~input_mask)
831
+
832
+ if exists(attn_mask):
833
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
834
+ if attn_mask.ndim == 2:
835
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
836
+ elif attn_mask.ndim == 3:
837
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
838
+ masks.append(~attn_mask)
839
+
840
+ if exists(self.max_attend_past):
841
+ range_q = torch.arange(j - i, j, device = device)
842
+ range_k = torch.arange(j, device = device)
843
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
844
+ max_attend_past_mask = dist > self.max_attend_past
845
+ masks.append(max_attend_past_mask)
846
+
847
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
848
+ top, _ = dots.topk(self.sparse_topk, dim = -1)
849
+ vk = rearrange(top[..., -1], '... -> ... 1')
850
+ sparse_topk_mask = dots < vk
851
+ masks.append(sparse_topk_mask)
852
+
853
+ if len(masks) > 0:
854
+ final_attn_mask = or_reduce(masks)
855
+
856
+ # prepare relative positional bias, if needed
857
+
858
+ attn_bias = None
859
+ if exists(rel_pos):
860
+ attn_bias = rel_pos(i, j)
861
+
862
+ # attention is all we need
863
+
864
+ out, intermediates = self.attend(
865
+ q, k, v,
866
+ mask = final_attn_mask,
867
+ attn_bias = attn_bias,
868
+ prev_attn = prev_attn
869
+ )
870
+
871
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
872
+
873
+ if exists(r):
874
+ out = out * r + out
875
+
876
+ # normformer scaling of heads
877
+
878
+ if head_scale:
879
+ out = out * self.head_scale_params
880
+
881
+ # merge heads
882
+
883
+ out = rearrange(out, 'b h n d -> b n (h d)')
884
+
885
+ # alphafold2 styled gating of the values
886
+
887
+ if exists(self.to_v_gate):
888
+ gates = self.to_v_gate(x)
889
+ out = out * gates.sigmoid()
890
+
891
+ # combine the heads
892
+
893
+ out = self.to_out(out)
894
+
895
+ if exists(mask):
896
+ mask = rearrange(mask, 'b n -> b n 1')
897
+ out = out.masked_fill(~mask, 0.)
898
+
899
+ return out, intermediates
900
+
901
+ class AttentionLayers(nn.Module):
902
+ def __init__(
903
+ self,
904
+ dim,
905
+ depth,
906
+ heads = None,
907
+ causal = False,
908
+ cross_attend = False,
909
+ only_cross = False,
910
+ use_scalenorm = False,
911
+ use_rmsnorm = False,
912
+ alibi_pos_bias = False,
913
+ alibi_num_heads = None,
914
+ alibi_learned = False,
915
+ rel_pos_bias = False,
916
+ rel_pos_num_buckets = 32,
917
+ rel_pos_max_distance = 128,
918
+ dynamic_pos_bias = False,
919
+ dynamic_pos_bias_log_distance = False,
920
+ dynamic_pos_bias_mlp_depth = 2,
921
+ dynamic_pos_bias_norm = False,
922
+ rotary_pos_emb = False,
923
+ rotary_emb_dim = None,
924
+ rotary_xpos = False,
925
+ rotary_interpolation_factor=1.,
926
+ rotary_xpos_scale_base = 512,
927
+ rotary_base_rescale_factor=1.,
928
+ custom_layers = None,
929
+ sandwich_coef = None,
930
+ par_ratio = None,
931
+ residual_attn = False,
932
+ cross_residual_attn = False,
933
+ macaron = False,
934
+ pre_norm = True,
935
+ gate_residual = False,
936
+ scale_residual = False,
937
+ scale_residual_constant = 1.,
938
+ deepnorm = False,
939
+ shift_tokens = 0,
940
+ sandwich_norm = False,
941
+ resi_dual = False,
942
+ zero_init_branch_output = False,
943
+ layer_dropout = 0.,
944
+ cross_attn_tokens_dropout = 0.,
945
+ **kwargs
946
+ ):
947
+ super().__init__()
948
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
949
+
950
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
951
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
952
+
953
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
954
+
955
+ self.dim = dim
956
+ self.depth = depth
957
+ self.layers = nn.ModuleList([])
958
+
959
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
960
+
961
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
962
+
963
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
964
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor=rotary_interpolation_factor, base_rescale_factor=rotary_base_rescale_factor) if rotary_pos_emb else None
965
+
966
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
967
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
968
+
969
+ # relative positional bias
970
+
971
+ flash_attn = attn_kwargs.get('flash', False)
972
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
973
+
974
+ self.rel_pos = None
975
+ if rel_pos_bias:
976
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
977
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
978
+ elif dynamic_pos_bias:
979
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
980
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
981
+ elif alibi_pos_bias:
982
+ alibi_num_heads = default(alibi_num_heads, heads)
983
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
984
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
985
+ self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
986
+
987
+ # determine deepnorm and residual scale
988
+
989
+ if deepnorm:
990
+ assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
991
+ pre_norm = sandwich_norm = resi_dual = False
992
+ scale_residual = True
993
+ scale_residual_constant = (2 * depth) ** 0.25
994
+
995
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
996
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
997
+ assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm'
998
+
999
+ self.pre_norm = pre_norm
1000
+ self.sandwich_norm = sandwich_norm
1001
+ self.resi_dual = resi_dual
1002
+
1003
+ self.residual_attn = residual_attn
1004
+ self.cross_residual_attn = cross_residual_attn
1005
+ self.cross_attend = cross_attend
1006
+
1007
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
1008
+ norm_class = RMSNorm if use_rmsnorm else norm_class
1009
+ norm_fn = partial(norm_class, dim)
1010
+
1011
+ if cross_attend and not only_cross:
1012
+ default_block = ('a', 'c', 'f')
1013
+ elif cross_attend and only_cross:
1014
+ default_block = ('c', 'f')
1015
+ else:
1016
+ default_block = ('a', 'f')
1017
+
1018
+ if macaron:
1019
+ default_block = ('f',) + default_block
1020
+
1021
+ # zero init
1022
+
1023
+ if zero_init_branch_output:
1024
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1025
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1026
+
1027
+ # calculate layer block order
1028
+
1029
+ if exists(custom_layers):
1030
+ layer_types = custom_layers
1031
+ elif exists(par_ratio):
1032
+ par_depth = depth * len(default_block)
1033
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1034
+ default_block = tuple(filter(not_equals('f'), default_block))
1035
+ par_attn = par_depth // par_ratio
1036
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1037
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1038
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1039
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1040
+ par_head = par_block * par_attn
1041
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1042
+ elif exists(sandwich_coef):
1043
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1044
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1045
+ else:
1046
+ layer_types = default_block * depth
1047
+
1048
+ self.layer_types = layer_types
1049
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1050
+
1051
+ # stochastic depth
1052
+
1053
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1054
+
1055
+ # structured dropout for cross attending
1056
+
1057
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1058
+
1059
+ # calculate token shifting
1060
+
1061
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1062
+
1063
+ # iterate and construct layers
1064
+
1065
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1066
+ is_last_layer = ind == (len(self.layer_types) - 1)
1067
+
1068
+ if layer_type == 'a':
1069
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1070
+ elif layer_type == 'c':
1071
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1072
+ elif layer_type == 'f':
1073
+ layer = FeedForward(dim, **ff_kwargs)
1074
+ layer = layer if not macaron else Scale(0.5, layer)
1075
+ else:
1076
+ raise Exception(f'invalid layer type {layer_type}')
1077
+
1078
+ if layer_shift_tokens > 0:
1079
+ shift_range_upper = layer_shift_tokens + 1
1080
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1081
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1082
+
1083
+ residual_fn = GRUGating if gate_residual else Residual
1084
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1085
+
1086
+ pre_branch_norm = norm_fn() if pre_norm else None
1087
+ post_branch_norm = norm_fn() if sandwich_norm else None
1088
+ post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None
1089
+
1090
+ norms = nn.ModuleList([
1091
+ pre_branch_norm,
1092
+ post_branch_norm,
1093
+ post_main_norm
1094
+ ])
1095
+
1096
+ self.layers.append(nn.ModuleList([
1097
+ norms,
1098
+ layer,
1099
+ residual
1100
+ ]))
1101
+
1102
+ self.layers_length = len(self.layers) # It doesn't work if called after
1103
+
1104
+ if deepnorm:
1105
+ init_gain = (8 * depth) ** -0.25
1106
+ deepnorm_init(self, init_gain)
1107
+
1108
+ def forward(
1109
+ self,
1110
+ x,
1111
+ context = None,
1112
+ mask = None,
1113
+ context_mask = None,
1114
+ attn_mask = None,
1115
+ self_attn_context_mask = None,
1116
+ mems = None,
1117
+ return_hiddens = False
1118
+ ):
1119
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1120
+
1121
+ hiddens = []
1122
+ intermediates = []
1123
+ prev_attn = None
1124
+ prev_cross_attn = None
1125
+
1126
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1127
+
1128
+ rotary_pos_emb = None
1129
+ if exists(self.rotary_pos_emb):
1130
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1131
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
1132
+
1133
+ outer_residual = x
1134
+
1135
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
1136
+ ind == (self.layers_length - 1)
1137
+
1138
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1139
+ continue
1140
+
1141
+ if layer_type == 'a':
1142
+ if return_hiddens:
1143
+ hiddens.append(x)
1144
+ layer_mem = mems.pop(0) if mems else None
1145
+
1146
+ if layer_type == 'c':
1147
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1148
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1149
+
1150
+ inner_residual = x
1151
+
1152
+ pre_norm, post_branch_norm, post_main_norm = norm
1153
+
1154
+ if exists(pre_norm) and not self.resi_dual:
1155
+ x = pre_norm(x)
1156
+
1157
+ if layer_type == 'a':
1158
+ out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
1159
+ elif layer_type == 'c':
1160
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
1161
+ elif layer_type == 'f':
1162
+ out = block(x)
1163
+
1164
+ if self.resi_dual:
1165
+ outer_residual = residual_fn(out, outer_residual)
1166
+
1167
+ if exists(post_branch_norm):
1168
+ out = post_branch_norm(out)
1169
+
1170
+ x = residual_fn(out, inner_residual)
1171
+
1172
+ if layer_type in ('a', 'c') and return_hiddens:
1173
+ intermediates.append(inter)
1174
+
1175
+ if layer_type == 'a' and self.residual_attn:
1176
+ prev_attn = inter.pre_softmax_attn
1177
+ elif layer_type == 'c' and self.cross_residual_attn:
1178
+ prev_cross_attn = inter.pre_softmax_attn
1179
+
1180
+ if exists(post_main_norm):
1181
+ x = post_main_norm(x)
1182
+
1183
+ if self.resi_dual:
1184
+ x = x + pre_norm(outer_residual)
1185
+
1186
+ if return_hiddens:
1187
+ intermediates = LayerIntermediates(
1188
+ hiddens = hiddens,
1189
+ attn_intermediates = intermediates
1190
+ )
1191
+
1192
+ return x, intermediates
1193
+
1194
+ return x
1195
+
1196
+
1197
+ class Decoder(AttentionLayers):
1198
+ def __init__(self, **kwargs):
1199
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1200
+ super().__init__(causal = True, **kwargs)
1201
+
1202
+
1203
+
1204
+ class Transformer(nn.Module):
1205
+ def __init__(
1206
+ self,
1207
+ *,
1208
+ num_tokens,
1209
+ max_seq_len,
1210
+ attn_layers,
1211
+ # tokenizer: BaseTokenizer,
1212
+ embedding_provider: BaseEmbedding,
1213
+ emb_dim = None,
1214
+ max_mem_len = 0.,
1215
+ shift_mem_down = 0,
1216
+ emb_dropout = 0.,
1217
+ post_emb_norm = False,
1218
+ num_memory_tokens = None,
1219
+ tie_embedding = False,
1220
+ logits_dim = None,
1221
+ use_abs_pos_emb = True,
1222
+ scaled_sinu_pos_emb = False,
1223
+ l2norm_embed = False,
1224
+ emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
1225
+ ):
1226
+ super().__init__()
1227
+
1228
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1229
+
1230
+ dim = attn_layers.dim
1231
+ emb_dim = default(emb_dim, dim)
1232
+
1233
+ self.emb_dim = emb_dim
1234
+ self.num_tokens = num_tokens
1235
+ self.max_seq_len = max_seq_len
1236
+ self.max_mem_len = max_mem_len
1237
+ self.shift_mem_down = shift_mem_down
1238
+
1239
+ self.l2norm_embed = l2norm_embed
1240
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, embedding_provider, l2norm_embed=l2norm_embed)
1241
+
1242
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1243
+ self.pos_emb = always(0)
1244
+ elif scaled_sinu_pos_emb:
1245
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
1246
+ else:
1247
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
1248
+
1249
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1250
+
1251
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1252
+ self.emb_dropout = nn.Dropout(emb_dropout)
1253
+
1254
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1255
+ self.attn_layers = attn_layers
1256
+ self.norm = nn.LayerNorm(dim)
1257
+
1258
+ self.init_()
1259
+
1260
+ logits_dim = default(logits_dim, num_tokens)
1261
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1262
+
1263
+ # memory tokens (like [cls]) from Memory Transformers paper
1264
+ num_memory_tokens = default(num_memory_tokens, 0)
1265
+ self.num_memory_tokens = num_memory_tokens
1266
+ if num_memory_tokens > 0:
1267
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1268
+
1269
+ def init_(self):
1270
+ if self.l2norm_embed:
1271
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
1272
+
1273
+ if not isinstance(self.pos_emb, always):
1274
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
1275
+
1276
+ return
1277
+
1278
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
1279
+
1280
+ def forward(
1281
+ self,
1282
+ x,
1283
+ return_embeddings = False,
1284
+ return_logits_and_embeddings = False,
1285
+ return_intermediates = False,
1286
+ mask = None,
1287
+ return_mems = False,
1288
+ return_attn = False,
1289
+ mems = None,
1290
+ pos = None,
1291
+ prepend_embeds = None,
1292
+ sum_embeds = None,
1293
+ **kwargs
1294
+ ):
1295
+ b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
1296
+ return_hiddens = return_mems | return_attn
1297
+
1298
+ # absolute positional embedding
1299
+
1300
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
1301
+ pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos
1302
+ x = self.token_emb(x) + pos_emb
1303
+
1304
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
1305
+
1306
+ if exists(sum_embeds):
1307
+ x = x + sum_embeds
1308
+
1309
+ # post embedding norm, purportedly leads to greater stabilization
1310
+
1311
+ x = self.post_emb_norm(x)
1312
+
1313
+ # whether to append embeds, as in PaLI, for image embeddings
1314
+
1315
+ if exists(prepend_embeds):
1316
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
1317
+
1318
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
1319
+
1320
+ x = torch.cat((prepend_embeds, x), dim = -2)
1321
+
1322
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
1323
+
1324
+ if emb_frac_gradient < 1:
1325
+ assert emb_frac_gradient > 0
1326
+
1327
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
1328
+
1329
+ # embedding dropout
1330
+
1331
+ x = self.emb_dropout(x)
1332
+
1333
+ x = self.project_emb(x)
1334
+
1335
+ if num_mem > 0:
1336
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
1337
+ x = torch.cat((mem, x), dim = 1)
1338
+
1339
+ # auto-handle masking after appending memory tokens
1340
+ if exists(mask):
1341
+ mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True)
1342
+
1343
+ if self.shift_mem_down and exists(mems):
1344
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1345
+ mems = [*mems_r, *mems_l]
1346
+
1347
+ if return_hiddens:
1348
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1349
+ else:
1350
+ x = self.attn_layers(x, mask = mask, mems = mems, **kwargs)
1351
+
1352
+ x = self.norm(x)
1353
+
1354
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1355
+
1356
+ if return_logits_and_embeddings:
1357
+ out = (self.to_logits(x), x)
1358
+ elif return_embeddings:
1359
+ out = x
1360
+ else:
1361
+ out = self.to_logits(x)
1362
+
1363
+ if return_intermediates:
1364
+ return out, intermediates
1365
+
1366
+ if return_mems:
1367
+ hiddens = intermediates.hiddens
1368
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
1369
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
1370
+ return out, new_mems
1371
+
1372
+ if return_attn:
1373
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1374
+ return out, attn_maps
1375
+
1376
+ return out
Andromeda/Andromeda/dataset_prep/__init__.py ADDED
File without changes
Andromeda/Andromeda/dataset_prep/books.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from Andromeda.dataset_builder import DatasetBuilder
2
+ from build_dataset import DatasetBuilder
3
+
4
+ builder = DatasetBuilder(
5
+ dataset_name="the_pile_books3",
6
+ seq_len=8192,
7
+ num_cpu=4,
8
+ hf_account_repo="kye/the_pile_books3_GPTNeox-8192",
9
+ tokenizer="EleutherAI/gpt-neox-20b",
10
+ )
11
+
12
+ dataset = builder.build_dataset()
Andromeda/Andromeda/inference.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from einops._torch_specific import allow_ops_in_compiled_graph
4
+
5
+ import argparse
6
+
7
+ # class AndromedaEval:
8
+ # def __init__(self, path, seed=42, device=None):
9
+ # self.path = path
10
+ # self.seed = seed
11
+
12
+ # self.device = device
13
+
14
+ # if self.device is None:
15
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ # set_seed(self.seed)
18
+
19
+ # #tokenizer
20
+ # self.tokenizer = AndromedaTokenizer
21
+
22
+ # #model
23
+ # self.model = Andromeda
24
+
25
+ # #checkpoint
26
+ # self.model.load_state_dict(torch.load(self.path))
27
+ # self.model.eval()
28
+
29
+ # #device
30
+ # self.model = self.model.to(self.device)
31
+
32
+ # #metrics
33
+ # self.metrics = {}
34
+ # self.reset_metrics()
35
+
36
+ # def reset_metrics(self):
37
+ # self.metrics = {
38
+ # "generation_steps": None,
39
+ # "time_forward": [],
40
+ # "time_forward_average": None,
41
+
42
+ # "memory_usages": [],
43
+ # "memory_usage_average": None,
44
+ # "time_end_to_end": None,
45
+
46
+ # "throughput": None
47
+ # }
48
+
49
+ # def get_num_params(self):
50
+ # num_params = sum(param.numel() for param in self.model.parameters() if param.requires_grad)
51
+
52
+ # return num_params
53
+
54
+ # def generate(self, prompt, generation_steps=32):
55
+ # #make sure all of the metrics reset at every generation
56
+ # self.reset_metrics()
57
+
58
+ # self.metrics["generation_steps"] = generation_steps
59
+
60
+ # tokens = self.tokenizer.encode(prompt)
61
+ # tokens_new = []
62
+
63
+ # time_end_to_end = time.time()
64
+
65
+ # #generation loop
66
+ # for _ in range(generation_steps):
67
+ # tokens_tensor = torch.tensor([tokens], device=self.device)
68
+
69
+ # #forward pass
70
+ # tracemalloc.start()
71
+
72
+ # time_forward_0 = time.time()
73
+
74
+ # logits = self.model(tokens_tensor, return_loss=False)[:, -1] # no loss takes the output of the last tokens
75
+
76
+ # time_forward_1 = time.time()
77
+
78
+ # _, memory_usage = tracemalloc.get_traced_memory()
79
+ # tracemalloc.stop()
80
+
81
+ # self.metrics["memory_usages"].append(memory_usage)
82
+
83
+ # time_forward = time_forward_1 - time_forward_0
84
+ # self.metrics["times_forward"].append(time_forward)
85
+
86
+ # next_token = torch.armax(logits).item()
87
+
88
+ # #save the newly generated token
89
+ # tokens.append(next_token)
90
+ # tokens_new.append(next_token)
91
+
92
+ # time_end_to_end_1 = time.time()
93
+
94
+ # time_end_to_end = time_end_to_end_1 - time_end_to_end_0
95
+ # self.metrics["time_end_to_end"] = time_end_to_end
96
+
97
+ # decoded = self.tokenizer.decode(tokens)
98
+
99
+ # self.metrics["time_forward_average"] = np.mean(self.metrics["times_forward"])
100
+ # self.metrics["memory_usage_average"] = np.mean(self.metrics["memory_usage"])
101
+
102
+ # self.metrics['throughput'] = generation_steps / np.sum(self.metrics["times_forward"])
103
+
104
+ # return tokens_new, decoded
105
+
106
+
107
+
108
+
109
+ # def main():
110
+ # prompt = 'My name is'
111
+
112
+ # andromeda = EvalAndromeda(path='checkpoints/step_44927_6656/pytorch_model.bin')
113
+
114
+ # num_params = Andromeda.get_num_params()
115
+ # print(f'The model has {num_params} parameters')
116
+
117
+ # _, output = Andromeda.generate(prompt)
118
+
119
+ # for metric, value in Andromeda.metrics.items():
120
+ # print(f'{metric}: {value}\n')
121
+
122
+ # print('\n')
123
+
124
+ # print(output)
125
+
126
+
127
+
128
+
129
+
130
+
131
+ def main():
132
+ allow_ops_in_compiled_graph()
133
+
134
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
135
+
136
+ parser = argparse.ArgumentParser(description="Generate text using Andromeda model")
137
+ parser.add_argument("prompt", type=str, help="Text prompt to generate text")
138
+ parser.add_argument(
139
+ "--seq_len", type=int, default=256, help="Sequence length for generated text"
140
+ )
141
+ parser.add_argument(
142
+ "--temperature", type=float, default=0.8, help="Sampling temperature"
143
+ )
144
+ parser.add_argument(
145
+ "--filter_thres", type=float, default=0.9, help="Filter threshold for sampling"
146
+ )
147
+ parser.add_argument(
148
+ "--model",
149
+ type=str,
150
+ default="andromeda-e-1",
151
+ help="Model to use for generation",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--dtype",
156
+ type=str,
157
+ default="fp32",
158
+ help="Data type for the model: 'bf16', or 'fp32'",
159
+ )
160
+
161
+ args = parser.parse_args()
162
+
163
+
164
+ dtype = torch.float32
165
+ if args.dtype == 'bf16':
166
+ dtype = torch.bfloat16
167
+
168
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+
170
+ #need to submit to torch hub
171
+ model = torch.hub.load("apacai/andromeda", args.model).to(device).to(dtype)
172
+
173
+ opt_model = torch.compile(model, backend="hidet")
174
+
175
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
176
+
177
+ encoded_text = tokenizer(args.prompt, return_tensors="pt")
178
+
179
+ output_tensor = opt_model.generate(
180
+ seq_len=args.seq_len,
181
+ prompt=encoded_text["input_ids"].to(device),
182
+ temperature=args.temperature,
183
+ filter_thres=args.filter_thres,
184
+ pad_value=0.0,
185
+ eos_token=tokenizer.eos_token_id,
186
+ return_seq_without_prompt=False,
187
+ use_tqdm=True,
188
+ )
189
+
190
+ decoded_output = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)
191
+
192
+ return decoded_output
193
+
194
+
195
+ if __name__ == "__main__":
196
+ generated_text = main()
197
+ for text in generated_text:
198
+ print(f"{text}")
Andromeda/Andromeda/model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Module
2
+ from Andromeda.core.transformer import Transformer, AutoregressiveWrapper, AndromedaEmbedding, Decoder
3
+ from transformers import AutoTokenizer
4
+
5
+ class AndromedaTokenizer:
6
+ def __init__(self):
7
+ self.tokenizer= AutoTokenizer.from_pretrained(
8
+ "EleutherAI/gpt-neox-20b",
9
+ eos_token="<eos>",
10
+ pad_token="<pad>",
11
+ extra_ids=0,
12
+ model_max_length=8192
13
+ )
14
+
15
+ def tokenize_texts(self, texts):
16
+ return self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
17
+
18
+ def decode(self, texts):
19
+ return self.tokenizer.decode(texts)
20
+
21
+ def __len__(self):
22
+ num_tokens = len(self.tokenizer)
23
+ return num_tokens
24
+
25
+
26
+
27
+ class Andromeda(Module):
28
+ """
29
+ Andromeda is a transformer-based model architecture. It initializes with
30
+ a Transformer and AutoregressiveWrapper with default or user-specified parameters.
31
+ """
32
+ def __init__(self,
33
+ num_tokens=50432,
34
+ max_seq_len=8192,
35
+ dim=2560,
36
+ depth=32,
37
+ dim_head=128,
38
+ heads=24,
39
+ use_abs_pos_emb=False,
40
+ alibi_pos_bias=True,
41
+ alibi_num_heads=12,
42
+ rotary_xpos=True,
43
+ attn_flash=True,
44
+ # shift_tokens=1,
45
+ attn_one_kv_head=True, # multiquery attention
46
+ qk_norm=True,
47
+ attn_qk_norm=True,
48
+ attn_qk_norm_dim_scale=True,
49
+ embedding_provider=AndromedaEmbedding()):
50
+ """
51
+ Initialize the model with specified or default parameters.
52
+ Args:
53
+ - num_tokens: Number of tokens in the vocabulary
54
+ - max_seq_len: Maximum sequence length
55
+ - dim: Dimension of the model
56
+ - depth: Depth of the model
57
+ - dim_head: Dimension of the model head
58
+ - heads: Number of heads
59
+ - use_abs_pos_emb: Whether to use absolute position embedding
60
+ - alibi_pos_bias: Alibi position bias
61
+ - alibi_num_heads: Number of alibi heads
62
+ - rotary_xpos: Rotary position
63
+ - attn_flash: Attention flash
64
+ - deepnorm: Deep normalization
65
+ - shift_tokens: Number of tokens to shift
66
+ - attn_one_kv_head: Attention one key/value head
67
+ - qk_norm: Query-key normalization
68
+ - attn_qk_norm: Attention query-key normalization
69
+ - attn_qk_norm_dim_scale: Attention query-key normalization dimension scale
70
+ - embedding_provider: Embedding provider module
71
+ """
72
+ super().__init__()
73
+
74
+ try:
75
+ self.Andromeda = Transformer(
76
+ num_tokens=num_tokens,
77
+ max_seq_len=max_seq_len,
78
+ use_abs_pos_emb=use_abs_pos_emb,
79
+ embedding_provider=embedding_provider,
80
+ attn_layers=Decoder(
81
+ dim=dim,
82
+ depth=depth,
83
+ dim_head=dim_head,
84
+ heads=heads,
85
+ alibi_pos_bias=alibi_pos_bias,
86
+ alibi_num_heads=alibi_num_heads,
87
+ rotary_xpos=rotary_xpos,
88
+ attn_flash=attn_flash,
89
+ # deepnorm=deepnorm,
90
+ # shift_tokens=shift_tokens,
91
+ attn_one_kv_head=attn_one_kv_head,
92
+ qk_norm=qk_norm,
93
+ attn_qk_norm=attn_qk_norm,
94
+ attn_qk_norm_dim_scale=attn_qk_norm_dim_scale
95
+ )
96
+ )
97
+
98
+ self.decoder = AutoregressiveWrapper(self.Andromeda)
99
+
100
+ except Exception as e:
101
+ print("Failed to initialize Andromeda: ", e)
102
+ raise
103
+
104
+ def forward(self, text_tokens, **kwargs):
105
+ """
106
+ Forward pass through the model. It expects the input text_tokens.
107
+ Args:
108
+ - text_tokens: Input tokens
109
+ - kwargs: Other arguments
110
+ Returns:
111
+ - output from the decoder
112
+ """
113
+ try:
114
+ model_input = self.decoder.forward(text_tokens)[0]
115
+ return self.decoder(model_input, padded_x=model_input[0])
116
+ except Exception as e:
117
+ print("Failed in forward method: ", e)
118
+ raise
Andromeda/Andromeda/old/__init__.py ADDED
File without changes
Andromeda/Andromeda/old/sophia.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.optim.optimizer import Optimizer
4
+ from typing import List
5
+
6
+
7
+ class SophiaG(Optimizer):
8
+ def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho = 0.04,
9
+ weight_decay=1e-1, *, maximize: bool = False,
10
+ capturable: bool = False):
11
+ if not 0.0 <= lr:
12
+ raise ValueError("Invalid learning rate: {}".format(lr))
13
+ if not 0.0 <= betas[0] < 1.0:
14
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
15
+ if not 0.0 <= betas[1] < 1.0:
16
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
17
+ if not 0.0 <= rho:
18
+ raise ValueError("Invalid rho parameter at index 1: {}".format(rho))
19
+ if not 0.0 <= weight_decay:
20
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
21
+ defaults = dict(lr=lr, betas=betas, rho=rho,
22
+ weight_decay=weight_decay,
23
+ maximize=maximize, capturable=capturable)
24
+ super(SophiaG, self).__init__(params, defaults)
25
+
26
+ def __setstate__(self, state):
27
+ super().__setstate__(state)
28
+ for group in self.param_groups:
29
+ group.setdefault('maximize', False)
30
+ group.setdefault('capturable', False)
31
+ state_values = list(self.state.values())
32
+ step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
33
+ if not step_is_tensor:
34
+ for s in state_values:
35
+ s['step'] = torch.tensor(float(s['step']))
36
+
37
+ @torch.no_grad()
38
+ def update_hessian(self):
39
+ for group in self.param_groups:
40
+ beta1, beta2 = group['betas']
41
+ for p in group['params']:
42
+ if p.grad is None:
43
+ continue
44
+ state = self.state[p]
45
+
46
+ if len(state) == 0:
47
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
48
+ if self.defaults['capturable'] else torch.tensor(0.)
49
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
50
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
51
+
52
+ if 'hessian' not in state.keys():
53
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
54
+
55
+ state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2)
56
+
57
+
58
+ @torch.no_grad()
59
+ def step(self, closure=None, bs=5120):
60
+ loss = None
61
+ if closure is not None:
62
+ with torch.enable_grad():
63
+ loss = closure()
64
+
65
+ for group in self.param_groups:
66
+ params_with_grad = []
67
+ grads = []
68
+ exp_avgs = []
69
+ state_steps = []
70
+ hessian = []
71
+ beta1, beta2 = group['betas']
72
+
73
+ for p in group['params']:
74
+ if p.grad is None:
75
+ continue
76
+ params_with_grad.append(p)
77
+
78
+ if p.grad.is_sparse:
79
+ raise RuntimeError('Hero does not support sparse gradients')
80
+ grads.append(p.grad)
81
+ state = self.state[p]
82
+ # State initialization
83
+ if len(state) == 0:
84
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
85
+ if self.defaults['capturable'] else torch.tensor(0.)
86
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
87
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
88
+
89
+ if 'hessian' not in state.keys():
90
+ state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format)
91
+
92
+ exp_avgs.append(state['exp_avg'])
93
+ state_steps.append(state['step'])
94
+ hessian.append(state['hessian'])
95
+
96
+ if self.defaults['capturable']:
97
+ bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs
98
+
99
+ sophiag(params_with_grad,
100
+ grads,
101
+ exp_avgs,
102
+ hessian,
103
+ state_steps,
104
+ bs=bs,
105
+ beta1=beta1,
106
+ beta2=beta2,
107
+ rho=group['rho'],
108
+ lr=group['lr'],
109
+ weight_decay=group['weight_decay'],
110
+ maximize=group['maximize'],
111
+ capturable=group['capturable'])
112
+
113
+ return loss
114
+
115
+ def sophiag(params: List[Tensor],
116
+ grads: List[Tensor],
117
+ exp_avgs: List[Tensor],
118
+ hessian: List[Tensor],
119
+ state_steps: List[Tensor],
120
+ capturable: bool = False,
121
+ *,
122
+ bs: int,
123
+ beta1: float,
124
+ beta2: float,
125
+ rho: float,
126
+ lr: float,
127
+ weight_decay: float,
128
+ maximize: bool):
129
+
130
+ if not all(isinstance(t, torch.Tensor) for t in state_steps):
131
+ raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
132
+
133
+
134
+ func = _single_tensor_sophiag
135
+
136
+ func(params,
137
+ grads,
138
+ exp_avgs,
139
+ hessian,
140
+ state_steps,
141
+ bs=bs,
142
+ beta1=beta1,
143
+ beta2=beta2,
144
+ rho=rho,
145
+ lr=lr,
146
+ weight_decay=weight_decay,
147
+ maximize=maximize,
148
+ capturable=capturable)
149
+
150
+ def _single_tensor_sophiag(params: List[Tensor],
151
+ grads: List[Tensor],
152
+ exp_avgs: List[Tensor],
153
+ hessian: List[Tensor],
154
+ state_steps: List[Tensor],
155
+ *,
156
+ bs: int,
157
+ beta1: float,
158
+ beta2: float,
159
+ rho: float,
160
+ lr: float,
161
+ weight_decay: float,
162
+ maximize: bool,
163
+ capturable: bool):
164
+
165
+ for i, param in enumerate(params):
166
+ grad = grads[i] if not maximize else -grads[i]
167
+ exp_avg = exp_avgs[i]
168
+ hess = hessian[i]
169
+ step_t = state_steps[i]
170
+
171
+ if capturable:
172
+ assert param.is_cuda and step_t.is_cuda and bs.is_cuda
173
+
174
+ if torch.is_complex(param):
175
+ grad = torch.view_as_real(grad)
176
+ exp_avg = torch.view_as_real(exp_avg)
177
+ hess = torch.view_as_real(hess)
178
+ param = torch.view_as_real(param)
179
+
180
+ # update step
181
+ step_t += 1
182
+
183
+ # Perform stepweight decay
184
+ param.mul_(1 - lr * weight_decay)
185
+
186
+ # Decay the first and second moment running average coefficient
187
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
188
+
189
+ if capturable:
190
+ step_size = lr
191
+ step_size_neg = step_size.neg()
192
+
193
+ ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
194
+ param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
195
+ else:
196
+ step_t.item()
197
+ step_size_neg = - lr
198
+
199
+ ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None,1)
200
+ param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg)
Andromeda/Andromeda/old/training.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #quantization + paralleism
2
+ import time
3
+
4
+ import torch
5
+ from accelerate.utils import set_seed
6
+ from datasets import load_dataset
7
+ from torch.nn import CrossEntropyLoss
8
+ from torch.utils.data import DataLoader
9
+ from transformers import default_data_collator, get_linear_schedule_with_warmup
10
+ from accelerate import Accelerator
11
+
12
+ from rich.progress import Progress
13
+
14
+
15
+ from lion_pytorch import Lion
16
+ # from x_transformers import Transformer, Decoder, AutoregressiveWrapper
17
+ from optimus_prim import Transformer, Decoder, AutoregressiveWrapper
18
+
19
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
20
+ import torch.distributed as dist
21
+
22
+ from torch.distributed.fsdp import (
23
+ FullyShardedDataParallel,
24
+ CPUOffload,
25
+ )
26
+
27
+ from torch.distributed.fsdp.wrap import (
28
+ default_auto_wrap_policy,
29
+ )
30
+
31
+ from transformers import AutoTokenizer
32
+
33
+ #logging
34
+ import boto3
35
+
36
+
37
+ #training
38
+ import wandb
39
+
40
+ from torch.utils.tensorboard import SummaryWriter
41
+
42
+ class CustomGPTNeoXTokenizer:
43
+ def __init__(self):
44
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
45
+
46
+ def tokenize(self, text):
47
+ return self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
48
+
49
+ custom_tokenizer = CustomGPTNeoXTokenizer()
50
+
51
+ Andromeda = Transformer(
52
+ num_tokens=64007,
53
+ max_seq_len=8192,
54
+ use_abs_pos_emb = False,
55
+ tokenizer=custom_tokenizer,
56
+ attn_layers = Decoder(
57
+ dim=2048,
58
+ depth=6,
59
+ heads=16,
60
+ alibi_pos_bias=True,
61
+ alibi_num_heads=8,
62
+ rotary_xpos=True,
63
+ attn_flash = True,
64
+ deepnorm=True,
65
+ shift_tokens=1,
66
+ attn_one_kv_head = True,
67
+ qk_norm=True
68
+ )
69
+ )
70
+
71
+ Andromeda = AutoregressiveWrapper(Andromeda)
72
+
73
+
74
+
75
+ AWS_ACCESS_KEY_ID=""
76
+ AWS_SECRET_ACCESS_KEY="d"
77
+
78
+
79
+ def save_model_to_s3(model, bucket_name, key_prefix, step):
80
+ s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
81
+ model_path = f"checkpoint_at_step_{step}.pt"
82
+ torch.save(model.state_dict(), model_path)
83
+ s3.upload_file(model_path, bucket_name, f"{key_prefix}/{model_path}")
84
+
85
+
86
+
87
+ def count_number_of_parameters(model, only_trainable: bool = True) -> int:
88
+ if only_trainable:
89
+ num_params: int = sum(p.numel()
90
+ for p in model.parameters() if p.requires_grad)
91
+ else:
92
+ num_params: int = sum(p.numel() for p in model.parameters() if p)
93
+ return int(num_params)
94
+
95
+
96
+
97
+ def prep_sample(sample):
98
+ title = sample["title"]
99
+ text = sample["text"]
100
+ return {
101
+ "title": title,
102
+ "text": text
103
+ }
104
+
105
+
106
+ def train(args):
107
+
108
+ if args.use_ddp:
109
+ dist.init_process_group(backend="nccl")
110
+
111
+
112
+ accelerator = Accelerator(
113
+ mixed_precision="fp16",
114
+ gradient_accumulation_steps=1,
115
+ )
116
+
117
+ # If passed along, set the training seed now.
118
+ if args.seed is not None:
119
+ set_seed(args.seed)
120
+
121
+ #v1
122
+ model = Andromeda()
123
+ if args.use_ddp:
124
+ model = DistributedDataParallel(model)
125
+ else:
126
+ model = DataParallel(model)
127
+
128
+ fsdp_model = FullyShardedDataParallel(
129
+ model(),
130
+ fsdp_auto_wrap_policy=default_auto_wrap_policy,
131
+ cpu_offload=CPUOffload(offload_params=True),
132
+ )
133
+
134
+ fsdp_model = fsdp_model.to(accelerator.device)
135
+
136
+ #device count
137
+ if torch.cuda.device_count() > 1:
138
+ print(f"Let's use ${torch.cuda.device_count()} GPUS")
139
+
140
+
141
+
142
+
143
+ optimizer = Lion(model.parameters(), lr=args.learning_rate / 3, weight_decay=args.weight_decay * 3)
144
+
145
+ lr_scheduler = get_linear_schedule_with_warmup(
146
+ optimizer=optimizer,
147
+ num_warmup_steps=args.warmup_steps,
148
+ num_training_steps=args.max_steps,
149
+ )
150
+
151
+ # tokenizer = KosmosTokenizer()
152
+
153
+ #====================> load data #====================> load data #====================> load data
154
+
155
+
156
+ dataset = load_dataset("the_pile_books3")
157
+
158
+ # dataset = dataset.map(prep_sample, num_proc=8)
159
+ dataset = dataset.map(prep_sample, num_proc=8)
160
+
161
+
162
+ #new removed columns
163
+ remove_columns = ['title']
164
+
165
+
166
+ dataset = dataset.map(Andromeda.decoder.tokenizer, batched=True,
167
+ batch_size=128, remove_columns=remove_columns)
168
+
169
+ train_dataloader = DataLoader(
170
+ dataset, collate_fn=default_data_collator, batch_size=args.batch_size, pin_memory=True
171
+ )
172
+
173
+
174
+
175
+ #====================> load data #====================> load data #====================> load data #====================> load data
176
+
177
+ fsdp_model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(fsdp_model, train_dataloader, optimizer,
178
+ lr_scheduler)
179
+ fsdp_model.train()
180
+ accelerator.register_for_checkpointing(lr_scheduler)
181
+
182
+ accelerator.print(
183
+ f"Number of parameters: {count_number_of_parameters(model):,}")
184
+ accelerator.print(
185
+ f"Number of trainable parameters: {count_number_of_parameters(model, only_trainable=True):,}")
186
+
187
+ # Log model and optimizer parameters to wandb
188
+ accelerator.init_trackers(project_name="Andromeda")
189
+
190
+ #wandb
191
+ wandb.init(project="Andromeda", config=args)
192
+
193
+ #init tensorboard writer
194
+ tb_writer = SummaryWriter()
195
+
196
+
197
+ train_loader = iter(train_dataloader)
198
+ epoch_loss = 0
199
+ total_loss = 0
200
+ start_time = time.time()
201
+
202
+ with Progress() as progress:
203
+ task = progress.add_task("[red]Training...", total=args.max_steps)
204
+ for step in range(0, args.max_steps):
205
+ batch_start = time.time()
206
+ batch = next(train_loader)
207
+ outputs = fsdp_model(**batch, self_attn_padding_mask=batch["attention_mask"])
208
+ # Shift so that tokens < n predict n
209
+ outputs = torch.cat([outputs[:, :1], outputs[:, 67:]], dim=1).contiguous()
210
+ # shift_logits = outputs[..., :-1, :].contiguous()
211
+ # shift_labels = batch["labels"][..., 1:].contiguous()
212
+ # Flatten the tokens
213
+ loss_fct = CrossEntropyLoss()
214
+ one_hot_labels = torch.nn.functional.one_hot(batch["labels"][:, 1:], num_classes=32002).float()
215
+ loss = loss_fct(outputs[:,:-1], one_hot_labels)
216
+
217
+ epoch_loss += loss.detach().float()
218
+
219
+ accelerator.backward(loss)
220
+ optimizer.step()
221
+ optimizer.zero_grad()
222
+
223
+ batch_end = time.time()
224
+ logs = {
225
+ "loss": loss.item(),
226
+ "perplexity": torch.exp(loss).item(),
227
+ "lr": lr_scheduler.get_last_lr()[0],
228
+ "examples": args.batch_size * (step + 1),
229
+ "examples_per_second": args.batch_size / (batch_end - batch_start),
230
+ }
231
+ if step % args.log_every == args.log_every - 1:
232
+ #log metrics to wandb
233
+ wandb.log(logs, step=step)
234
+
235
+ #log metrics to tensorboard
236
+ # Log metrics to TensorBoard
237
+ tb_writer.add_scalar("loss", logs["loss"], step)
238
+ tb_writer.add_scalar("perplexity", logs["perplexity"], step)
239
+ tb_writer.add_scalar("lr", logs["lr"], step)
240
+ tb_writer.add_scalar("examples", logs["examples"], step)
241
+ tb_writer.add_scalar("examples_per_second", logs["examples_per_second"], step)
242
+
243
+ #accelerator
244
+ accelerator.log(logs, step=step)
245
+ progress.update(task, advance=1, description=f"Step Loss: {loss.item():.5f} "
246
+ f"| Mean Loss: {(total_loss + epoch_loss) / step:.5f} "
247
+ f"| Mean PPL: {torch.exp((total_loss + epoch_loss) / step):.2f} "
248
+ f"| Examples: {args.batch_size * (step + 1)} "
249
+ f"| Examples/s: {args.batch_size / (batch_end - batch_start):.2f} "
250
+ f"| Elapsed: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")
251
+
252
+ if step % args.save_every == args.save_every - 1:
253
+ train_epoch_loss = epoch_loss / args.save_every
254
+ total_loss += epoch_loss
255
+ epoch_loss = 0
256
+
257
+ accelerator.log({
258
+ "train_ppl": torch.exp(train_epoch_loss),
259
+ "train_epoch_loss": train_epoch_loss,
260
+ }, step=step)
261
+
262
+ progress.print(f"Saving checkpoint at step {step}...")
263
+ accelerator.save_state(
264
+ f"{args.checkpoint_dir}/checkpoint_at_step_{step}/")
265
+
266
+ #save the model weights to s3
267
+ save_model_to_s3(model, "kosmostraining", "kosmosv1/checkpoints", step)
268
+ print(f"Saved to s3: {save_model_to_s3} ")
269
+
270
+ #finish tensorboard writer
271
+ tb_writer.close()
272
+
273
+ #finish wnabd run
274
+ wandb.finish()
275
+
276
+
277
+ if __name__ == "__main__":
278
+ import argparse
279
+
280
+ parser = argparse.ArgumentParser()
281
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
282
+ parser.add_argument("--learning_rate", type=float, default=1e-5)
283
+ parser.add_argument("--weight_decay", type=float, default=0.01)
284
+ parser.add_argument("--warmup_steps", type=int, default=0)
285
+ parser.add_argument("--max_steps", type=int, default=100000)
286
+ parser.add_argument("--batch_size", type=int, default=4)
287
+ parser.add_argument("--log_every", type=int, default=1)
288
+ parser.add_argument("--save_every", type=int, default=100)
289
+ parser.add_argument("--seed", type=int, default=None)
290
+ parser.add_argument("--use_ddp", action="store_true", help="Use DistributedDataParallel")
291
+
292
+ args = parser.parse_args()
293
+
294
+ train(args)
Andromeda/Andromeda/old/training_1.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import multiprocessing
3
+ import os
4
+
5
+ from datetime import timedelta
6
+ from functools import partial
7
+ from itertools import chain
8
+
9
+
10
+ from accelerate import Accelerator
11
+ from accelerate.utils import InitProcessGroupKwargs
12
+
13
+ from datasets import concatenate_datasets, load_dataset
14
+
15
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16
+ CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
17
+
18
+ from torch.utils.data import DataLoader
19
+
20
+ from tqdm import tqdm
21
+
22
+ from transformers import (AutoTokenizer, default_data_collator,
23
+ get_cosine_schedule_with_warmup,
24
+ get_linear_schedule_with_warmup, set_seed)
25
+
26
+
27
+ # from stable_adamw import StableAdamWUnfused
28
+ # sd
29
+
30
+ from optimus_prime import Transformer, Decoder, AutoregressiveWrapper
31
+ from optimus_prime import AndromedaEmbedding
32
+
33
+ from lion_pytorch import Lion
34
+
35
+
36
+ # constants
37
+
38
+ class CFG:
39
+ BATCH_SIZE: int = 3 # 3
40
+ GRADIENT_ACCUMULATE_EVERY: int = 1
41
+ SEED: int = 42
42
+ LEARNING_RATE: float = 1e-4
43
+ WEIGHT_DECAY: float = 1e-2
44
+ SEQ_LEN: int = 8192 # 8192
45
+ NUM_CPU: int = multiprocessing.cpu_count()
46
+ USE_PRETOKENIZED: bool = True
47
+ USE_ACTIVATION_CHECKPOINTING: bool = True
48
+ RESUME_FROM_CHECKPOINT: str = None
49
+ CHECKPOINTING_STEPS: int = 1000
50
+ OUTPUT_DIR: str = "output"
51
+ ENTITY_NAME: str = "wanb" # Put your wandb username here
52
+
53
+ # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY)
54
+
55
+ # helpers
56
+
57
+ def print_num_params(model, accelerator: Accelerator):
58
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
59
+ accelerator.print(f"Number of parameters in model: {n_params}")
60
+
61
+ def fsdp_activation_checkpointing(
62
+ model, accelerator: Accelerator, offload_to_cpu=False
63
+ ):
64
+
65
+ accelerator.print("Using FSDP activation checkpointing")
66
+
67
+ # check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)
68
+
69
+ non_reentrant_wrapper = partial(
70
+ checkpoint_wrapper,
71
+ offload_to_cpu=offload_to_cpu,
72
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
73
+ )
74
+
75
+ apply_activation_checkpointing(
76
+ model, checkpoint_wrapper_fn=non_reentrant_wrapper)
77
+
78
+
79
+ def get_lr_scheduler_with_warmup(
80
+ optimizer, scheduler_type, num_warmup_steps, max_train_steps, grad_accumulate_every
81
+ ):
82
+ NUM_WARMUP_STEPS = num_warmup_steps
83
+ GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
84
+
85
+ if scheduler_type == "linear":
86
+ return get_linear_schedule_with_warmup(
87
+ optimizer=optimizer,
88
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
89
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
90
+ )
91
+ elif scheduler_type == "cosine":
92
+ return get_cosine_schedule_with_warmup(
93
+ optimizer=optimizer,
94
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
95
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
96
+ )
97
+ else:
98
+ raise ValueError(
99
+ "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
100
+ scheduler_type
101
+ )
102
+ )
103
+
104
+
105
+ def build_dataloaders():
106
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
107
+ dataset = load_dataset("openwebtext", split="train")
108
+
109
+ tokenized_dataset = dataset.map(
110
+ lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
111
+ batched=True,
112
+ num_proc=CFG.NUM_CPU,
113
+ remove_columns=["text"],
114
+ )
115
+
116
+ block_size = CFG.SEQ_LEN
117
+
118
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
119
+ def group_texts(examples):
120
+ # Concatenate all texts.
121
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
122
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
123
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
124
+ # customize this part to your needs.
125
+ if total_length >= block_size:
126
+ total_length = (total_length // block_size) * block_size
127
+ # Split by chunks of max_len.
128
+ result = {
129
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
130
+ for k, t in concatenated_examples.items()
131
+ }
132
+ return result
133
+
134
+ train_dataset = tokenized_dataset.map(
135
+ group_texts, batched=True, num_proc=CFG.NUM_CPU,
136
+ )
137
+
138
+ return train_dataset
139
+
140
+ # main
141
+
142
+ def TrainAndromeda():
143
+ # accelerator
144
+
145
+ timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
146
+
147
+ accelerator = Accelerator(
148
+ gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
149
+ mixed_precision="fp16",
150
+ log_with="wandb",
151
+ kwargs_handlers=[timeout],
152
+ deepspeed_plugin=deepspeed_plugin
153
+ )
154
+
155
+ accelerator.init_trackers(
156
+ project_name="andromeda",
157
+ config={
158
+ "batch_size": CFG.BATCH_SIZE,
159
+ "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
160
+ "learning_rate": CFG.LEARNING_RATE,
161
+ "seq_len": CFG.SEQ_LEN,
162
+ },
163
+ init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}
164
+ )
165
+
166
+ accelerator.print(f"Total GPUS: {accelerator.num_processes}")
167
+
168
+ # set seed
169
+
170
+ set_seed(CFG.SEED)
171
+
172
+ # Create the tokenizer
173
+
174
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
175
+
176
+ # instantiate andromeda
177
+
178
+ model = Transformer(
179
+ num_tokens=64007,
180
+ max_seq_len=8192,
181
+ use_abs_pos_emb=False,
182
+ tokenizer=tokenizer, # !
183
+ embedding_provider=AndromedaEmbedding(),
184
+ attn_layers = Decoder(
185
+ dim=128, # 2048
186
+ depth=8, # 16
187
+ dim_head=128,
188
+ heads=8,
189
+ alibi_pos_bias=True,
190
+ alibi_num_heads=4,
191
+ rotary_xpos=True,
192
+ attn_flash = True,
193
+ deepnorm=True,
194
+ shift_tokens=1,
195
+ attn_one_kv_head = True,
196
+ qk_norm=True,
197
+ attn_qk_norm=True,
198
+ attn_qk_norm_dim_scale=True # set this to True, in addition to `attn_qk_norm = True`
199
+ )
200
+ ).to(accelerator.device)
201
+
202
+ model = AutoregressiveWrapper(model).to(accelerator.device)
203
+
204
+ optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2, use_triton=True)
205
+
206
+ print_num_params(model, accelerator)
207
+
208
+ if CFG.USE_ACTIVATION_CHECKPOINTING:
209
+ fsdp_activation_checkpointing(model, accelerator)
210
+
211
+ # dataloaders
212
+
213
+ if CFG.USE_PRETOKENIZED:
214
+ d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train")
215
+ d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
216
+ d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
217
+ d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
218
+ d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
219
+
220
+ train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
221
+ else:
222
+ train_dataset = build_dataloaders()
223
+
224
+ train_loader = DataLoader(
225
+ train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
226
+ )
227
+
228
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
229
+ accelerator.print(f"Max train steps: {max_train_steps}")
230
+
231
+ # lr scheduler
232
+ # We cant decide on an actual number
233
+
234
+ NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
235
+ accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
236
+
237
+ lr_scheduler = get_lr_scheduler_with_warmup(
238
+ optimizer=optim,
239
+ scheduler_type="cosine",
240
+ num_warmup_steps=NUM_WARMUP_STEPS,
241
+ max_train_steps=max_train_steps,
242
+ grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY
243
+ )
244
+
245
+ # prepare
246
+
247
+ model, optim, train_loader, lr_scheduler = accelerator.prepare(
248
+ model, optim, train_loader, lr_scheduler
249
+ )
250
+
251
+ # checkpoint scheduler
252
+
253
+ accelerator.register_for_checkpointing(lr_scheduler)
254
+
255
+ # I do not know why Huggingface recommends recalculation of max_train_steps
256
+
257
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
258
+ accelerator.print(f"Max train steps recalculated: {max_train_steps}")
259
+
260
+ # Total batch size for logging
261
+
262
+ total_batch_size = (
263
+ CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
264
+ )
265
+ accelerator.print(f"Total batch size: {total_batch_size}")
266
+
267
+ # resume training
268
+
269
+ progress_bar = tqdm(
270
+ range(max_train_steps), disable=not accelerator.is_local_main_process
271
+ )
272
+ completed_steps = 0
273
+
274
+ if CFG.RESUME_FROM_CHECKPOINT:
275
+ if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
276
+ accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
277
+ accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
278
+ path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
279
+
280
+ training_difference = os.path.splitext(path)[0]
281
+
282
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
283
+ resume_step = (
284
+ int(training_difference.replace("step_", ""))
285
+ * CFG.GRADIENT_ACCUMULATE_EVERY
286
+ )
287
+
288
+ if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
289
+ train_loader = accelerator.skip_first_batches(train_loader, resume_step)
290
+ completed_steps += resume_step
291
+ progress_bar.update(resume_step)
292
+
293
+ # training
294
+
295
+ model.train()
296
+
297
+ for step, batch in enumerate(train_loader):
298
+ with accelerator.accumulate(model):
299
+ inputs = batch["input_ids"].to(accelerator.device)
300
+ _, loss = model(inputs, return_loss=True)
301
+ accelerator.backward(loss)
302
+
303
+ # print(loss.item())
304
+
305
+ accelerator.log({"loss": loss.item()}, step=step)
306
+
307
+ if accelerator.sync_gradients:
308
+ accelerator.clip_grad_norm_(model.parameters(), 0.5)
309
+
310
+ optim.step()
311
+ lr_scheduler.step()
312
+ optim.zero_grad()
313
+
314
+ if accelerator.sync_gradients:
315
+ progress_bar.update(1)
316
+ completed_steps += 1
317
+
318
+ if isinstance(CFG.CHECKPOINTING_STEPS, int):
319
+ if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
320
+ output_dir = f"step_{completed_steps }"
321
+ if CFG.OUTPUT_DIR is not None:
322
+ output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
323
+ accelerator.save_state(output_dir)
324
+
325
+ if completed_steps >= max_train_steps:
326
+ break
327
+
328
+ # end training
329
+
330
+ accelerator.print("Training Finished")
331
+ accelerator.end_training()
332
+
333
+ # save final model
334
+
335
+ # accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
336
+ if CFG.OUTPUT_DIR is not None:
337
+ base_path = f'{CFG.OUTPUT_DIR}/final'
338
+
339
+ if not os.path.exists(base_path):
340
+ os.makedirs(base_path)
341
+
342
+ accelerator.wait_for_everyone()
343
+ unwrapped_model = accelerator.unwrap_model(model)
344
+ with accelerator.main_process_first():
345
+ accelerator.save(
346
+ unwrapped_model.state_dict(), os.path.join(base_path, 'final_model.pt')
347
+ )
348
+
349
+ if __name__ == "__main__":
350
+ TrainAndromeda()
Andromeda/Andromeda/old/training_sophia.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import multiprocessing
3
+ import os
4
+
5
+ from datetime import timedelta
6
+ from functools import partial
7
+ from itertools import chain
8
+
9
+
10
+ from accelerate import Accelerator
11
+ from accelerate.utils import InitProcessGroupKwargs
12
+
13
+ from datasets import concatenate_datasets, load_dataset
14
+
15
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
16
+ CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper)
17
+
18
+ from torch.utils.data import DataLoader
19
+
20
+ from tqdm import tqdm
21
+
22
+ from transformers import (AutoTokenizer, default_data_collator,
23
+ get_cosine_schedule_with_warmup,
24
+ get_linear_schedule_with_warmup, set_seed)
25
+
26
+
27
+ # from stable_adamw import StableAdamWUnfused
28
+ # sd
29
+
30
+ from optimus_prime import Transformer, Decoder, AutoregressiveWrapper
31
+ from optimus_prime import AndromedaEmbedding
32
+
33
+ from sophia import SophiaG
34
+
35
+ # constants
36
+
37
+ class CFG:
38
+ BATCH_SIZE: int = 3 # 3
39
+ GRADIENT_ACCUMULATE_EVERY: int = 1
40
+ SEED: int = 42
41
+ LEARNING_RATE: float = 1e-4
42
+ WEIGHT_DECAY: float = 1e-2
43
+ SEQ_LEN: int = 8192 # 8192
44
+ NUM_CPU: int = multiprocessing.cpu_count()
45
+ USE_PRETOKENIZED: bool = True
46
+ USE_ACTIVATION_CHECKPOINTING: bool = True
47
+ RESUME_FROM_CHECKPOINT: str = None
48
+ CHECKPOINTING_STEPS: int = 1000
49
+ OUTPUT_DIR: str = "output"
50
+ ENTITY_NAME: str = "nicolo" # Put your wandb username here
51
+
52
+ # helpers
53
+
54
+ def print_num_params(model, accelerator: Accelerator):
55
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
56
+ accelerator.print(f"Number of parameters in model: {n_params}")
57
+
58
+ def fsdp_activation_checkpointing(
59
+ model, accelerator: Accelerator, offload_to_cpu=False
60
+ ):
61
+
62
+ accelerator.print("Using FSDP activation checkpointing")
63
+
64
+ # check_fn = lambda submodule: isinstance(submodule, ParallelTransformerBlock)
65
+
66
+ non_reentrant_wrapper = partial(
67
+ checkpoint_wrapper,
68
+ offload_to_cpu=offload_to_cpu,
69
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
70
+ )
71
+
72
+ apply_activation_checkpointing(
73
+ model, checkpoint_wrapper_fn=non_reentrant_wrapper)
74
+
75
+
76
+ def get_lr_scheduler_with_warmup(
77
+ optimizer, scheduler_type, num_warmup_steps, max_train_steps, grad_accumulate_every
78
+ ):
79
+ NUM_WARMUP_STEPS = num_warmup_steps
80
+ GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
81
+
82
+ if scheduler_type == "linear":
83
+ return get_linear_schedule_with_warmup(
84
+ optimizer=optimizer,
85
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
86
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
87
+ )
88
+ elif scheduler_type == "cosine":
89
+ return get_cosine_schedule_with_warmup(
90
+ optimizer=optimizer,
91
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
92
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY
93
+ )
94
+ else:
95
+ raise ValueError(
96
+ "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
97
+ scheduler_type
98
+ )
99
+ )
100
+
101
+
102
+ def build_dataloaders():
103
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
104
+
105
+ content_column = 'text'
106
+
107
+ dataset = load_dataset("sentiment140", split="train")
108
+ dataset = dataset.remove_columns([col for col in dataset.column_names if col != content_column])
109
+
110
+ tokenized_dataset = dataset.map(
111
+ lambda example: tokenizer([t + tokenizer.eos_token for t in example[content_column]]),
112
+ batched=True,
113
+ num_proc=CFG.NUM_CPU,
114
+ remove_columns=[content_column]
115
+ )
116
+
117
+ block_size = CFG.SEQ_LEN
118
+
119
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
120
+ def group_texts(examples):
121
+ # Concatenate all texts.
122
+ concatenated_examples = {}
123
+
124
+ for k in examples.keys():
125
+ concatenated_examples[k] = list(chain(*examples[k]))
126
+
127
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
128
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
129
+ # customize this part to your needs.
130
+ if total_length >= block_size:
131
+ total_length = (total_length // block_size) * block_size
132
+ # Split by chunks of max_len.
133
+ result = {
134
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
135
+ for k, t in concatenated_examples.items()
136
+ }
137
+
138
+ return result
139
+
140
+ train_dataset = tokenized_dataset.map(
141
+ group_texts, batched=True, num_proc=CFG.NUM_CPU
142
+ )
143
+
144
+ return train_dataset
145
+
146
+ # main
147
+
148
+ def TrainAndromeda():
149
+ # accelerator
150
+
151
+ timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
152
+
153
+ accelerator = Accelerator(
154
+ gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
155
+ mixed_precision="fp16", # Switch to bf16
156
+ log_with="wandb",
157
+ kwargs_handlers=[timeout]
158
+ )
159
+
160
+ accelerator.init_trackers(
161
+ project_name="andromeda",
162
+ config={
163
+ "batch_size": CFG.BATCH_SIZE,
164
+ "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
165
+ "learning_rate": CFG.LEARNING_RATE,
166
+ "seq_len": CFG.SEQ_LEN,
167
+ },
168
+ init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}}
169
+ )
170
+
171
+ accelerator.print(f"Total GPUS: {accelerator.num_processes}")
172
+
173
+ # set seed
174
+
175
+ set_seed(CFG.SEED)
176
+
177
+ # Create the tokenizer
178
+
179
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
180
+
181
+ # instantiate andromeda
182
+
183
+ model = Transformer(
184
+ num_tokens=64007,
185
+ max_seq_len=8192,
186
+ use_abs_pos_emb=False,
187
+ tokenizer=tokenizer, # !
188
+ embedding_provider=AndromedaEmbedding(),
189
+ attn_layers = Decoder(
190
+ dim=128, # 2048
191
+ depth=8, # 16
192
+ dim_head=128,
193
+ heads=8,
194
+ alibi_pos_bias=True,
195
+ alibi_num_heads=4,
196
+ rotary_xpos=True,
197
+ attn_flash = True,
198
+ # deepnorm=True,
199
+ shift_tokens=1,
200
+ attn_one_kv_head = True,
201
+ qk_norm=True,
202
+ attn_qk_norm=True,
203
+ attn_qk_norm_dim_scale=True # set this to True, in addition to `attn_qk_norm = True`
204
+ )
205
+ ).to(accelerator.device)
206
+
207
+ model = AutoregressiveWrapper(model).to(accelerator.device)
208
+
209
+ #optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
210
+ optim = SophiaG(model.parameters(), lr=1e-5, weight_decay=1e-1)
211
+
212
+ print_num_params(model, accelerator)
213
+
214
+ if CFG.USE_ACTIVATION_CHECKPOINTING:
215
+ fsdp_activation_checkpointing(model, accelerator)
216
+
217
+ # dataloaders
218
+
219
+ if CFG.USE_PRETOKENIZED:
220
+ d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train")
221
+ d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
222
+ d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
223
+ d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
224
+ d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
225
+
226
+ train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
227
+ else:
228
+ train_dataset = build_dataloaders()
229
+
230
+ train_loader = DataLoader(
231
+ train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
232
+ )
233
+
234
+ # optimizer
235
+
236
+ # optim = decoupled_optimizer(
237
+ # model,
238
+ # learning_rate=CFG.LEARNING_RATE,
239
+ # weight_decay=CFG.WEIGHT_DECAY,
240
+ # beta_1=0.9,
241
+ # beta_2=0.95,
242
+ # use_adamw=False,
243
+ # )
244
+
245
+ # Determine number of training steps
246
+
247
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
248
+ accelerator.print(f"Max train steps: {max_train_steps}")
249
+
250
+ # lr scheduler
251
+ # We cant decide on an actual number
252
+
253
+ NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
254
+ accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
255
+
256
+ lr_scheduler = get_lr_scheduler_with_warmup(
257
+ optimizer=optim,
258
+ scheduler_type="cosine",
259
+ num_warmup_steps=NUM_WARMUP_STEPS,
260
+ max_train_steps=max_train_steps,
261
+ grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY
262
+ )
263
+
264
+ # prepare
265
+
266
+ model, optim, train_loader, lr_scheduler = accelerator.prepare(
267
+ model, optim, train_loader, lr_scheduler
268
+ )
269
+
270
+ # checkpoint scheduler
271
+
272
+ accelerator.register_for_checkpointing(lr_scheduler)
273
+
274
+ # I do not know why Huggingface recommends recalculation of max_train_steps
275
+
276
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
277
+ accelerator.print(f"Max train steps recalculated: {max_train_steps}")
278
+
279
+ # Total batch size for logging
280
+
281
+ total_batch_size = (
282
+ CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
283
+ )
284
+ accelerator.print(f"Total batch size: {total_batch_size}")
285
+
286
+ # resume training
287
+
288
+ progress_bar = tqdm(
289
+ range(max_train_steps), disable=not accelerator.is_local_main_process
290
+ )
291
+ completed_steps = 0
292
+
293
+ if CFG.RESUME_FROM_CHECKPOINT:
294
+ if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
295
+ accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
296
+ accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
297
+ path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
298
+
299
+ training_difference = os.path.splitext(path)[0]
300
+
301
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
302
+ resume_step = (
303
+ int(training_difference.replace("step_", ""))
304
+ * CFG.GRADIENT_ACCUMULATE_EVERY
305
+ )
306
+
307
+ if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
308
+ train_loader = accelerator.skip_first_batches(train_loader, resume_step)
309
+ completed_steps += resume_step
310
+ progress_bar.update(resume_step)
311
+
312
+ # training
313
+
314
+ model.train()
315
+
316
+ for step, batch in enumerate(train_loader):
317
+ with accelerator.accumulate(model):
318
+ inputs = batch["input_ids"].to(accelerator.device)
319
+ _, loss = model(inputs, return_loss=True)
320
+ accelerator.backward(loss)
321
+
322
+ # print(loss.item())
323
+
324
+ accelerator.log({"loss": loss.item()}, step=step)
325
+
326
+ if accelerator.sync_gradients:
327
+ accelerator.clip_grad_norm_(model.parameters(), 0.5)
328
+
329
+ optim.step()
330
+ lr_scheduler.step()
331
+ optim.zero_grad()
332
+
333
+ if accelerator.sync_gradients:
334
+ progress_bar.update(1)
335
+ completed_steps += 1
336
+
337
+ if isinstance(CFG.CHECKPOINTING_STEPS, int):
338
+ if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
339
+ output_dir = f"step_{completed_steps }"
340
+ if CFG.OUTPUT_DIR is not None:
341
+ output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
342
+ accelerator.save_state(output_dir)
343
+
344
+ if completed_steps >= max_train_steps:
345
+ break
346
+
347
+ # end training
348
+
349
+ accelerator.print("Training Finished")
350
+ accelerator.end_training()
351
+
352
+ # save final model
353
+
354
+ # accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
355
+ if CFG.OUTPUT_DIR is not None:
356
+ base_path = f'{CFG.OUTPUT_DIR}/final'
357
+
358
+ if not os.path.exists(base_path):
359
+ os.makedirs(base_path)
360
+
361
+ accelerator.wait_for_everyone()
362
+ unwrapped_model = accelerator.unwrap_model(model)
363
+ with accelerator.main_process_first():
364
+ accelerator.save(
365
+ unwrapped_model.state_dict(), os.path.join(base_path, 'final_model.pt')
366
+ )
367
+
368
+ if __name__ == "__main__":
369
+ TrainAndromeda()
Andromeda/Andromeda/train.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import multiprocessing
3
+ import os
4
+ from datetime import timedelta
5
+ from functools import partial
6
+ from itertools import chain
7
+
8
+ import torch
9
+
10
+ ########### SETUP CONFIG
11
+ import torch.distributed as dist
12
+ from accelerate import Accelerator
13
+ from accelerate.logging import get_logger
14
+ from accelerate.state import AcceleratorState
15
+ from accelerate.utils import DummyOptim, InitProcessGroupKwargs
16
+ from datasets import load_dataset
17
+ from lion_pytorch import Lion
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
19
+ CheckpointImpl,
20
+ apply_activation_checkpointing,
21
+ checkpoint_wrapper,
22
+ )
23
+
24
+ # import bitsandbytes as bnb
25
+ from torch.distributed.fsdp import (
26
+ BackwardPrefetch,
27
+ FullyShardedDataParallel,
28
+ MixedPrecision,
29
+ ShardingStrategy,
30
+ )
31
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
32
+ from torch.nn import LayerNorm
33
+ from torch.optim import AdamW
34
+ from torch.utils.data import DataLoader
35
+ from tqdm import tqdm
36
+ from transformers import (
37
+ AutoTokenizer,
38
+ default_data_collator,
39
+ get_cosine_schedule_with_warmup,
40
+ get_linear_schedule_with_warmup,
41
+ set_seed,
42
+ )
43
+
44
+ # from Andromeda.model import Andromeda
45
+ from Andromeda.configs import Andromeda1Billion
46
+ from Andromeda.core.transformer import Transformer
47
+ from Andromeda.utils.stable_adamw import StableAdamWUnfused
48
+
49
+ # state = AcceleratorState()
50
+
51
+
52
+ logger = get_logger(__name__, log_level="INFO")
53
+
54
+ class CFG:
55
+ BATCH_SIZE = 1
56
+ GRADIENT_ACCUMULATE_EVERY: int = 1
57
+ SEED: int = 42
58
+ LEARNING_RATE: float = 1e-4 #3e-4 # 1e-4 for lion
59
+ WEIGHT_DECAY: float = 0.1
60
+ SEQ_LEN: int = 8192
61
+ NUM_CPU: int = multiprocessing.cpu_count()
62
+ USE_DEEPSPEED: bool = True
63
+ USE_FSDP: bool = True
64
+ USE_PRETOKENIZED: bool = True
65
+ USE_ACTIVATION_CHECKPOINTING: bool = True
66
+ RESUME_FROM_CHECKPOINT: str = False
67
+ CHECKPOINTING_STEPS: int = 1000
68
+ OUTPUT_DIR: str = 'checkpoints/' # Folder
69
+ ENTITY_NAME: str = "Andromeda"
70
+ LOGGING_STEPS: int = 100
71
+
72
+
73
+ # helpers
74
+
75
+
76
+ def print_num_params(model, accelerator: Accelerator):
77
+ # n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
78
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
79
+ accelerator.print(f"Number of parameters in model: {n_params}")
80
+
81
+
82
+ # activation checkpointing
83
+
84
+
85
+ def activation_checkpointing(
86
+ model: torch.nn.Module,
87
+ offload_to_cpu: bool = False,
88
+ accelerator: Accelerator = None,
89
+ ):
90
+ """
91
+ Apply activation checkpointing to a model.
92
+
93
+ Args:
94
+ model (Module): The model to which to apply activation checkpointing.
95
+ offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False.
96
+ accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
97
+ """
98
+ if accelerator is not None:
99
+ accelerator.print("Using activation checkpointing")
100
+ def check_fn(submodule):
101
+ return isinstance(submodule, Transformer)
102
+ non_reentrant_wrapper = partial(
103
+ checkpoint_wrapper,
104
+ offload_to_cpu=offload_to_cpu,
105
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
106
+ )
107
+ apply_activation_checkpointing(
108
+ model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
109
+ )
110
+
111
+
112
+ # FSDP
113
+
114
+
115
+ def fsdp(
116
+ model: torch.nn.Module,
117
+ auto_wrap: bool = False,
118
+ mp: str = "fp32",
119
+ shard_strat: str = "NO_SHARD",
120
+ ):
121
+ """
122
+ This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding.
123
+
124
+ Args:
125
+ model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP.
126
+ auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False.
127
+ mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'.
128
+ shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'.
129
+
130
+ Raises:
131
+ ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'.
132
+ ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'.
133
+
134
+ Returns:
135
+ torch.nn.Module: The input model wrapped with FSDP.
136
+ """
137
+ if auto_wrap:
138
+ Andromeda_auto_wrap_policy = partial(
139
+ transformer_auto_wrap_policy,
140
+ transformer_layer_cls={
141
+ Transformer,
142
+ },
143
+ )
144
+ else:
145
+ Andromeda_auto_wrap_policy = None
146
+
147
+ if mp == "bf16":
148
+ mp_fsdp = MixedPrecision(
149
+ param_dtype=torch.bfloat16,
150
+ # Gradient communication precision.
151
+ reduce_dtype=torch.bfloat16,
152
+ # Buffer precision.
153
+ buffer_dtype=torch.bfloat16,
154
+ )
155
+ elif mp == "fp16":
156
+ mp_fsdp = MixedPrecision(
157
+ param_dtype=torch.float16,
158
+ # Gradient communication precision.
159
+ reduce_dtype=torch.float16,
160
+ # Buffer precision.
161
+ buffer_dtype=torch.float16,
162
+ )
163
+ elif mp == "fp32":
164
+ mp_fsdp = MixedPrecision(
165
+ param_dtype=torch.float32,
166
+ # Gradient communication precision.
167
+ reduce_dtype=torch.float32,
168
+ # Buffer precision.
169
+ buffer_dtype=torch.float32,
170
+ )
171
+ else:
172
+ raise ValueError(
173
+ "Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format(
174
+ mp
175
+ )
176
+ )
177
+
178
+ if shard_strat == "SHARD_GRAD":
179
+ sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP
180
+ elif shard_strat == "FULL_SHARD":
181
+ sharding_strat_fsdp = ShardingStrategy.FULL_SHARD
182
+ elif shard_strat == "NO_SHARD":
183
+ sharding_strat_fsdp = ShardingStrategy.NO_SHARD
184
+ else:
185
+ raise ValueError(
186
+ "Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format(
187
+ shard_strat
188
+ )
189
+ )
190
+
191
+ model = FullyShardedDataParallel(
192
+ model,
193
+ auto_wrap_policy=Andromeda_auto_wrap_policy,
194
+ mixed_precision=mp_fsdp,
195
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
196
+ sharding_strategy=sharding_strat_fsdp,
197
+ forward_prefetch=True,
198
+ use_orig_params=True,
199
+ )
200
+
201
+ return model
202
+
203
+
204
+ # learning rate scheduler
205
+
206
+
207
+ def get_lr_scheduler_with_warmup(
208
+ optimizer: torch.optim.Optimizer,
209
+ scheduler_type: str,
210
+ num_warmup_steps: int,
211
+ max_train_steps: int,
212
+ grad_accumulate_every: int = 1,
213
+ accelerator: Accelerator = None,
214
+ ):
215
+ """
216
+ Get a learning rate scheduler with warmup.
217
+
218
+ Args:
219
+ optimizer (Optimizer): The optimizer for which to create the learning rate scheduler.
220
+ scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine".
221
+ num_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
222
+ max_train_steps (int): The maximum number of training steps.
223
+ grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1.
224
+ accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None.
225
+
226
+ Returns:
227
+ The learning rate scheduler with warmup.
228
+
229
+ Raises:
230
+ ValueError: If scheduler_type is not "linear" or "cosine".
231
+ """
232
+ NUM_WARMUP_STEPS = num_warmup_steps
233
+ GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every
234
+ if accelerator is not None:
235
+ accelerator.print(f"Using {scheduler_type} lr scheduler")
236
+ if scheduler_type == "linear":
237
+ return get_linear_schedule_with_warmup(
238
+ optimizer=optimizer,
239
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
240
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
241
+ )
242
+ elif scheduler_type == "cosine":
243
+ return get_cosine_schedule_with_warmup(
244
+ optimizer=optimizer,
245
+ num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY,
246
+ num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY,
247
+ )
248
+ else:
249
+ raise ValueError(
250
+ "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format(
251
+ scheduler_type
252
+ )
253
+ )
254
+
255
+
256
+ # optimizers
257
+
258
+
259
+ def decoupled_optimizer(
260
+ model: torch.nn.Module,
261
+ learning_rate: float,
262
+ weight_decay: float,
263
+ beta_1: float,
264
+ beta_2: float,
265
+ optimizer_type: str,
266
+ use_fsdp: bool = True,
267
+ accelerator: Accelerator = None,
268
+ ):
269
+ """
270
+ Decouples the optimizer from the training process.
271
+
272
+ This function sets up the optimizer for the model by creating two groups of parameters:
273
+ one for weight decay and one without weight decay. Then, it initializes the optimizer
274
+ with these two groups of parameters.
275
+
276
+ Args:
277
+ model (Module): The model whose parameters are optimized.
278
+ learning_rate (float): The learning rate for the optimizer.
279
+ weight_decay (float): The weight decay for the optimizer.
280
+ beta_1 (float): The exponential decay rate for the 1st moment estimates.
281
+ beta_2 (float): The exponential decay rate for the 2nd moment estimates.
282
+ optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
283
+ use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
284
+ accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.
285
+
286
+ Returns:
287
+ Optimizer: The initialized optimizer.
288
+
289
+ Raises:
290
+ ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
291
+ """
292
+ accelerator.print(f"Using {optimizer_type} optimizer")
293
+ # Create an empty dictionary called param_dict to store the model's named parameters.
294
+ param_dict = {}
295
+ # Iterate over the model's named parameters and populate the param_dict with key-value pairs.
296
+ for param_name, param in model.named_parameters():
297
+ param_dict[param_name] = param
298
+
299
+ # Separate the model's named modules into two groups: decay and no_decay.
300
+
301
+ # Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
302
+ no_decay = []
303
+
304
+ if use_fsdp:
305
+ exclude_module = "_fsdp_wrapped_module.token_emb"
306
+ else:
307
+ exclude_module = "token_emb"
308
+
309
+ # Iterate through the named modules of the model.
310
+ for module_name, module in model.named_modules():
311
+ # Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
312
+ for ndim in [LayerNorm, torch.nn.Embedding]:
313
+ if isinstance(module, ndim):
314
+ # If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
315
+ if module_name == exclude_module:
316
+ no_decay.append(f"{module_name}.weight")
317
+ else:
318
+ # If the module is an instance of LayerNorm
319
+ no_decay.append(f"{module_name}.gamma")
320
+ # Exit the inner loop since the desired module has been found.
321
+ break
322
+
323
+ # Create an empty list to store the names of the Linear layer weights with weight decay.
324
+ decay = []
325
+
326
+ # Iterate through the named modules of the model.
327
+ for module_name, module in model.named_modules():
328
+ # Check if the current module is an instance of the desired type (torch.nn.Linear).
329
+ for ndim in [torch.nn.Linear]:
330
+ if isinstance(module, ndim):
331
+ # If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
332
+ decay.append(f"{module_name}.weight")
333
+ # Exit the inner loop since the desired module has been found.
334
+ break
335
+
336
+ # Create two separate lists of model parameters: decay_param and no_decay_param.
337
+ # The decay_param list contains the parameters that should have weight decay applied.
338
+ # The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.
339
+
340
+ # Create an empty list called decay_param to store the parameters with weight decay.
341
+ decay_param = []
342
+
343
+ if use_fsdp:
344
+ exclude_param = "_fsdp_wrapped_module.to_logits.weight"
345
+ else:
346
+ exclude_param = "to_logits.weight"
347
+
348
+ # Iterate over the decay list, which contains the names of the parameters with weight decay.
349
+ for param in decay:
350
+ # Check if the current parameter is not 'to_logits.weight'.
351
+ # Append the corresponding parameter from param_dict to the decay_param list.
352
+
353
+ if param != exclude_param:
354
+ decay_param.append(param_dict[param])
355
+
356
+ # Create an empty list called no_decay_param to store the parameters without weight decay.
357
+ no_decay_param = []
358
+
359
+ # Iterate over the no_decay list, which contains the names of the parameters without weight decay.
360
+ for param in no_decay:
361
+ try:
362
+
363
+ # Append the corresponding parameter from param_dict to the no_decay_param list.
364
+ no_decay_param.append(param_dict[param])
365
+ except KeyError:
366
+ # print(f"Parameter {param_name} does not exist in the model")
367
+ pass
368
+
369
+ # Create a list called grouped_params that contains two dictionaries.
370
+ # The first dictionary has the decay_param list and the corresponding weight_decay value.
371
+ # The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
372
+ grouped_params = [
373
+ {"params": decay_param, "weight_decay": weight_decay},
374
+ {"params": no_decay_param, "weight_decay": 0.0},
375
+ ]
376
+
377
+ # Create a variable called optimizer that stores an instance of the optimizer.
378
+ if optimizer_type == "lion":
379
+ optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
380
+ elif optimizer_type == "adamw":
381
+ optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
382
+ elif optimizer_type == "deepspeed":
383
+ optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),)
384
+ elif optimizer_type == "stable_adamw":
385
+ optimizer = StableAdamWUnfused(
386
+ grouped_params, lr=learning_rate, betas=(beta_1, beta_2),
387
+ )
388
+ # elif optimizer_type=="Adam8bit":
389
+ # optimizer = bnb.optim.Adam8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
390
+ # elif optimizer_type=="Lion8Bit":
391
+ # optimizer = bnb.optim.Lion8bit(grouped_params, lr=learning_rate, betas=(beta_1, beta_2))
392
+ else:
393
+ raise ValueError(
394
+ "Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
395
+ optimizer_type
396
+ )
397
+ )
398
+
399
+ # Return the optimizer.
400
+ return optimizer
401
+
402
+
403
+ # dataloaders
404
+
405
+
406
+ def build_dataloaders():
407
+ """
408
+ Build data loaders for training.
409
+
410
+ This function performs the following steps:
411
+ 1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model.
412
+ 2. Load the "openwebtext" dataset.
413
+ 3. Tokenize the dataset, adding the end-of-sentence token to each text.
414
+ 4. Process the tokenized dataset into chunks of a specified block size.
415
+
416
+ Returns:
417
+ Dataset: The processed dataset ready for training.
418
+ """
419
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
420
+ dataset = load_dataset("openwebtext", split="train")
421
+
422
+ tokenized_dataset = dataset.map(
423
+ lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]),
424
+ batched=True,
425
+ num_proc=CFG.NUM_CPU,
426
+ remove_columns=["text"],
427
+ )
428
+
429
+ block_size = CFG.SEQ_LEN
430
+
431
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
432
+ def group_texts(examples):
433
+ # Concatenate all texts.
434
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
435
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
436
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
437
+ # customize this part to your needs.
438
+ if total_length >= block_size:
439
+ total_length = (total_length // block_size) * block_size
440
+ # Split by chunks of max_len.
441
+ result = {
442
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
443
+ for k, t in concatenated_examples.items()
444
+ }
445
+ return result
446
+
447
+ train_dataset = tokenized_dataset.map(
448
+ group_texts, batched=True, num_proc=CFG.NUM_CPU,
449
+ )
450
+
451
+ return train_dataset
452
+
453
+ #switch to falconwebdataset
454
+ def build_pre_tokenized():
455
+ d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]")
456
+ # d1 = load_dataset("conceptofmind/c4_21-to-40_neox_with_eos_8k", split="train")
457
+ # d2 = load_dataset("conceptofmind/c4_41-to-60_neox_with_eos_8k", split="train")
458
+ # d3 = load_dataset("conceptofmind/c4_61-to-80_neox_with_eos_8k", split="train")
459
+ # d4 = load_dataset("conceptofmind/c4_81-to-100_neox_with_eos_8k", split="train")
460
+ # train_dataset = concatenate_datasets([d0, d1, d2, d3, d4])
461
+ return d0
462
+
463
+
464
+
465
+ def Train():
466
+ # accelerator
467
+
468
+ timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
469
+
470
+ accelerator = Accelerator(
471
+ gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY,
472
+ mixed_precision="fp16",
473
+ log_with="wandb",
474
+ kwargs_handlers=[timeout],
475
+ )
476
+
477
+ state = AcceleratorState()
478
+
479
+ state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE #??????
480
+
481
+ accelerator.init_trackers(
482
+ project_name="Andromeda",
483
+ config={
484
+ "batch_size": CFG.BATCH_SIZE,
485
+ "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY,
486
+ "learning_rate": CFG.LEARNING_RATE,
487
+ "seq_len": CFG.SEQ_LEN,
488
+ },
489
+ # init_kwargs={"wandb": {"entity": CFG.ENTITY_NAME}},
490
+ )
491
+
492
+ accelerator.print(f"Total GPUS: {accelerator.num_processes}")
493
+
494
+ # set seed
495
+
496
+ set_seed(CFG.SEED)
497
+
498
+ # model = Andromeda(
499
+ # num_tokens=50432,
500
+ # max_seq_len=8192,
501
+ # dim=3072,
502
+ # depth=24,
503
+ # dim_head=128,
504
+ # heads=12,
505
+ # use_abs_pos_emb=False,
506
+ # alibi_pos_bias=True,
507
+ # alibi_num_heads=6,
508
+ # rotary_xpos=True,
509
+ # attn_flash=True,
510
+ # shift_tokens=1,
511
+ # attn_one_kv_head=True,
512
+ # qk_norm=True,
513
+ # attn_qk_norm=True,
514
+ # attn_qk_norm_dim_scale=True,
515
+ # embedding_provider=AndromedaEmbedding()
516
+ # )
517
+ model = Andromeda1Billion()
518
+
519
+ print_num_params(model, accelerator)
520
+
521
+ if CFG.USE_FSDP:
522
+ model = fsdp(
523
+ model,
524
+ mp="fp16",
525
+ shard_strat="SHARD_GRAD"
526
+ )
527
+
528
+ if CFG.USE_ACTIVATION_CHECKPOINTING:
529
+ activation_checkpointing(model, accelerator)
530
+
531
+ model = accelerator.prepare(model)
532
+
533
+ # dataloaders
534
+
535
+ if CFG.USE_PRETOKENIZED:
536
+ train_dataset = build_pre_tokenized()
537
+ else:
538
+ train_dataset = build_dataloaders()
539
+
540
+ train_loader = DataLoader(
541
+ train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator,
542
+ )
543
+
544
+
545
+ # optimizer
546
+ optim = decoupled_optimizer(
547
+ model=model,
548
+ learning_rate=CFG.LEARNING_RATE,
549
+ weight_decay=CFG.WEIGHT_DECAY,
550
+ beta_1=0.90,
551
+ beta_2=0.95,
552
+ optimizer_type='lion',
553
+ use_fsdp=True,
554
+ accelerator=accelerator
555
+ )
556
+
557
+ # Determine number of training steps
558
+
559
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
560
+ accelerator.print(f"Max train steps: {max_train_steps}")
561
+
562
+ # lr scheduler
563
+
564
+ NUM_WARMUP_STEPS = int(max_train_steps * 0.01)
565
+ accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}")
566
+
567
+ # if False: # if CFG.USE_DEEPSPEED:
568
+ # lr_scheduler = DummyScheduler(
569
+ # optim,
570
+ # total_num_steps=max_train_steps * accelerator.num_processes,
571
+ # warmup_num_steps=NUM_WARMUP_STEPS
572
+ # )
573
+ # else:
574
+ lr_scheduler = get_lr_scheduler_with_warmup(
575
+ optimizer=optim,
576
+ scheduler_type="cosine",
577
+ num_warmup_steps=NUM_WARMUP_STEPS,
578
+ max_train_steps=max_train_steps,
579
+ grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY,
580
+ )
581
+
582
+ # prepare
583
+
584
+ optim, train_loader, lr_scheduler = accelerator.prepare(
585
+ optim, train_loader, lr_scheduler
586
+ )
587
+
588
+ # checkpoint scheduler
589
+
590
+ accelerator.register_for_checkpointing(lr_scheduler)
591
+
592
+ # I do not know why Huggingface recommends recalculation of max_train_steps
593
+
594
+ max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY)
595
+ accelerator.print(f"Max train steps recalculated: {max_train_steps}")
596
+
597
+ # Total batch size for logging
598
+
599
+ total_batch_size = (
600
+ CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY
601
+ )
602
+ accelerator.print(f"Total batch size: {total_batch_size}")
603
+
604
+ # resume training
605
+
606
+ progress_bar = tqdm(
607
+ range(max_train_steps), disable=not accelerator.is_local_main_process
608
+ )
609
+ completed_steps = 0
610
+
611
+ if CFG.RESUME_FROM_CHECKPOINT:
612
+ if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "":
613
+ accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}")
614
+ accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT)
615
+ path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT)
616
+ training_difference = os.path.splitext(path)[0]
617
+
618
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
619
+ resume_step = (
620
+ int(training_difference.replace("step_", ""))
621
+ * CFG.GRADIENT_ACCUMULATE_EVERY
622
+ )
623
+
624
+ if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None:
625
+ train_loader = accelerator.skip_first_batches(train_loader, resume_step)
626
+ completed_steps += resume_step
627
+ progress_bar.update(resume_step)
628
+
629
+ # training
630
+
631
+ model.train()
632
+ for step, batch in enumerate(train_loader):
633
+ with accelerator.accumulate(model):
634
+ inputs = batch["input_ids"].to(accelerator.device)
635
+ loss = model(inputs, return_loss=True)
636
+ accelerator.backward(loss)
637
+
638
+ accelerator.log({"loss": loss.item()}, step=step)
639
+
640
+ if accelerator.sync_gradients:
641
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
642
+
643
+ optim.step()
644
+ lr_scheduler.step()
645
+ optim.zero_grad()
646
+
647
+ if accelerator.sync_gradients:
648
+ progress_bar.update(1)
649
+ completed_steps += 1
650
+
651
+ if isinstance(CFG.CHECKPOINTING_STEPS, int):
652
+ if completed_steps % CFG.CHECKPOINTING_STEPS == 0:
653
+ output_dir = f"step_{completed_steps }"
654
+ if CFG.OUTPUT_DIR is not None:
655
+ output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir)
656
+ accelerator.save_state(output_dir)
657
+
658
+ if completed_steps >= max_train_steps:
659
+ break
660
+
661
+ #logging every CFG.LOGGING STEPS
662
+ if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0:
663
+ logger.info(
664
+ f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}"
665
+ )
666
+
667
+ # end training
668
+
669
+ # accelerator.print(f"Training Finished")
670
+ accelerator.end_training()
671
+
672
+ # save final model
673
+
674
+ # accelerator.print(f"Saving model to {CFG.OUTPUT_DIR}")
675
+ if CFG.OUTPUT_DIR is not None:
676
+ accelerator.wait_for_everyone()
677
+ unwrapped_model = accelerator.unwrap_model(model)
678
+ with accelerator.main_process_first():
679
+ accelerator.save(
680
+ unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt"
681
+ )
682
+
683
+
684
+ def train():
685
+ os.environ['MASTER_ADDR'] #'localhost'
686
+ os.environ['MASTER_PORT'] #= '9994'
687
+
688
+ # # [CRITICAL] Pay attention to this when scaling to multiple GPUs and clusters
689
+
690
+ # # Pay attention to this, use "accelerate config"
691
+
692
+ os.environ['RANK'] #= str(0) # Number of nodes (servers)
693
+ os.environ['WORLD_SIZE'] # = str(torch.cuda.device_count())
694
+
695
+ dist.init_process_group(backend='nccl') #init_method="env://")
696
+
697
+ Train()
698
+
699
+ if __name__ == '__main__':
700
+ train()
Andromeda/Andromeda/utils/__init__.py ADDED
File without changes
Andromeda/Andromeda/utils/decoupled_optimizer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from palm_rlhf_pytorch.palm import LayerNorm
3
+ from torch.nn import LayerNorm
4
+ from torch.optim import AdamW
5
+
6
+ # from palm.utils import print_main
7
+ from Andromeda.utils.helpers import print_main
8
+ from Andromeda.utils.stable_adamw import StableAdamWUnfused
9
+
10
+ # optimizers
11
+
12
+
13
+ def decoupled_optimizer(
14
+ model: torch.nn.Module,
15
+ learning_rate: float,
16
+ weight_decay: float = 0.1,
17
+ beta_1: float = 0.90,
18
+ beta_2: float = 0.95,
19
+ optimizer_type: str = "adamw",
20
+ use_fsdp: bool = True,
21
+ ):
22
+ """
23
+ Decouples the optimizer from the training process.
24
+
25
+ This function sets up the optimizer for the model by creating two groups of parameters:
26
+ one for weight decay and one without weight decay. Then, it initializes the optimizer
27
+ with these two groups of parameters.
28
+
29
+ Args:
30
+ model (Module): The model whose parameters are optimized.
31
+ learning_rate (float): The learning rate for the optimizer.
32
+ weight_decay (float): The weight decay for the optimizer.
33
+ beta_1 (float): The exponential decay rate for the 1st moment estimates.
34
+ beta_2 (float): The exponential decay rate for the 2nd moment estimates.
35
+ optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'.
36
+ use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True.
37
+ accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None.
38
+
39
+ Returns:
40
+ Optimizer: The initialized optimizer.
41
+
42
+ Raises:
43
+ ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'.
44
+ """
45
+ print_main(f"Using {optimizer_type} optimizer")
46
+ # Create an empty dictionary called param_dict to store the model's named parameters.
47
+ param_dict = {}
48
+ # Iterate over the model's named parameters and populate the param_dict with key-value pairs.
49
+ for param_name, param in model.named_parameters():
50
+ print_main(param_name)
51
+ param_dict[param_name] = param
52
+
53
+ # Separate the model's named modules into two groups: decay and no_decay.
54
+
55
+ # Create an empty list to store the names of the LayerNorm and Embedding layer weights with no weight decay.
56
+ no_decay = []
57
+
58
+ if use_fsdp:
59
+ exclude_module = "_fsdp_wrapped_module.token_emb"
60
+ else:
61
+ exclude_module = "token_emb"
62
+
63
+ # Iterate through the named modules of the model.
64
+ for module_name, module in model.named_modules():
65
+ # Check if the current module is an instance of any of the desired types (LayerNorm or torch.nn.Embedding).
66
+ for ndim in [LayerNorm, torch.nn.Embedding]:
67
+ if isinstance(module, ndim):
68
+ # If torch.nn.Embedding, append its name with a ".weight" suffix to the no_decay list.
69
+ if module_name == exclude_module:
70
+ no_decay.append(f"{module_name}.weight")
71
+ else:
72
+ # If the module is an instance of LayerNorm
73
+ no_decay.append(f"{module_name}.gamma")
74
+ # Exit the inner loop since the desired module has been found.
75
+ break
76
+
77
+ # Create an empty list to store the names of the Linear layer weights with weight decay.
78
+ decay = []
79
+
80
+ # Iterate through the named modules of the model.
81
+ for module_name, module in model.named_modules():
82
+ # Check if the current module is an instance of the desired type (torch.nn.Linear).
83
+ for ndim in [torch.nn.Linear]:
84
+ if isinstance(module, ndim):
85
+ # If the module is an instance of torch.nn.Linear, append its name with a ".weight" suffix to the decay list.
86
+ decay.append(f"{module_name}.weight")
87
+ # Exit the inner loop since the desired module has been found.
88
+ break
89
+
90
+ # Create two separate lists of model parameters: decay_param and no_decay_param.
91
+ # The decay_param list contains the parameters that should have weight decay applied.
92
+ # The no_decay_param list contains the parameters that should not have weight decay applied, excluding the 'to_logits.weight' parameter.
93
+
94
+ # Create an empty list called decay_param to store the parameters with weight decay.
95
+ decay_param = []
96
+
97
+ if use_fsdp:
98
+ exclude_param = "_fsdp_wrapped_module.to_logits.weight"
99
+ else:
100
+ exclude_param = "to_logits.weight"
101
+
102
+ # Iterate over the decay list, which contains the names of the parameters with weight decay.
103
+ for param in decay:
104
+ # Check if the current parameter is not 'to_logits.weight'.
105
+ # Append the corresponding parameter from param_dict to the decay_param list.
106
+
107
+ if param != exclude_param:
108
+ decay_param.append(param_dict[param])
109
+
110
+ # Create an empty list called no_decay_param to store the parameters without weight decay.
111
+ no_decay_param = []
112
+
113
+ # Iterate over the no_decay list, which contains the names of the parameters without weight decay.
114
+ for param in no_decay:
115
+ # Append the corresponding parameter from param_dict to the no_decay_param list.
116
+ no_decay_param.append(param_dict[param])
117
+
118
+ # Create a list called grouped_params that contains two dictionaries.
119
+ # The first dictionary has the decay_param list and the corresponding weight_decay value.
120
+ # The second dictionary has the no_decay_param list and a weight_decay value of 0.0.
121
+ grouped_params = [
122
+ {"params": decay_param, "weight_decay": weight_decay},
123
+ {"params": no_decay_param, "weight_decay": 0.0},
124
+ ]
125
+
126
+ # Create a variable called optimizer that stores an instance of the optimizer.
127
+ if optimizer_type == "adamw":
128
+ optimizer = AdamW(
129
+ grouped_params,
130
+ lr=learning_rate,
131
+ betas=(beta_1, beta_2),
132
+ )
133
+ elif optimizer_type == "stable_adamw":
134
+ optimizer = StableAdamWUnfused(
135
+ grouped_params,
136
+ lr=learning_rate,
137
+ betas=(beta_1, beta_2),
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format(
142
+ optimizer_type
143
+ )
144
+ )
145
+
146
+ # Return the optimizer.
147
+ return optimizer
Andromeda/Andromeda/utils/helpers.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.distributed as dist # Add this line
2
+
3
+ def print_num_params(model):
4
+ n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
5
+
6
+ if dist.is_available():
7
+ if dist.get_rank() == 0:
8
+ print(f"Number of parameters in model: {n_params}")
9
+ else:
10
+ print(f"Number of parameters in model: {n_params}")
11
+
12
+ def print_main(msg):
13
+ if dist.is_available():
14
+ if dist.get_rank() == 0:
15
+ print(msg)
16
+ else:
17
+ print(msg)
Andromeda/Andromeda/utils/rf_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import einsum, _nnpack_available
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange
7
+ import copy
8
+ from pathlib import PurePath
9
+ from tqdm import tqdm_gui
10
+ from beartype import beartype
11
+ from beartype.typing import Tuple, Optional
12
+
13
+ from einops import rearrange, repeat, reduce, unpack
14
+ from einops.layers.torch import Rearrange, Reduce
15
+
16
+
17
+ #helpers
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ #decorators
23
+ def eval_decorator(fn):
24
+ def inner(self, *args, **kwargs):
25
+ was_training = self.training
26
+ self.eval()
27
+ out = fn(self, *args, **kwargs)
28
+ self.train(was_training)
29
+ return out
30
+ return inner
31
+
32
+ def defaults(val, d):
33
+ return val if exists(val) else d
34
+
35
+ #tensor helpers
36
+
37
+ def log(t, eps=1e-20):
38
+ return torch.log(t.clamp(min = eps))
39
+
40
+ def masked_mean(seq, mask=None, dim=1, keepdim=True):
41
+ if not exists(mask):
42
+ return seq.mean(dim=dim)
43
+
44
+ if seq.ndim == 3:
45
+ mask = rearrange(mask, 'b n -> b n 1')
46
+
47
+ masked_seq = seq.masked_fill(~mask, 0.)
48
+ numer = masked_seq.sum(dim=dim, keepdim=keepdim)
49
+ denom = mask.sum(dim=dim, keepdim=keepdim)
50
+
51
+ masked_mean = numer / denom.clamp(min = 1e-3)
52
+ masked_mean = masked_mean.masked_fill(denom == 0, 0.)
53
+ return masked_mean
54
+
55
+
56
+ #sampling helpers
57
+
58
+ def gumbel_noise(t):
59
+ noise = torch.zeros_like(t).uniform(0, 1)
60
+ return -log(-log(noise))
61
+
62
+
63
+ def gumbel_sample(t, temperature = 1., dim=-1):
64
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
65
+
66
+ def top_p(logits, thres=0.9):
67
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
68
+ cum_probs = torch.einsum(F.softmax(sorted_logits, dim=-1), dim=-1)
69
+
70
+ sorted_indices_to_remove = cum_probs > (1 - thres)
71
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
72
+ sorted_indices_to_remove[:, 0] = 0
73
+
74
+ sorted_logits[sorted_indices_to_remove] = float("-inf")
75
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
76
+
77
+ def top_k(logits, thres=0.9):
78
+ k = math.ceil((1 - thres) * logits.shape[-1])
79
+ val, ind = torch.topk(logits, k)
80
+ probs = torch.full_like(logits, float('-inf'))
81
+ probs.scatter_(1, ind, val)
82
+ return probs
83
+
84
+
85
+ class LoRA(nn.Module):
86
+ def __init__(
87
+ self,
88
+ dim,
89
+ dim_out,
90
+ r=8,
91
+ alpha=None
92
+ ):
93
+ super().__init__()
94
+ alpha = defaults(alpha, r)
95
+ self.scale = alpha / r
96
+
97
+ self.A = nn.Parameter(torch.randn(dim, r))
98
+ self.B = nn.Parameter(torch.zeros(r, dim_out))
99
+
100
+
101
+
102
+ #reward model
103
+ @beartype
104
+
105
+ class RewardModel(nn.Module):
106
+ def __init__(
107
+ self,
108
+ model: Andromeda,
109
+ dropout=0.1,
110
+ num_binned_output = 0.,
111
+ use_lora = True,
112
+ lora_r = 8,
113
+ reward_lora_scope = 'reward',
114
+ ):
115
+ super().__init__()
116
+
117
+ self.model = copy.deepcopy(Andromeda)
118
+ self.model.set_dropout(dropout)
119
+
120
+ self.reward_lora_scope = reward_lora_scope is use_lora else None
121
+
122
+ if exists(self.reward_lora_scope):
123
+ self.model.add_finetune_params(reward_lora_scope, lora_r = lora_r)
124
+
125
+ dim = model.dim
126
+
127
+ self.binned_output = num_binned_output > 1
128
+
129
+ self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
130
+ self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
131
+
132
+
133
+ if self.binned_output:
134
+ self.to_pred = nn.Linear(dim, num_binned_output)
135
+ else:
136
+ self.to_pred = nn.Sequential(
137
+ nn.Linear(dim, 1, bias=False),
138
+ Rearrange('... 1 -> ...')
139
+ )
140
+
141
+ def load(self, path):
142
+ path = Path(path)
143
+ assert path.exists()
144
+ self.load_state_dict(torch.load(str(path)))
145
+
146
+ def finetune_parameters(self):
147
+ return (
148
+ *self.to_pred.parameters(),
149
+ *(self.model.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else model.parameters())
150
+ )
151
+
152
+
153
+ def forward(
154
+ self,
155
+ x,
156
+ mask=None,
157
+ prompt_mask=None,
158
+ prompt_lengths=None,
159
+ labels=None,
160
+ sample=False,
161
+ sample_temperature=1.,
162
+ disable_lora=False
163
+ ):
164
+ assert not (exists(prompt_mask) and exists(prompt_lengths))
165
+
166
+ #derive prompt mask from prompt lengths
167
+
168
+ if exists(prompt_lengths):
169
+ batch, seq_len = x.shape
170
+ arange = torch.arange(seq_len, device = x.device)
171
+ prompt_mask = repeat(arange, 'n -> n n', b = batch) > rearrange(prompt_lengths, 'b -> b 1')
172
+
173
+ #rward model should have an understand of which section is prompt and which section is repsonse
174
+
175
+ extra_embed = None
176
+
177
+ if exists(prompt_mask):
178
+ extra_embed = torch.where(
179
+ rearrange(prompt_mask, 'b n -> b n 1'),
180
+ self.prompt_embed,
181
+ self.response_embed
182
+ )
183
+
184
+ embeds = self.model(
185
+ x,
186
+ )
Andromeda/Andromeda/utils/stable_adamw.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ # This is the unfused version of StableAdamW. It is slower than the fused version (coming).
5
+
6
+
7
+ class StableAdamWUnfused(torch.optim.Optimizer):
8
+ def __init__(
9
+ self,
10
+ params,
11
+ lr=0.002,
12
+ weight_decay=0.2,
13
+ betas=(0.9, 0.99),
14
+ eps=1e-8,
15
+ clip_thresh=1.0,
16
+ precision="amp_bfloat16",
17
+ custom_scalar=65536,
18
+ ):
19
+ beta1, beta2 = betas[0], betas[1]
20
+ defaults = dict(lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2)
21
+ super(StableAdamWUnfused, self).__init__(params, defaults)
22
+
23
+ self.eps = eps
24
+ self.d = clip_thresh
25
+
26
+ # Set precision to "custom_fp16" if you want to use a fixed loss scalar, custom_scalar, which is divided out in the update step.
27
+ # If you do this, call (custom_scalar * loss).backward() instead of loss.backward().
28
+ self.precision = precision
29
+ self.custom_scaler = custom_scalar
30
+
31
+ for group in self.param_groups:
32
+ group["step"] = 1.0
33
+
34
+ print("Using StableAdamWUnfused-v1")
35
+
36
+ def __setstate__(self, state):
37
+ super(StableAdamWUnfused, self).__setstate__(state)
38
+
39
+ def step(self, closure=None):
40
+ if closure is not None:
41
+ closure()
42
+
43
+ for group in self.param_groups:
44
+ lr = group["lr"]
45
+ weight_decay = group["weight_decay"]
46
+ beta1 = group["beta1"]
47
+ beta2 = group["beta2"]
48
+ step = group["step"]
49
+
50
+ for p in group["params"]:
51
+ if p.grad is None:
52
+ continue
53
+ theta = p.data
54
+ param_state = self.state[p]
55
+
56
+ if self.precision == "custom_fp16":
57
+ g = p.grad.data / self.custom_scaler
58
+ if torch.any(torch.isnan(g) | torch.isinf(g)):
59
+ continue
60
+ else:
61
+ g = p.grad.data
62
+
63
+ if "exp_avg" not in param_state:
64
+ v = param_state["exp_avg"] = torch.zeros_like(theta)
65
+ u = param_state["exp_avg_sq"] = torch.zeros_like(theta)
66
+ else:
67
+ v = param_state["exp_avg"]
68
+ u = param_state["exp_avg_sq"]
69
+
70
+ beta1hat = beta1 * (1 - beta1 ** (step - 1)) / (1 - beta1**step)
71
+ beta2hat = beta2 * (1 - beta2 ** (step - 1)) / (1 - beta2**step)
72
+
73
+ v = v.mul_(beta1hat).add_(g, alpha=1.0 - beta1hat)
74
+ u = u.mul_(beta2hat).addcmul_(g, g, value=1.0 - beta2hat)
75
+
76
+ denominator = u.sqrt().add_(self.eps)
77
+
78
+ # StableAdamW = AdamW + update clipping (https://arxiv.org/abs/1804.04235) applied tensor-wise.
79
+ rms = (
80
+ torch.div(
81
+ g.pow(2), torch.maximum(u, (self.eps**2) * torch.ones_like(u))
82
+ )
83
+ .mean()
84
+ .sqrt()
85
+ .item()
86
+ )
87
+
88
+ theta = theta.mul_(1.0 - lr * weight_decay).addcdiv_(
89
+ v, denominator, value=-lr * (1.0 / max(1.0, rms / self.d))
90
+ )
91
+
92
+ # save current params
93
+ param_state["exp_avg"] = v
94
+ param_state["exp_avg_sq"] = u
95
+
96
+ group["step"] = step + 1
Andromeda/DOCs/Corporation/MONETIZATION.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Andromeda Product Brief and Monetization Strategy Document
2
+
3
+ ## Product Summary:
4
+
5
+ Andromeda is an innovative language model designed for high performance and efficiency. It utilizes advanced techniques that allow it to process and learn from multiple sources and adapt in real-time.
6
+
7
+ ## Monetization Strategies:
8
+
9
+ 1. **Usage-based API:** Provide Andromeda as a paid API service where users pay based on the amount of computation they use.
10
+ 2. **Consulting deals:** Offer expert consulting services to businesses looking to incorporate Andromeda's capabilities into their operations.
11
+ 3. **Dedicated capacity:** Sell dedicated computational power to businesses for exclusive usage of Andromeda's capabilities.
12
+ 4. **Licensing the technology:** Allow companies to license the Andromeda model for their proprietary use.
13
+ 5. **Subscription models:** Provide access to Andromeda's capabilities on a subscription basis.
14
+ 6. **Freemium model:** Offer basic usage of Andromeda for free, while charging for advanced features and capabilities.
15
+ 7. **Partnerships:** Form strategic partnerships with tech companies that can leverage Andromeda's capabilities in their products and services.
16
+ 8. **Sponsorships:** Sponsor research projects or tech events to get visibility and promote Andromeda's services.
17
+ 9. **Training and certifications:** Offer training programs and certifications on Andromeda usage and applications.
18
+ 10. **Custom development:** Offer custom development services for businesses that want specialized applications of Andromeda.
19
+
20
+ ## Potential Customers:
21
+
22
+ 1. **Tech companies:** Andromeda can be integrated into a wide array of tech products and services.
23
+ 2. **Educational institutions:** Universities and research institutions can use Andromeda for research purposes.
24
+ 3. **Government agencies:** Andromeda can assist in processing and analyzing large amounts of data.
25
+ 4. **Healthcare providers:** Andromeda can be used in data analysis and decision making in healthcare.
26
+ 5. **Media and entertainment industry:** Andromeda's language model can be used in content creation and curation.
27
+
28
+ ## Potential Cashflow Gains:
29
+
30
+ 1. **API usage revenues:** Charging per API call can generate substantial revenues with a high number of users.
31
+ 2. **Subscription fees:** A tier-based subscription model can ensure a steady income stream.
32
+ 3. **Licensing fees:** Companies willing to license the technology can provide a significant one-time or recurring revenue.
33
+ 4. **Consulting fees:** Consulting services can yield high-value contracts.
34
+ 5. **Sponsorship revenues:** Sponsoring events or projects can yield returns in the form of new business leads and customers.
35
+
36
+ ## Expenses:
37
+
38
+ 1. **Cloud infrastructure costs:** Major expense in maintaining and scaling the Andromeda model.
39
+ 2. **Research and development:** Continual improvement of Andromeda requires ongoing investment.
40
+ 3. **Marketing and sales:** Promoting Andromeda and closing sales deals will be a recurring expense.
41
+ 4. **Operational costs:** Expenses related to managing the company, including salaries, office space, utilities, and more.
42
+ 5. **Open-source contributors:** Andromeda is built on the contributions of numerous developers. Recognizing these contributors through a rewards program is an essential part of maintaining a healthy development ecosystem.
43
+
44
+ ### Open Source Contributors:
45
+
46
+ The following is a representative list of contributors who have helped make Agora what it is today:
47
+
48
+ 1. Kye
49
+ 2. Nicolo
50
+
51
+ Each contributor brings unique expertise and value to the project, helping to shape Andromeda into a powerful, efficient, and intelligent language model that will revolutionize the NLP landscape.
Andromeda/DOCs/Design/Dyson.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Insights and Techniques:
2
+
3
+ 1. Flops: The importance of considering the number of floating-point operations (FLOPs) when designing models.
4
+ 2. Flash Attention 2.0: The use of techniques like Flash Attention 2.0 cuda to enable more FLOPs in the model.
5
+ 3. Mixed Precision: Utilizing mixed precision training to improve training speed and memory efficiency.
6
+ 4. Deepspeed 3 with NVMe: Using Deepspeed 3 with NVMe for optimizing training performance.
7
+ 5. 8-bit Optimizer: Employing an 8-bit optimizer for further speed improvements.
8
+ 6. Gradient Clipping: Adding gradient clipping to achieve massive speedup during training.
9
+ 7. XPOS, ALIBI, QK Layernorm: Leveraging advanced techniques for extrapolation, interpolation, and training stabilization.
10
+ 8. Multi Query Attention: Using multi-query attention to boost decoding speed.
11
+ 9. Parallelized Transformer Blocks: Parallelizing transformer blocks to enhance overall model performance.
12
+ 10. Positional Embeddings and Shifted Tokens: The decision to not use positional embeddings and utilization of shifted tokens for sequence length advancement.
13
+ 11. Positional Interpolation: Incorporating positional interpolation for improved sequence handling.
14
+ 12. Optimized CUDA Embedding Function: Utilizing an optimized CUDA embedding function for better performance.
15
+ 13. Nebula Loss Function: Implementing the Nebula loss function, a polymorphic loss function for multi-task training.
16
+
17
+ Possible Improvements:
18
+
19
+ 1. Clearer Metrics: To validate the model's claims, it would be beneficial to establish specific metrics for monitoring across training, especially regarding reasoning capabilities.
20
+ 2. Validation and Testing Environment: Further development and description of the exhaustive testing environment to validate the model's performance and capabilities.
21
+ 3. Comprehensive Documentation: Provide detailed documentation of the model's architecture, training methodology, and testing procedures to ensure transparency and replicability.
22
+ 4. Benchmarking Against Competitors: Perform benchmarking against existing models to showcase the advantages and differentiation offered by the proposed architecture and training techniques.
23
+ 5. Real-World Applications: Highlight potential real-world applications or use cases where the proposed model can provide superior performance compared to existing solutions.
24
+ 6. Explainability and Interpretability: Consider incorporating methods for model explainability and interpretability, especially in applications where these aspects are crucial.
25
+ 7. Addressing Specific Niche Needs: Identify specific niches or use cases where the model can excel and tailor marketing and development efforts accordingly.
26
+ 8. Collaboration and Peer Review: Engage with the research community, participate in peer review, and seek collaboration opportunities to gain additional insights and validation.
Andromeda/DOCs/Design/MODEL_ARCHITECTURE.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### Alibi Positional Bias
3
+
4
+ Alibi positional bias allows the model to learn relative positions between tokens, enabling it to better capture the relationships and dependencies between tokens in a sequence.
5
+
6
+ Usage example:
7
+
8
+ ```python
9
+ attn_layers = Decoder(
10
+ ...
11
+ alibi_pos_bias=True,
12
+ alibi_num_heads=4,
13
+ ...
14
+ )
15
+ ```
16
+
17
+ ### Rotary Position Encodings (xpos)
18
+
19
+ Rotary position encodings introduce a more efficient way to encode positions in the input sequence. They avoid the need for absolute positional embeddings, reducing the model's memory footprint and improving training speed.
20
+
21
+ Usage example:
22
+
23
+ ```python
24
+ attn_layers = Decoder(
25
+ ...
26
+ rotary_xpos=True,
27
+ ...
28
+ )
29
+ ```
30
+
31
+ ### Flash Attention
32
+
33
+ Flash attention speeds up the self-attention mechanism by reducing the number of attention computations. It accelerates training and inference while maintaining a high level of performance.
34
+
35
+ Usage example:
36
+
37
+ ```python
38
+ attn_layers = Decoder(
39
+ ...
40
+ attn_flash=True,
41
+ ...
42
+ )
43
+ ```
44
+
45
+ Usage example:
46
+
47
+ ```python
48
+ attn_layers = Decoder(
49
+ ...
50
+ deepnorm=True,
51
+ ...
52
+ )
53
+ ```
54
+
55
+ ### Deep Normalization (deepnorm)
56
+
57
+ Deep normalization is a technique that normalizes the activations within a layer, helping with training stability and convergence. It allows the model to better learn complex patterns and generalize to unseen data.
Andromeda/DOCs/Design/SPEED.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Increasing Speed
2
+
3
+ * Integrate Flash Attention 2.0 cuda, significant speed up
4
+
5
+ * Utilize 8BIT Optimizer from BNB, big speed up weakness => bnb isn't compatible with all gpus
6
+
7
+ * Use a better tokenizer TokenMonster?
8
+
9
+ * Parallelize the transformer blocks similar to that of [PALMS](https://github.com/conceptofmind/PaLM)
10
+
11
+ * Look into MPTS config for LION for pretraining, did they use high batch size?
Andromeda/DOCs/Design/Specs.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## **Andromeda Specs**: Unveiling Mastery
2
+
3
+ **Overview**
4
+ Elegantly marrying craftsmanship and technology, Andromeda is not just another step in AI evolution. It's a giant leap. Driven by precision, powered by innovation, and defined by excellence, Andromeda is the epitome of intelligence realized. Here, we detail the marvel that is Andromeda, in numbers, facts, and logic.
5
+
6
+ ---
7
+
8
+ ### **Specifications**
9
+
10
+ | **Feature** | **Specification** |
11
+ |----------------------------------------------|-----------------------------------------------|
12
+ | **Sequence Handling** | Ultra Long (32,000 - 200,000+ context lengths)|
13
+ | **Processing Speed** | Ultra Fast (32,000+ tokens in < 100ms) |
14
+ | **Reasoning Abilities** | Creativity, Quantitative |
15
+ | **Attention Mechanism** | Flash Attention 2.0 Triton |
16
+ | **Memory Consumption** (compared to GPT-3) | 100x Less |
17
+ | **Memory Consumption** (compared to LLAMA) | 30x Less |
18
+ | **Max Sequence Processing Speed** | 100,000+ sequences in < 300ms |
19
+ | **Dataset Strategy** | Books, Falcon, Redpajama, Math, Code |
20
+ | **Functionality** | FSDP, HF Accelerate, Poetry Composition, API Calls, and more |
21
+
22
+ ---
23
+
24
+ ### **Benchmarks**
25
+ **Speed**: At the heart of Andromeda's unparalleled capabilities is its raw speed. Leveraging the prowess of Flash Attention 2.0 Triton, it doesn't merely process data; it blazes through it. This power allows it to consume 50x less memory than its predecessor, GPT-3, and 10x less than LLAMA.
26
+
27
+ ---
28
+
29
+ ### **Why Andromeda?**
30
+ - **Performance**: Andromeda isn't about doing things faster; it's about doing them the best. Reliable processing of sequences, even as extensive as 100,000+ lengths, is realized in the blink of an eye, under 300ms.
31
+
32
+ - **Precision and Creativity**: The dataset strategy is no mere algorithm. It's a symphony, meticulously crafted to offer both creativity and quantitative reasoning.
33
+
34
+ - **Versatility**: Andromeda doesn't just compute; it contemplates. Whether you need the flair of a poet or the precision of an API call, Andromeda delivers, seamlessly.
35
+
36
+ ---
37
+
38
+ ### **Andromeda Principles**
39
+ - **Efficiency**: It's not just about doing more; it's about doing better. Techniques like attention flashing, rotary position encodings, and deep normalization ensure every cycle, every operation, every byte is optimized for performance.
40
+
41
+ - **Flexibility**: In the ever-evolving world of technology, adaptability is king. Andromeda is designed to mold, adapt, and excel, irrespective of the task or domain.
42
+
43
+ - **Scalability**: Grow with you, for you. Andromeda isn't static. It's dynamic, designed to scale, accommodating growing resources and expanding data sizes.
44
+
45
+ - **Community-Driven**: Behind Andromeda's machine brain is the human heart of the community. It doesn't just utilize open source; it thrives on it, constantly evolving, learning, and improving with contributions from around the world.
46
+
47
+
48
+ For enthusiasts, developers, and thinkers looking to dive deeper, the Model Architecture documentation offers an exhaustive, detailed view into the intricate marvel that is Andromeda. Dive in, and witness engineering and artistry in harmony.
49
+
50
+ ---
51
+
52
+ ### **Andromeda: A Detailed Technical Overview**
53
+
54
+ At the intersection of technological ingenuity and groundbreaking design principles, Andromeda emerges. Representing the zenith of years of research and development, it promises a transformative leap in AI performance, efficiency, and versatility. In this technical specifications document, we deconstruct the intricacies of Andromeda, presenting a meticulous overview of its structure, performance metrics, and underlying methodologies.
55
+
56
+ ## **Feature Insights**
57
+
58
+ ### **Alibi Positional Bias**
59
+ Empowering Andromeda to discern relative positions between tokens, this feature accentuates its ability to grasp intricate relationships within a sequence.
60
+
61
+ ### **Rotary Position Encodings (xpos)**
62
+ This is a revolutionary means of encoding positions, shrinking the model's memory demands and propelling training speeds.
63
+
64
+ ### **Flash Attention**
65
+ This is the linchpin of Andromeda's speed prowess, minimizing attention computations, thus boosting training and inference phases.
66
+
67
+ ### **Deep Normalization (deepnorm)**
68
+ By normalizing activations, deep normalization shores up training stability, allowing Andromeda to identify intricate patterns with finesse.
69
+
70
+ ## **Feature Insights (Contd.)**
71
+
72
+ ### **Attn One KV Head (Multiquery Attention)**
73
+ A breakthrough in attention mechanism design, this feature allows for simultaneous computation of multiple queries against the same set of key-values, fostering speed and efficiency.
74
+
75
+ ### **QK Norm & Attention QK Norm**
76
+ These two features introduce a normalization step in the query and key matrices. This step facilitates stabilization in the attention mechanism, rendering it more robust and enabling it to scale with larger input sizes.
77
+
78
+ ### **Attention QK Norm Dimension Scale**
79
+ A sophisticated adjustment to the attention mechanism, it modulates the normalization scale in accordance to the dimensions of the model. The result is a more adaptive and responsive attention framework.
80
+
81
+ ### **Embedding Provider**
82
+ At the foundation of Andromeda, this module facilitates the embedding process, converting token sequences into dense vectors. Tailored for Andromeda, it ensures rapid and efficient embedding processes.
83
+
84
+ ---
85
+
86
+ ## **Deeper Dive: Model Parameters**
87
+
88
+ Unpacking Andromeda means diving deep into the parameters that shape its capabilities. Here's a granular view:
89
+
90
+ | **Parameter** | **Description** | **Default Value** |
91
+ |-----------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------|
92
+ | **num_tokens** | Total number of tokens in the vocabulary. | 50432 |
93
+ | **max_seq_len** | Maximum sequence length the model can process. | 8192 |
94
+ | **dim** | Dimension size of the model. It represents the size of embeddings and general depth in neural layers. | 2560 |
95
+ | **depth** | Represents the number of transformer layers in the architecture. | 32 |
96
+ | **dim_head** | Dimension size of each head in multi-head attention mechanism. | 128 |
97
+ | **heads** | Total number of heads in multi-head attention. | 24 |
98
+ | **use_abs_pos_emb** | Boolean flag to determine if absolute positional embeddings are used. | False |
99
+ | **alibi_pos_bias** | Enables the alibi positional bias in attention mechanisms. | True |
100
+ | **alibi_num_heads** | Specifies the number of heads for the alibi positional bias. | 12 |
101
+ | **rotary_xpos** | Determines if rotary positional encodings are utilized. | True |
102
+ | **attn_flash** | Flag to activate the Flash Attention mechanism, minimizing computations in the attention phase. | True |
103
+ | **shift_tokens** | The number of tokens by which input sequences are shifted. Essential for certain sequence-to-sequence tasks. | 1 |
104
+ | **attn_one_kv_head** | Activates multiquery attention by computing multiple queries against a singular key-value pair. | True |
105
+ | **qk_norm** | Enables the query-key normalization mechanism in the attention phase. | True |
106
+ | **attn_qk_norm** | A more advanced version of query-key normalization that scales according to the model's dimensions. | True |
107
+ | **attn_qk_norm_dim_scale** | Modulates the scale of the aforementioned attention normalization based on the model's dimensionality. | True |
108
+ | **embedding_provider** | The module responsible for providing embeddings. Custom providers can be passed for tailored embedding processes. | AndromedaEmbedding|
109
+
110
+ ---
111
+
112
+
113
+ ## **Insights and Techniques**
114
+
115
+ #### **1. Floating-Point Operations (FLOPs)**
116
+ Considering the number of FLOPs is paramount. It provides a metric to gauge the computational intensity and, by extension, the potential speed of the model.
117
+
118
+ #### **2. Flash Attention 2.0 Triton**
119
+ Enhanced with CUDA, this method offers a significant surge in the number of FLOPs the model can handle, amplifying its overall efficiency.
120
+
121
+ #### **3. Mixed Precision Training**
122
+ By embracing mixed precision, Andromeda realizes a noteworthy uptick in training speed while achieving commendable memory efficiency.
123
+
124
+ #### **4. Deepspeed 3 with NVMe Integration**
125
+ This powerful combination paves the way for superlative optimization during the training phase.
126
+
127
+ #### **5. 8-bit Optimizer**
128
+ Further pushing the boundaries of speed, the 8-bit optimizer boosts processing times without compromising the integrity of results.
129
+
130
+ #### **6. Gradient Clipping**
131
+ This technique has been integrated into the training regimen, achieving a massive speedup and preventing undesirable spikes during the process.
132
+
133
+ #### **7. Advanced Techniques: XPOS, ALIBI, QK Layernorm**
134
+ These sophisticated techniques are harnessed for superior extrapolation, interpolation, and stabilization during training.
135
+
136
+ #### **8. Multi Query Attention**
137
+ This approach has been adopted to supercharge decoding speeds.
138
+
139
+ #### **9. Parallelized Transformer Blocks**
140
+ Ensuring that the model's performance is consistently high, these blocks run in tandem to provide a smooth and efficient operational experience.
141
+
142
+ #### **10. Shifted Tokens**
143
+ In a strategic move, Andromeda sidesteps traditional positional embeddings, relying instead on shifted tokens for sequence length progression.
144
+
145
+ #### **11. Positional Interpolation**
146
+ This innovative technique augments the model's ability to manage sequences more effectively.
147
+
148
+ #### **12. Optimized CUDA Embedding Function**
149
+ This function is tailored for peak performance, ensuring rapid and accurate computations.
150
+
151
+ #### **13. Nebula Loss Function**
152
+ Integrated into Andromeda, this polymorphic loss function is adept at handling multi-task training scenarios.
153
+
154
+ ## **A Word on Optimization and Future Iterations**
155
+
156
+ As with any state-of-the-art model, Andromeda's design is an ever-evolving tapestry. This means iterative refinement. As feedback streams in and technology progresses, expect advancements in:
157
+
158
+ - **Model Pruning**: Trimming redundancies, bolstering efficiency.
159
+ - **Knowledge Distillation**: Harnessing the wisdom of larger models in smaller, more agile architectures.
160
+ - **Zero-Shot and Few-Shot Learning**: Broadening adaptability horizons.
161
+ - **Enhanced Data Augmentation**: Fortifying the model's grasp on varied, nuanced contexts.
162
+ - **Decentralized Training**: Tapping into the global hive-mind, harnessing the collaborative power of the community.
163
+
164
+
165
+ ## **Potential Other Future Trajectories**
166
+
167
+ #### **1. Clearer Metrics**
168
+ There's always room to elevate the benchmarking rigor, especially concerning reasoning abilities.
169
+
170
+ #### **2. Robust Validation and Testing Environment**
171
+ Further fine-tuning of the testing environment can offer even more reliable validations of Andromeda's capabilities.
172
+
173
+ #### **3. Comprehensive Documentation**
174
+ To bolster transparency and replicability, detailed documentation covering every facet of Andromeda is on the horizon.
175
+
176
+ #### **4. Benchmarking Against Peers**
177
+ By juxtaposing Andromeda against its counterparts, its distinctive advantages can be spotlighted more effectively.
178
+
179
+ #### **5. Spotlight on Real-World Applications**
180
+ By highlighting tangible use-cases, the versatility and prowess of Andromeda can be showcased in palpable contexts.
181
+
182
+ #### **6. Model Interpretability**
183
+ Future iterations might delve deeper into model interpretability, especially for critical applications.
184
+
185
+ #### **7. Niche Customizations**
186
+ By tailoring Andromeda to meet specific niche needs, its adaptability and value proposition can be further enhanced.
187
+
188
+ #### **8. Collaborative Endeavors**
189
+ Engaging more intimately with the global research community could spawn collaborative projects, bringing diverse insights to the fore.
190
+
191
+
192
+ As we voyage further into the AI frontier, Andromeda stands as a beacon, illuminating the path forward, promising marvels yet to come. It's not just about machine intelligence; it's about the dance between human curiosity and machine capability.
193
+
194
+ ---
195
+
196
+ Join us on this journey. Dive deeper, ask questions, innovate, and let's redefine what's possible, together.
Andromeda/DOCs/Docs/DOCUMENTATION.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Documentation
2
+
3
+ ## `DatasetBuilder`
4
+
5
+ ### DatasetBuilder
6
+
7
+ DatasetBuilder provides a convenient way to build datasets for training the Andromeda model.
8
+
9
+ #### Constructor
10
+
11
+ ```python
12
+ def __init__(
13
+ self,
14
+ dataset_name,
15
+ seq_len=8192,
16
+ num_cpu=None,
17
+ hf_account_repo=None,
18
+ tokenizer="EleutherAI/gpt-neox-20b",
19
+ )
20
+ ```
21
+
22
+ Initialize the DatasetBuilder.
23
+
24
+ **Args:**
25
+
26
+ - `dataset_name` (str): Name of the dataset to process.
27
+ - `seq_len` (int): Maximum sequence length.
28
+ - `num_cpu` (int, optional): Number of CPU cores to use for multiprocessing. Defaults to None.
29
+ - `hf_account_repo` (str, optional): Hugging Face account name and repository to push the processed dataset. Defaults to None.
30
+ - `tokenizer` (str, optional): Tokenizer model to use. Defaults to "EleutherAI/gpt-neox-20b".
31
+
32
+ #### Methods
33
+
34
+ ##### build_dataset
35
+
36
+ ```python
37
+ def build_dataset(self) -> torch.utils.data.Dataset
38
+ ```
39
+
40
+ Build and process the dataset.
41
+
42
+ **Returns:**
43
+
44
+ - `torch.utils.data.Dataset`: The processed dataset ready for training.
45
+
46
+
47
+
48
+ ## AndromedaTokenizer
49
+
50
+ ### Purpose
51
+
52
+ The `AndromedaTokenizer` class provides tokenization functionality using the Hugging Face tokenizer. It allows you to tokenize texts using the specified tokenizer model.
53
+
54
+ ### Systems Understanding
55
+
56
+ The `AndromedaTokenizer` class initializes a tokenizer model from the Hugging Face library. It uses the `AutoTokenizer.from_pretrained` method to load the tokenizer model with specific parameters such as the EOS token, pad token, extra IDs, and model maximum length. The `tokenize_texts` method tokenizes input texts using the tokenizer model and returns the tokenized input IDs.
57
+
58
+ ### Usage Example
59
+
60
+ ```python
61
+ from Andromeda import AndromedaTokenizer
62
+
63
+ # Initialize the tokenizer
64
+ tokenizer = AndromedaTokenizer()
65
+
66
+ # Tokenize texts
67
+ texts = ["This is an example sentence.", "Another example sentence."]
68
+ tokenized_ids = tokenizer.tokenize_texts(texts)
69
+
70
+ print(tokenized_ids)
71
+ ```
72
+
73
+ ## Andromeda
74
+
75
+ ### Purpose
76
+
77
+ The `Andromeda` class is a transformer-based model architecture. It consists of a `Transformer` and `AutoregressiveWrapper` with default or user-specified parameters.
78
+
79
+ ### Systems Understanding
80
+
81
+ The `Andromeda` class initializes with a `Transformer` and `AutoregressiveWrapper`. The `Transformer` encapsulates the main transformer model, and the `AutoregressiveWrapper` enables autoregressive generation using the transformer model.
82
+
83
+ The constructor of the `Andromeda` class takes various parameters that define the architecture of the model, such as the number of tokens, maximum sequence length, model dimension, depth, number of heads, etc. These parameters are used to initialize the `Transformer` and `AutoregressiveWrapper` with the specified configuration.
84
+
85
+ The `forward` method performs a forward pass through the model. It takes the input `text_tokens` as input and passes it through the `Decoder` module inside the `Andromeda` model. The output from the decoder is returned as the result.
86
+
87
+ ### Usage Example
88
+
89
+ ```python
90
+ from Andromeda import Andromeda
91
+
92
+ # Create an instance of the Andromeda model
93
+ model = Andromeda()
94
+
95
+ # Define the input text tokens
96
+ text_tokens = [1, 2, 3, 4, 5] # Example input tokens
97
+
98
+ # Perform a forward pass through the model
99
+ output = model.forward(text_tokens)
100
+
101
+ print(output)
102
+ ```
103
+
104
+ ### Constructor
105
+
106
+ ```python
107
+ def __init__(self, num_tokens=50304, max_seq_len=8192, dim=2560, depth=32, dim_head=128, heads=24, use_abs_pos_emb=False, alibi_pos_bias=True, alibi_num_heads=12, rotary_xpos=True, attn_flash=True, deepnorm=True, shift_tokens=1, attn_one_kv_head=True, qk_norm=True, attn_qk_norm=True, attn_qk_norm_dim_scale=True, embedding_provider=AndromedaEmbedding())
108
+ ```
109
+
110
+ - `num_tokens` (optional): Number of tokens in the vocabulary.
111
+ - `max_seq_len` (optional): Maximum sequence length.
112
+ - `dim` (optional): Dimension of the model.
113
+ - `depth` (optional): Depth of the model.
114
+ - `dim_head` (optional): Dimension of the model head.
115
+ - `heads` (optional): Number of heads.
116
+ - `use_abs_pos_emb` (optional): Whether to use absolute position embedding.
117
+ - `alibi_pos_bias` (optional): Alibi position bias.
118
+ - `alibi_num_heads` (optional): Number of alibi heads.
119
+ - `rotary_xpos` (optional): Rotary position.
120
+ - `attn_flash` (optional): Attention flash.
121
+ - `deepnorm` (optional): Deep normalization.
122
+ - `shift_tokens` (optional): Number of tokens to shift.
123
+ - `attn_one_kv_head` (optional): Attention one key/value head.
124
+ - `qk_norm` (optional): Query-key normalization.
125
+ - `attn_qk_norm` (optional): Attention query-key normalization.
126
+ - `attn_qk_norm_dim_scale` (optional): Attention query-key normalization dimension scale.
127
+ - `embedding_provider` (optional): Embedding provider module.
128
+
129
+ ### Methods
130
+
131
+ - `forward(text_tokens, **kwargs)`: Performs a forward pass through the model.
132
+ - `text_tokens` (required): Input tokens.
133
+ - `kwargs` (optional): Other arguments.
134
+
135
+ ### Args
136
+
137
+ - `text_tokens` (list): Input tokens.
138
+
139
+ ### Returns
140
+
141
+ - Output from the decoder module.
142
+
143
+ ## Conclusion
144
+
145
+ The Andromeda module provides a transformer-based model architecture for text generation. The `AndromedaTokenizer` class allows you to tokenize texts using the specified tokenizer model. The `Andromeda` class initializes with a transformer and autoregressive wrapper, providing the functionality for text generation. By using the provided classes and methods, you can generate text using the Andromeda model.
Andromeda/DOCs/Docs/TRAINING.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Andromeda Model Training Standard Operating Procedure
2
+
3
+ This document provides instructions on how to train the Andromeda model end-to-end using the provided code. The training procedure consists of three main scripts: `build_dataset.py`, `model.py`, and `train_distributed.py`. Follow the steps below to train the Andromeda model.
4
+
5
+ ## Prerequisites
6
+
7
+ Before starting the training process, ensure that you have the following requirements:
8
+
9
+ - Python 3.7 or higher
10
+ - PyTorch 1.9 or higher
11
+ - Transformers library
12
+ - Datasets library
13
+ - Accelerate library
14
+ - Wandb library (optional, for logging)
15
+
16
+ ## Step 1: Building the Dataset
17
+
18
+ The first step is to build the dataset required for training. The `build_dataset.py` script processes the training data and prepares it for training. Follow the instructions below to build the dataset:
19
+
20
+ 1. Open the `build_dataset.py` script.
21
+ 2. Set the configuration parameters in the `CFG` class according to your requirements:
22
+ - `HF_ACCOUNT_REPO`: Replace with your Hugging Face API key.
23
+ - `TOKENIZER`: Choose the tokenizer model to use (e.g., "EleutherAI/gpt-neox-20b").
24
+ - `DATASET_NAME`: Choose the dataset to process (e.g., "tiiuae/falcon-refinedweb").
25
+ - `SEQ_LEN`: Set the desired sequence length.
26
+ 3. Save the changes to the script.
27
+ 4. Open a terminal or command prompt and navigate to the directory containing the `build_dataset.py` script.
28
+ 5. Run the following command to execute the script:
29
+ ```
30
+ python build_dataset.py
31
+ ```
32
+ 6. The script will process the dataset and push it to your Hugging Face account repository specified by `HF_ACCOUNT_REPO`.
33
+
34
+ ## Step 2: Defining the Andromeda Model
35
+
36
+ The second step is to define the Andromeda model architecture. The `model.py` script contains the model definition and configuration. Follow the instructions below to configure the Andromeda model:
37
+
38
+ 1. Open the `model.py` script.
39
+ 2. Set the configuration parameters in the `AndromedaTokenizer` and `Andromeda` classes according to your requirements:
40
+ - `tokenizer`: Configure the tokenizer with the desired parameters.
41
+ - `Andromeda`: Configure the Andromeda model with the desired architecture.
42
+ 3. Save the changes to the script.
43
+
44
+ ## Step 3: Training the Andromeda Model
45
+
46
+ The final step is to train the Andromeda model using the `train_distributed.py` script. Follow the instructions below to start the training process:
47
+
48
+ 1. Open the `train_distributed.py` script.
49
+ 2. Set the configuration parameters in the `TrainAndromeda.CFG` class according to your requirements:
50
+ - `BATCH_SIZE`: Set the batch size for training.
51
+ - `GRADIENT_ACCUMULATE_EVERY`: Set the number of gradient accumulation steps.
52
+ - `LEARNING_RATE`: Set the learning rate for the optimizer.
53
+ - `WEIGHT_DECAY`: Set the weight decay for the optimizer.
54
+ - `SEQ_LEN`: Set the desired sequence length.
55
+ - `USE_DEEPSPEED`: Set to `True` if using DeepSpeed for optimization.
56
+ - `USE_FSDP`: Set to `True` if using Fully Sharded Data Parallelism.
57
+ - `USE_PRETOKENIZED`: Set to `True` if using a pre-tokenized dataset.
58
+ - `USE_ACTIVATION_CHECKPOINTING`: Set to `True` if using activation checkpointing.
59
+ - `RESUME_FROM_CHECKPOINT`: Set to the path of a checkpoint to resume training from.
60
+ - `CHECKPOINTING_STEPS`: Set the number of steps between checkpoints.
61
+ - `OUTPUT_DIR`: Set the output directory for saving the model checkpoints and logs.
62
+ - `ENTITY_NAME`: Set the Wandb entity name for logging (optional).
63
+ 3. Save the changes to the script.
64
+ 4. Open a terminal or command prompt and navigate to the directory containing the `train_distributed.py` script.
65
+ 5. Run the following command to start the training:
66
+ ```
67
+ python train_distributed.py
68
+ ```
69
+ 6. The script will train the Andromeda model using the specified configuration and dataset.
70
+ 7. During training, the progress will be displayed in the terminal, and logs will be saved to the specified output directory.
71
+
72
+ # Other Training methods
73
+
74
+ First:
75
+
76
+ `Accelerate Config`
77
+
78
+ Enable Deepspeed 3:
79
+
80
+ `Accelerate launch train_distributed_accelerate.py`
81
+
82
+
Andromeda/DOCs/Docs/Training/DATASET_STRATEGY.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Andromeda
2
+
3
+ We should train an 100m param, 500m, 1billion parameters verisions with similiar hyperparameters from these 2 similiar models
4
+
5
+ [concept of mind's PALM](https://github.com/conceptofmind/PaLM)
6
+ Model Size Num Tokens Dim Depth Dim Head Heads Flash Attention Learning Rate
7
+ 150 M 50304 768 12 128 8 True 6e-4
8
+ 410 M 50304 1024 24 128 8 True 3e-4
9
+ 1 B 50304 2048 16 128 8 True 3e-4
10
+
11
+
12
+ [MPT HF](https://huggingface.co/mosaicml/mpt-7b)
13
+
14
+ Hyperparameter Value
15
+ n_parameters 6.7B
16
+ n_layers 32
17
+ n_heads 32
18
+ d_model 4096
19
+ vocab size 50432
20
+ sequence length 2048
21
+
22
+
23
+
24
+
25
+ ## Data prioritization: Prioritize datasets based on their relevance to the desired AI capabilities and the quality of the data.
26
+
27
+ High priority: C4, openwebtext, super_glue, piqa, Falcon-40B (RefinedWeb-English, RefinedWeb-Europe, Books, Conversations, Code, Technical), glue, tiiuae/falcon-refinedweb, math_dataset
28
+
29
+ Medium priority: bigcode/ta-prompt, bigcode/the-stack-dedup, OpenAssistant/oasst1, ehartford/wizard_vicuna_70k_unfiltered, tiiuae/falcon-refinedweb
30
+
31
+ Low priority: timdettmers/openassistant-guanaco, JosephusCheung/GuanacoDataset, JosephusCheung/GuanacoDataset, anon8231489123/ShareGPT_Vicuna_unfiltered, togethercomputer/RedPajama-Data, togethercomputer/RedPajama-Data-1T, Anthropic/hh-rlhf, databricks/databricks-dolly-15k, QingyiSi/Alpaca-CoT, alpaca,
32
+ distillation, timdettmers/openassistant-guanaco, OpenAssistant/oasst1, dmayhem93/toolformer-v0-postprocessed, openai_humaneval, yahma/alpaca-cleaned,
33
+
34
+ ## Data preprocessing: Clean, preprocess, and tokenize the datasets to ensure consistency and compatibility with the AI model.
35
+
36
+ Remove duplicates, irrelevant content, and low-quality data.
37
+
38
+ Tokenize the text using a suitable tokenizer, such as GPT Neox tokenizer or potentially falcon's tokenizer
39
+
40
+ Split the datasets into training, validation, and testing sets.
41
+
42
+
43
+ ## Training strategy: Train the AI model using the prioritized datasets in a multi-stage process.
44
+
45
+ Stage 1: Pretrain the model on high-priority datasets (openwebtext, super_glue, piqa, Falcon-40B, glue) to build a strong language understanding foundation.
46
+
47
+ Stage 2: Fine-tune the model on medium-priority datasets (bigcode/ta-prompt, bigcode/the-stack-dedup, OpenAssistant/oasst1, ehartford/wizard_vicuna_70k_unfiltered, tiiuae/falcon-refinedweb) to enhance its performance in specific domains and tasks.
48
+
49
+ Stage 3: Further fine-tune the model on low-priority datasets (JosephusCheung/GuanacoDataset, anon8231489123/ShareGPT_Vicuna_unfiltered, togethercomputer/RedPajama-Data, togethercomputer/RedPajama-Data-1T, Anthropic/hh-rlhf, databricks/databricks-dolly-15k, QingyiSi/Alpaca-CoT) to capture any additional knowledge and nuances. PRM800K: A Process Supervision Dataset
50
+
51
+
52
+
53
+ Evaluation and iteration: Continuously evaluate the model's performance on the validation and testing sets, and iterate the training process to improve its performance.
54
+
55
+ Monitor the model's performance using relevant metrics, such as perplexity, F1 score, or BLEU score, depending on the task.
56
+ Adjust hyperparameters, learning rate, and training duration as needed to optimize the model's performance.
57
+ If necessary, revisit the data prioritization and preprocessing steps to refine the training data.
58
+
59
+
60
+ # Evaluations and Benchmarks:
61
+
62
+ [Chain of thought hub](https://github.com/FranxYao/chain-of-thought-hub)
63
+ SFT stands for Style Fine-tuning and RLHF stands for Reinforcement Learning and Human Feedback. These are techniques used in natural language processing to improve the quality and accuracy of generated text. The statement suggests that if these techniques are applied correctly to the 65B LLaMA dataset, it is possible to recreate ChatGPT.
64
+
65
+
66
+ # Analysis of Existing Models
67
+
68
+ ### MPT-7B
69
+
70
+ ```python
71
+ Data Source Number of Tokens in Source Proportion Effective Number of Tokens Epochs
72
+ mC4 3.1.0 - English 417.99 B 0.33 330 B 0.14
73
+ C4 - English - SemDedup 80% 100.42 B 0.299 299 B 2.98
74
+ RedPajama - CommonCrawl 878.45 B 0.1 100 B 0.11
75
+ The Stack - Selected Languages 463.78 B 0.1 100 B 0.22
76
+ RedPajama - Wikipedia - En 4.87 B 0.04 40 B 8.21
77
+ The Stack - Markdown 107.07 B 0.035 35 B 0.33
78
+ S2ORC 48.85 B 0.033 33 B 0.68
79
+ RedPajama - Books 26.02 B 0.03 30B 1.15
80
+ RedPajama - arXiv 28.10 B 0.019 19 B 0.68
81
+ RedPajama - StackExchange 20.54 B 0.014 14 B 0.68
82
+ ```
83
+
84
+ # MPT-1B
85
+
86
+ ```
87
+ Training Data
88
+ The model was trained for 200B tokens (batch size 2200, sequence length 2048). It was trained on the following data mix:
89
+
90
+ 67% RedPajama Common Crawl
91
+ 15% C4
92
+ 4.5% RedPajama GitHub
93
+ 4.5% RedPajama Wikipedia
94
+ 4.5% RedPajama Books
95
+ 2.5% RedPajama Arxiv
96
+ 2% RedPajama StackExchange
97
+
98
+ Each sample was chosen from one of the datasets, with the dataset selected with the probability specified above. The examples were shuffled within each dataset. Each example was constructed from as many sequences from that dataset as were necessary to fill the 2048 sequence length.
99
+
100
+ ```