|
|
|
|
|
#include <stdio.h> |
|
#include <stdlib.h> |
|
#include <ctype.h> |
|
#include <stdint.h> |
|
#include <time.h> |
|
#include <math.h> |
|
#include <string.h> |
|
#include <fcntl.h> |
|
#if defined _WIN32 |
|
#include "win.h" |
|
#else |
|
#include <unistd.h> |
|
#include <sys/mman.h> |
|
#endif |
|
|
|
|
|
int GS = 0; |
|
|
|
|
|
|
|
|
|
typedef struct { |
|
int dim; |
|
int hidden_dim; |
|
int n_layers; |
|
int n_heads; |
|
int n_kv_heads; |
|
int vocab_size; |
|
int seq_len; |
|
} Config; |
|
|
|
typedef struct { |
|
int8_t* q; |
|
float* s; |
|
} QuantizedTensor; |
|
|
|
typedef struct { |
|
|
|
QuantizedTensor *q_tokens; |
|
float* token_embedding_table; |
|
|
|
|
|
float* rms_att_weight; |
|
float* rms_ffn_weight; |
|
|
|
QuantizedTensor *wq; |
|
QuantizedTensor *wk; |
|
QuantizedTensor *wv; |
|
QuantizedTensor *wo; |
|
|
|
QuantizedTensor *w1; |
|
QuantizedTensor *w2; |
|
QuantizedTensor *w3; |
|
|
|
float* rms_final_weight; |
|
|
|
QuantizedTensor *wcls; |
|
} TransformerWeights; |
|
|
|
typedef struct { |
|
|
|
float *x; |
|
float *xb; |
|
float *xb2; |
|
float *hb; |
|
float *hb2; |
|
QuantizedTensor xq; |
|
QuantizedTensor hq; |
|
float *q; |
|
float *k; |
|
float *v; |
|
float *att; |
|
float *logits; |
|
|
|
float* key_cache; |
|
float* value_cache; |
|
} RunState; |
|
|
|
typedef struct { |
|
Config config; |
|
TransformerWeights weights; |
|
RunState state; |
|
|
|
int fd; |
|
float* data; |
|
ssize_t file_size; |
|
} Transformer; |
|
|
|
void malloc_run_state(RunState* s, Config* p) { |
|
|
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; |
|
s->x = calloc(p->dim, sizeof(float)); |
|
s->xb = calloc(p->dim, sizeof(float)); |
|
s->xb2 = calloc(p->dim, sizeof(float)); |
|
s->hb = calloc(p->hidden_dim, sizeof(float)); |
|
s->hb2 = calloc(p->hidden_dim, sizeof(float)); |
|
s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) }; |
|
s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) }; |
|
s->q = calloc(p->dim, sizeof(float)); |
|
s->k = calloc(kv_dim, sizeof(float)); |
|
s->v = calloc(kv_dim, sizeof(float)); |
|
s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); |
|
s->logits = calloc(p->vocab_size, sizeof(float)); |
|
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); |
|
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); |
|
|
|
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q |
|
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache |
|
|| !s->value_cache) { |
|
fprintf(stderr, "malloc failed!\n"); |
|
exit(EXIT_FAILURE); |
|
} |
|
} |
|
|
|
void free_run_state(RunState* s) { |
|
free(s->x); |
|
free(s->xb); |
|
free(s->xb2); |
|
free(s->hb); |
|
free(s->hb2); |
|
free(s->xq.q); |
|
free(s->xq.s); |
|
free(s->hq.q); |
|
free(s->hq.s); |
|
free(s->q); |
|
free(s->k); |
|
free(s->v); |
|
free(s->att); |
|
free(s->logits); |
|
free(s->key_cache); |
|
free(s->value_cache); |
|
} |
|
|
|
|
|
|
|
|
|
void dequantize(QuantizedTensor *qx, float* x, int n) { |
|
for (int i = 0; i < n; i++) { |
|
x[i] = qx->q[i] * qx->s[i / GS]; |
|
} |
|
} |
|
|
|
void quantize(QuantizedTensor *qx, float* x, int n) { |
|
int num_groups = n / GS; |
|
float Q_MAX = 127.0f; |
|
|
|
for (int group = 0; group < num_groups; group++) { |
|
|
|
|
|
float wmax = 0.0; |
|
for (int i = 0; i < GS; i++) { |
|
float val = fabs(x[group * GS + i]); |
|
if (val > wmax) { |
|
wmax = val; |
|
} |
|
} |
|
|
|
|
|
float scale = wmax / Q_MAX; |
|
qx->s[group] = scale; |
|
|
|
|
|
for (int i = 0; i < GS; i++) { |
|
float quant_value = x[group * GS + i] / scale; |
|
int8_t quantized = (int8_t) round(quant_value); |
|
qx->q[group * GS + i] = quantized; |
|
} |
|
} |
|
} |
|
|
|
|
|
QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) { |
|
void *p = *ptr; |
|
QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor)); |
|
for(int i=0; i<n; i++) { |
|
|
|
res[i].q = (int8_t*)p; |
|
p = (int8_t*)p + size_each; |
|
|
|
res[i].s = (float*)p; |
|
p = (float*)p + size_each / GS; |
|
} |
|
*ptr = p; |
|
return res; |
|
} |
|
|
|
void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) { |
|
int head_size = p->dim / p->n_heads; |
|
|
|
float* fptr = (float*) ptr; |
|
w->rms_att_weight = fptr; |
|
fptr += p->n_layers * p->dim; |
|
w->rms_ffn_weight = fptr; |
|
fptr += p->n_layers * p->dim; |
|
w->rms_final_weight = fptr; |
|
fptr += p->dim; |
|
|
|
|
|
ptr = (void*)fptr; |
|
w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim); |
|
|
|
w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float)); |
|
dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim); |
|
|
|
w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size)); |
|
w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size)); |
|
w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size)); |
|
w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim); |
|
|
|
w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim); |
|
w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim); |
|
w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim); |
|
|
|
w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size); |
|
} |
|
|
|
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, |
|
int* fd, float** data, ssize_t* file_size) { |
|
FILE *file = fopen(checkpoint, "rb"); |
|
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } |
|
|
|
uint32_t magic_number; |
|
if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); } |
|
if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); } |
|
|
|
int version; |
|
if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); } |
|
if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); } |
|
int header_size = 256; |
|
|
|
if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } |
|
|
|
uint8_t shared_classifier; |
|
if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); } |
|
int group_size; |
|
if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); } |
|
GS = group_size; |
|
|
|
fseek(file, 0, SEEK_END); |
|
*file_size = ftell(file); |
|
fclose(file); |
|
|
|
*fd = open(checkpoint, O_RDONLY); |
|
if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); } |
|
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); |
|
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } |
|
void* weights_ptr = ((char*)*data) + header_size; |
|
memory_map_weights(weights, config, weights_ptr, shared_classifier); |
|
} |
|
|
|
void build_transformer(Transformer *t, char* checkpoint_path) { |
|
|
|
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); |
|
|
|
malloc_run_state(&t->state, &t->config); |
|
} |
|
|
|
void free_transformer(Transformer* t) { |
|
|
|
free(t->weights.q_tokens); |
|
free(t->weights.token_embedding_table); |
|
free(t->weights.wq); |
|
free(t->weights.wk); |
|
free(t->weights.wv); |
|
free(t->weights.wo); |
|
free(t->weights.w1); |
|
free(t->weights.w2); |
|
free(t->weights.w3); |
|
if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); } |
|
|
|
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } |
|
if (t->fd != -1) { close(t->fd); } |
|
|
|
free_run_state(&t->state); |
|
} |
|
|
|
|
|
|
|
|
|
void rmsnorm(float* o, float* x, float* weight, int size) { |
|
|
|
float ss = 0.0f; |
|
for (int j = 0; j < size; j++) { |
|
ss += x[j] * x[j]; |
|
} |
|
ss /= size; |
|
ss += 1e-5f; |
|
ss = 1.0f / sqrtf(ss); |
|
|
|
for (int j = 0; j < size; j++) { |
|
o[j] = weight[j] * (ss * x[j]); |
|
} |
|
} |
|
|
|
void softmax(float* x, int size) { |
|
|
|
float max_val = x[0]; |
|
for (int i = 1; i < size; i++) { |
|
if (x[i] > max_val) { |
|
max_val = x[i]; |
|
} |
|
} |
|
|
|
float sum = 0.0f; |
|
for (int i = 0; i < size; i++) { |
|
x[i] = expf(x[i] - max_val); |
|
sum += x[i]; |
|
} |
|
|
|
for (int i = 0; i < size; i++) { |
|
x[i] /= sum; |
|
} |
|
} |
|
|
|
void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) { |
|
|
|
|
|
|
|
|
|
int i; |
|
#pragma omp parallel for private(i) |
|
for (i = 0; i < d; i++) { |
|
|
|
float val = 0.0f; |
|
int32_t ival = 0; |
|
int in = i * n; |
|
|
|
|
|
int j; |
|
for (j = 0; j <= n - GS; j += GS) { |
|
for (int k = 0; k < GS; k++) { |
|
ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]); |
|
} |
|
val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS]; |
|
ival = 0; |
|
} |
|
|
|
xout[i] = val; |
|
} |
|
} |
|
|
|
float* forward(Transformer* transformer, int token, int pos) { |
|
|
|
|
|
Config* p = &transformer->config; |
|
TransformerWeights* w = &transformer->weights; |
|
RunState* s = &transformer->state; |
|
float *x = s->x; |
|
int dim = p->dim; |
|
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; |
|
int kv_mul = p->n_heads / p->n_kv_heads; |
|
int hidden_dim = p->hidden_dim; |
|
int head_size = dim / p->n_heads; |
|
|
|
|
|
memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float)); |
|
|
|
|
|
for(unsigned long long l = 0; l < p->n_layers; l++) { |
|
|
|
|
|
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); |
|
|
|
|
|
quantize(&s->xq, s->xb, dim); |
|
matmul(s->q, &s->xq, w->wq + l, dim, dim); |
|
matmul(s->k, &s->xq, w->wk + l, dim, kv_dim); |
|
matmul(s->v, &s->xq, w->wv + l, dim, kv_dim); |
|
|
|
|
|
for (int i = 0; i < p->n_heads; i++) { |
|
for (int j = 0; j < head_size; j += 2) { |
|
float freq = 1.0f / powf(500000.0f, (float)j / (float)head_size); |
|
float val = pos * freq; |
|
float fcr = cosf(val); |
|
float fci = sinf(val); |
|
float q0 = s->q[i * head_size + j]; |
|
float q1 = s->q[i * head_size + j + 1]; |
|
s->q[i * head_size + j] = q0 * fcr - q1 * fci; |
|
s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr; |
|
if (i < p->n_kv_heads) { |
|
float k0 = s->k[i * head_size + j]; |
|
float k1 = s->k[i * head_size + j + 1]; |
|
s->k[i * head_size + j] = k0 * fcr - k1 * fci; |
|
s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr; |
|
} |
|
} |
|
} |
|
|
|
|
|
int loff = l * p->seq_len * kv_dim; |
|
float* key_cache_row = s->key_cache + loff + pos * kv_dim; |
|
float* value_cache_row = s->value_cache + loff + pos * kv_dim; |
|
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); |
|
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); |
|
|
|
|
|
int h; |
|
#pragma omp parallel for private(h) |
|
for (h = 0; h < p->n_heads; h++) { |
|
|
|
float* q = s->q + h * head_size; |
|
|
|
float* att = s->att + h * p->seq_len; |
|
|
|
for (int t = 0; t <= pos; t++) { |
|
|
|
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size; |
|
|
|
float score = 0.0f; |
|
for (int i = 0; i < head_size; i++) { |
|
score += q[i] * k[i]; |
|
} |
|
score /= sqrtf(head_size); |
|
|
|
att[t] = score; |
|
} |
|
|
|
|
|
softmax(att, pos + 1); |
|
|
|
|
|
float* xb = s->xb + h * head_size; |
|
memset(xb, 0, head_size * sizeof(float)); |
|
for (int t = 0; t <= pos; t++) { |
|
|
|
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size; |
|
|
|
float a = att[t]; |
|
|
|
for (int i = 0; i < head_size; i++) { |
|
xb[i] += a * v[i]; |
|
} |
|
} |
|
} |
|
|
|
|
|
quantize(&s->xq, s->xb, dim); |
|
matmul(s->xb2, &s->xq, w->wo + l, dim, dim); |
|
|
|
|
|
for (int i = 0; i < dim; i++) { |
|
x[i] += s->xb2[i]; |
|
} |
|
|
|
|
|
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim); |
|
|
|
|
|
|
|
quantize(&s->xq, s->xb, dim); |
|
matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim); |
|
matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim); |
|
|
|
|
|
for (int i = 0; i < hidden_dim; i++) { |
|
float val = s->hb[i]; |
|
|
|
val *= (1.0f / (1.0f + expf(-val))); |
|
|
|
val *= s->hb2[i]; |
|
s->hb[i] = val; |
|
} |
|
|
|
|
|
quantize(&s->hq, s->hb, hidden_dim); |
|
matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim); |
|
|
|
|
|
for (int i = 0; i < dim; i++) { |
|
x[i] += s->xb[i]; |
|
} |
|
} |
|
|
|
|
|
rmsnorm(x, x, w->rms_final_weight, dim); |
|
|
|
|
|
quantize(&s->xq, x, dim); |
|
matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size); |
|
return s->logits; |
|
} |
|
|
|
|
|
|
|
|
|
typedef struct { |
|
char *str; |
|
int id; |
|
} TokenIndex; |
|
|
|
typedef struct { |
|
char** vocab; |
|
float* vocab_scores; |
|
TokenIndex *sorted_vocab; |
|
int vocab_size; |
|
unsigned int max_token_length; |
|
unsigned char byte_pieces[512]; |
|
} Tokenizer; |
|
|
|
int compare_tokens(const void *a, const void *b) { |
|
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); |
|
} |
|
|
|
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { |
|
|
|
t->vocab_size = vocab_size; |
|
|
|
t->vocab = (char**)malloc(vocab_size * sizeof(char*)); |
|
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); |
|
t->sorted_vocab = NULL; |
|
for (int i = 0; i < 256; i++) { |
|
t->byte_pieces[i * 2] = (unsigned char)i; |
|
t->byte_pieces[i * 2 + 1] = '\0'; |
|
} |
|
|
|
FILE *file = fopen(tokenizer_path, "rb"); |
|
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } |
|
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
|
int len; |
|
for (int i = 0; i < vocab_size; i++) { |
|
if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} |
|
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
|
t->vocab[i] = (char *)malloc(len + 1); |
|
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } |
|
t->vocab[i][len] = '\0'; |
|
} |
|
fclose(file); |
|
} |
|
|
|
void free_tokenizer(Tokenizer* t) { |
|
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } |
|
free(t->vocab); |
|
free(t->vocab_scores); |
|
free(t->sorted_vocab); |
|
} |
|
|
|
char* decode(Tokenizer* t, int prev_token, int token) { |
|
char *piece = t->vocab[token]; |
|
|
|
|
|
|
|
|
|
unsigned char byte_val; |
|
if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { |
|
piece = (char*)t->byte_pieces + byte_val * 2; |
|
} |
|
return piece; |
|
} |
|
|
|
void safe_printf(char *piece, char *out_buffer) { |
|
|
|
|
|
if (piece == NULL) { return; } |
|
if (piece[0] == '\0') { return; } |
|
if (piece[1] == '\0') { |
|
unsigned char byte_val = piece[0]; |
|
if (!(isprint(byte_val) || isspace(byte_val))) { |
|
return; |
|
} |
|
} |
|
strcat(out_buffer, piece); |
|
} |
|
|
|
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { |
|
|
|
TokenIndex tok = { .str = str }; |
|
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); |
|
return res != NULL ? res->id : -1; |
|
} |
|
|
|
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { |
|
|
|
|
|
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } |
|
|
|
if (t->sorted_vocab == NULL) { |
|
|
|
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); |
|
for (int i = 0; i < t->vocab_size; i++) { |
|
t->sorted_vocab[i].str = t->vocab[i]; |
|
t->sorted_vocab[i].id = i; |
|
} |
|
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); |
|
} |
|
|
|
|
|
|
|
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); |
|
size_t str_len = 0; |
|
|
|
|
|
*n_tokens = 0; |
|
|
|
|
|
if (bos) tokens[(*n_tokens)++] = 128000; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (char *c = text; *c != '\0'; c++) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ((*c & 0xC0) != 0x80) { |
|
|
|
|
|
str_len = 0; |
|
} |
|
|
|
|
|
str_buffer[str_len++] = *c; |
|
str_buffer[str_len] = '\0'; |
|
|
|
|
|
|
|
if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) { |
|
continue; |
|
} |
|
|
|
|
|
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
|
|
|
if (id != -1) { |
|
|
|
tokens[(*n_tokens)++] = id; |
|
} else { |
|
|
|
|
|
|
|
for (int i=0; i < str_len; i++) { |
|
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3; |
|
} |
|
} |
|
str_len = 0; |
|
} |
|
|
|
|
|
while (1) { |
|
float best_score = -1e10; |
|
int best_id = -1; |
|
int best_idx = -1; |
|
int best_len = 2; |
|
|
|
|
|
for (int i = 0; i < (*n_tokens - 1); i++) { |
|
|
|
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); |
|
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
|
if (id != -1 && t->vocab_scores[id] > best_score) { |
|
|
|
best_score = t->vocab_scores[id]; |
|
best_id = id; |
|
best_idx = i; |
|
} |
|
} |
|
|
|
|
|
if (best_idx == -1) { |
|
for (int i = 0; i < (*n_tokens - 2); i++) { |
|
|
|
sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]); |
|
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); |
|
if (id != -1 && t->vocab_scores[id] > best_score) { |
|
|
|
best_score = t->vocab_scores[id]; |
|
best_id = id; |
|
best_idx = i; |
|
best_len = 3; |
|
} |
|
} |
|
} |
|
|
|
if (best_idx == -1) { |
|
break; |
|
} |
|
|
|
|
|
tokens[best_idx] = best_id; |
|
|
|
for (int i = best_idx + 1; i < (*n_tokens - best_len + 1); i++) { |
|
tokens[i] = tokens[i + best_len - 1]; |
|
} |
|
(*n_tokens) -= (best_len - 1); |
|
} |
|
|
|
|
|
if (eos) tokens[(*n_tokens)++] = 128001; |
|
|
|
free(str_buffer); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
typedef struct { |
|
float prob; |
|
int index; |
|
} ProbIndex; |
|
|
|
typedef struct { |
|
int vocab_size; |
|
ProbIndex* probindex; |
|
float temperature; |
|
float topp; |
|
unsigned long long rng_state; |
|
} Sampler; |
|
|
|
int sample_argmax(float* probabilities, int n) { |
|
|
|
int max_i = 0; |
|
float max_p = probabilities[0]; |
|
for (int i = 1; i < n; i++) { |
|
if (probabilities[i] > max_p) { |
|
max_i = i; |
|
max_p = probabilities[i]; |
|
} |
|
} |
|
return max_i; |
|
} |
|
|
|
int sample_mult(float* probabilities, int n, float coin) { |
|
|
|
|
|
float cdf = 0.0f; |
|
for (int i = 0; i < n; i++) { |
|
cdf += probabilities[i]; |
|
if (coin < cdf) { |
|
return i; |
|
} |
|
} |
|
return n - 1; |
|
} |
|
|
|
int compare(const void* a, const void* b) { |
|
ProbIndex* a_ = (ProbIndex*) a; |
|
ProbIndex* b_ = (ProbIndex*) b; |
|
if (a_->prob > b_->prob) return -1; |
|
if (a_->prob < b_->prob) return 1; |
|
return 0; |
|
} |
|
|
|
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) { |
|
|
|
|
|
|
|
|
|
|
|
int n0 = 0; |
|
|
|
|
|
|
|
const float cutoff = (1.0f - topp) / (n - 1); |
|
for (int i = 0; i < n; i++) { |
|
if (probabilities[i] >= cutoff) { |
|
probindex[n0].index = i; |
|
probindex[n0].prob = probabilities[i]; |
|
n0++; |
|
} |
|
} |
|
qsort(probindex, n0, sizeof(ProbIndex), compare); |
|
|
|
|
|
float cumulative_prob = 0.0f; |
|
int last_idx = n0 - 1; |
|
for (int i = 0; i < n0; i++) { |
|
cumulative_prob += probindex[i].prob; |
|
if (cumulative_prob > topp) { |
|
last_idx = i; |
|
break; |
|
} |
|
} |
|
|
|
|
|
float r = coin * cumulative_prob; |
|
float cdf = 0.0f; |
|
for (int i = 0; i <= last_idx; i++) { |
|
cdf += probindex[i].prob; |
|
if (r < cdf) { |
|
return probindex[i].index; |
|
} |
|
} |
|
return probindex[last_idx].index; |
|
} |
|
|
|
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) { |
|
sampler->vocab_size = vocab_size; |
|
sampler->temperature = temperature; |
|
sampler->topp = topp; |
|
sampler->rng_state = rng_seed; |
|
|
|
sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex)); |
|
} |
|
|
|
void free_sampler(Sampler* sampler) { |
|
free(sampler->probindex); |
|
} |
|
|
|
unsigned int random_u32(unsigned long long *state) { |
|
|
|
*state ^= *state >> 12; |
|
*state ^= *state << 25; |
|
*state ^= *state >> 27; |
|
return (*state * 0x2545F4914F6CDD1Dull) >> 32; |
|
} |
|
float random_f32(unsigned long long *state) { |
|
return (random_u32(state) >> 8) / 16777216.0f; |
|
} |
|
|
|
int sample(Sampler* sampler, float* logits) { |
|
|
|
int next; |
|
if (sampler->temperature == 0.0f) { |
|
|
|
next = sample_argmax(logits, sampler->vocab_size); |
|
} else { |
|
|
|
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; } |
|
|
|
softmax(logits, sampler->vocab_size); |
|
|
|
float coin = random_f32(&sampler->rng_state); |
|
|
|
if (sampler->topp <= 0 || sampler->topp >= 1) { |
|
|
|
next = sample_mult(logits, sampler->vocab_size, coin); |
|
} else { |
|
|
|
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin); |
|
} |
|
} |
|
return next; |
|
} |
|
|
|
|
|
|
|
|
|
long time_in_ms() { |
|
|
|
struct timespec time; |
|
clock_gettime(CLOCK_REALTIME, &time); |
|
return time.tv_sec * 1000 + time.tv_nsec / 1000000; |
|
} |
|
|
|
|
|
|
|
|
|
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps, char *out_buffer) { |
|
char *empty_prompt = ""; |
|
if (prompt == NULL) { prompt = empty_prompt; } |
|
|
|
|
|
int num_prompt_tokens = 0; |
|
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); |
|
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens); |
|
if (num_prompt_tokens < 1) { |
|
fprintf(stderr, "something is wrong, expected at least 1 prompt token\n"); |
|
exit(EXIT_FAILURE); |
|
} |
|
|
|
|
|
long start = 0; |
|
int next; |
|
int token = prompt_tokens[0]; |
|
int pos = 0; |
|
|
|
while (pos < steps) { |
|
|
|
|
|
float* logits = forward(transformer, token, pos); |
|
|
|
|
|
if (pos < num_prompt_tokens - 1) { |
|
|
|
next = prompt_tokens[pos + 1]; |
|
} else { |
|
|
|
next = sample(sampler, logits); |
|
} |
|
pos++; |
|
|
|
|
|
if ((next == 128001 || next == 128009) && pos > num_prompt_tokens) break; |
|
|
|
char* piece = decode(tokenizer, token, next); |
|
safe_printf(piece, out_buffer); |
|
fflush(stdout); |
|
token = next; |
|
|
|
|
|
if (start == 0) { start = time_in_ms(); } |
|
} |
|
strcat(out_buffer, "\n"); |
|
|
|
|
|
if (pos > 1) { |
|
long end = time_in_ms(); |
|
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000); |
|
} |
|
|
|
free(prompt_tokens); |
|
} |
|
|
|
void read_stdin(const char* guide, char* buffer, size_t bufsize, char *out_buffer) { |
|
|
|
strcat(out_buffer, guide); |
|
if (fgets(buffer, bufsize, stdin) != NULL) { |
|
size_t len = strlen(buffer); |
|
if (len > 0 && buffer[len - 1] == '\n') { |
|
buffer[len - 1] = '\0'; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, |
|
char *cli_user_prompt, char *cli_system_prompt, int steps, char *out_buffer) { |
|
|
|
|
|
|
|
char* system_prompt = (char*)malloc(32768 * sizeof(char)); |
|
char* user_prompt = (char*)malloc(32768 * sizeof(char)); |
|
int num_prompt_tokens = 0; |
|
int* prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
|
int* system_prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
|
int* user_prompt_tokens = (int*)malloc(32768 * sizeof(int)); |
|
int user_idx=0; |
|
|
|
|
|
int8_t user_turn = 1; |
|
int next; |
|
int token; |
|
|
|
int pos = 0; |
|
while (pos < steps) { |
|
|
|
|
|
if (user_turn) { |
|
|
|
if (pos == 0) { |
|
|
|
prompt_tokens[num_prompt_tokens++] = 128000; |
|
prompt_tokens[num_prompt_tokens++] = 128006; |
|
prompt_tokens[num_prompt_tokens++] = 9125; |
|
prompt_tokens[num_prompt_tokens++] = 128007; |
|
prompt_tokens[num_prompt_tokens++] = 271; |
|
if (cli_system_prompt == NULL) { |
|
|
|
read_stdin("Enter system prompt (optional): ", system_prompt, 32768, out_buffer); |
|
} else { |
|
|
|
strcpy(system_prompt, cli_system_prompt); |
|
} |
|
if (system_prompt != NULL) { |
|
int num_system_prompt_tokens = 0; |
|
encode(tokenizer, system_prompt, 0, 0, system_prompt_tokens, &num_system_prompt_tokens); |
|
for (int i=0; i<num_system_prompt_tokens; i++) { |
|
prompt_tokens[num_prompt_tokens++] = system_prompt_tokens[i]; |
|
} |
|
} |
|
prompt_tokens[num_prompt_tokens++] = 128009; |
|
} else { |
|
num_prompt_tokens = 0; |
|
} |
|
prompt_tokens[num_prompt_tokens++] = 128006; |
|
prompt_tokens[num_prompt_tokens++] = 882; |
|
prompt_tokens[num_prompt_tokens++] = 128007; |
|
prompt_tokens[num_prompt_tokens++] = 271; |
|
|
|
if (pos == 0 && cli_user_prompt != NULL) { |
|
|
|
strcpy(user_prompt, cli_user_prompt); |
|
} else { |
|
|
|
read_stdin("User (or exit): ", user_prompt, 32768, out_buffer); |
|
if(strcmp(user_prompt, "exit")==0) break; |
|
} |
|
int num_user_prompt_tokens = 0; |
|
|
|
encode(tokenizer, user_prompt, 0, 0, user_prompt_tokens, &num_user_prompt_tokens); |
|
for (int i=0; i<num_user_prompt_tokens; i++) { |
|
prompt_tokens[num_prompt_tokens++] = user_prompt_tokens[i]; |
|
} |
|
prompt_tokens[num_prompt_tokens++] = 128009; |
|
prompt_tokens[num_prompt_tokens++] = 128006; |
|
prompt_tokens[num_prompt_tokens++] = 78191; |
|
prompt_tokens[num_prompt_tokens++] = 128007; |
|
prompt_tokens[num_prompt_tokens++] = 271; |
|
|
|
|
|
user_idx = 0; |
|
user_turn = 0; |
|
strcat(out_buffer, "Assistant: "); |
|
} |
|
|
|
|
|
if (user_idx < num_prompt_tokens) { |
|
|
|
token = prompt_tokens[user_idx++]; |
|
} else { |
|
|
|
token = next; |
|
} |
|
|
|
if (user_idx >= num_prompt_tokens && (token == 128009 || token == 128001)) { user_turn = 1; } |
|
|
|
|
|
float* logits = forward(transformer, token, pos); |
|
next = sample(sampler, logits); |
|
pos++; |
|
|
|
if (user_idx >= num_prompt_tokens && next != 128009 && next != 128001 && next != 128006) { |
|
|
|
char* piece = decode(tokenizer, token, next); |
|
safe_printf(piece, out_buffer); |
|
fflush(stdout); |
|
} |
|
if (user_idx >= num_prompt_tokens && next == 128009 || next == 128001) { printf("\n"); } |
|
} |
|
strcat(out_buffer, "\n"); |
|
free(prompt_tokens); |
|
free(system_prompt_tokens); |
|
free(user_prompt_tokens); |
|
free(system_prompt); |
|
free(user_prompt); |
|
} |
|
|
|
typedef struct { |
|
char *checkpoint_path; |
|
char *tokenizer_path; |
|
float temperature; |
|
float topp; |
|
int steps; |
|
char *prompt; |
|
unsigned long long rng_seed; |
|
char *mode; |
|
char *system_prompt; |
|
char out_buffer[32768]; |
|
Transformer transformer; |
|
Tokenizer tokenizer; |
|
Sampler sampler; |
|
} Main; |
|
|
|
#define DEFAULT_CHECKPOINT_PATH "model.bin" |
|
#define DEFAULT_TOKENIZER_PATH "tokenizer.bin" |
|
#define DEFAULT_MAIN_MODE "generate" |
|
|
|
__declspec(dllexport) Main *build_main(char* checkpoint_path, char* tokenizer_path, float temperature, float topp, int steps, |
|
char* prompt, unsigned long long rng_seed, char* mode, char* system_prompt) { |
|
|
|
Main *ret = (Main *)calloc(1, sizeof(Main)); |
|
if (!ret) return ret; |
|
ret->checkpoint_path = checkpoint_path ? checkpoint_path : DEFAULT_CHECKPOINT_PATH; |
|
ret->tokenizer_path = tokenizer_path ? tokenizer_path : DEFAULT_TOKENIZER_PATH; |
|
ret->temperature = (temperature < 0.0) ? 0.0f : (temperature ? temperature : 1.0f); |
|
ret->topp = topp ? topp : 0.9f; |
|
ret->steps = (steps < 0) ? 0 : steps; |
|
ret->prompt = prompt ? system_prompt : NULL; |
|
ret->rng_seed = (rng_seed <= 0) ? (unsigned int)time(NULL) : rng_seed; |
|
ret->mode = mode ? mode : DEFAULT_MAIN_MODE; |
|
ret->system_prompt = system_prompt ? system_prompt : NULL; |
|
|
|
build_transformer(&ret->transformer, ret->checkpoint_path); |
|
ret->steps = (steps == 0 || steps > ret->transformer.config.seq_len) ? ret->transformer.config.seq_len : steps; |
|
|
|
build_tokenizer(&ret->tokenizer, ret->tokenizer_path, ret->transformer.config.vocab_size); |
|
|
|
build_sampler(&ret->sampler, ret->transformer.config.vocab_size, ret->temperature, ret->topp, ret->rng_seed); |
|
return ret; |
|
} |
|
|
|
__declspec(dllexport) void free_main(Main *m) { |
|
|
|
free_sampler(&m->sampler); |
|
free_tokenizer(&m->tokenizer); |
|
free_transformer(&m->transformer); |
|
free(m); |
|
} |
|
|
|
__declspec(dllexport) char *run_main(Main *m) { |
|
|
|
if (strcmp(m->mode, "generate") == 0) { |
|
generate(&m->transformer, &m->tokenizer, &m->sampler, m->prompt, m->steps, m->out_buffer); |
|
} else if (strcmp(m->mode, "chat") == 0) { |
|
chat(&m->transformer, &m->tokenizer, &m->sampler, m->prompt, m->system_prompt, m->steps, m->out_buffer); |
|
} else { |
|
fprintf(stderr, "unknown mode: %s\n", m->mode); |
|
} |
|
return m->out_buffer; |
|
} |
|
|