PiccoviralesGPT / README.md
avuhong's picture
Update README.md
e6ec0ef
|
raw
history blame
5.72 kB
metadata
license: apache-2.0
tags:
  - generated_from_trainer
metrics:
  - accuracy
model-index:
  - name: output_v3
    results: []
widget:
  - text: <|endoftext|>MAADGYLPDWLEDNLSEGIREWWALKPGAPQPKANQQHQDNARGLVLPGYKYLGPGNGL

output_v3

This model is a fine-tuned version of avuhong/ParvoGPT2 on an unknown dataset. It achieves the following results on the evaluation set:

  • Loss: 0.4775
  • Accuracy: 0.9290

Model description

This model is a GPT2-like model for generating capsid amino acid sequences. It was trained exclusively on capsid aa_seqs of Piccovirales members.

Intended uses & limitations

As a typical GPT model, it can be used to generate new sequences or used to evaluate the perplexity of given sequences.

Generate novel sequences for viral capsid proteins

from transformers import pipeline
protgpt2 = pipeline('text-generation', model="avuhong/PiccoviralesGPT")
sequences = protgpt2("<|endoftext|>", max_length=750, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=10, eos_token_id=0)

Calculate the perplexity of a protein sequence

def calculatePerplexity(sequence, model, tokenizer):
    input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0) 
    input_ids = input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, logits = outputs[:2]
    return math.exp(loss)

def split_sequence(sequence):
    chunks = []
    max_i = 0
    for i in range(0, len(sequence), 60):
        chunk = sequence[i:i+60]
        
        if i == 0:
            chunk = '<|endoftext|>' + chunk[:-1]
        chunks.append(chunk)
        max_i = i
    
    chunks = '\n'.join(chunks)
    
    if max_i+61==len(sequence):
        chunks = chunks+"\n<|endoftext|>"
    else:
        chunks = chunks+"<|endoftext|>"
    return chunks

seq = "MAADGYLPDWLEDNLSEGIREWWALKPGAPQPKANQQHQDNARGLVLPGYKYLGPGNGLDKGEPVNAADAAALEHDKAYDQQLKAGDNPYLKYNHADAEFQERLKEDTSFGGNLGRAVFQAKKRLLEPLGLVEEAAKTAPGKKRPVEQSPQEPDSSAGIGKSGAQPAKKRLNFGQTGDTESVPDPQPIGEPPAAPSGVGSLTMASGGGAPVADNNEGADGVGSSSGNWHCDSQWLGDRVITTSTRTWALPTYNNHLYKQISNSTSGGSSNDNAYFGYSTPWGYFDFNRFHCHFSPRDWQRLINNNWGFRPKRLNFKLFNIQVKEVTDNNGVKTIANNLTSTVQVFTDSDYQLPYVLGSAHEGCLPPFPADVFMIPQYGYLTLNDGSQAVGRSSFYCLEYFPSQMLRTGNNFQFSYEFENVPFHSSYAHSQSLDRLMNPLIDQYLYYLSKTINGSGQNQQTLKFSVAGPSNMAVQGRNYIPGPSYRQQRVSTTVTQNNNSEFAWPGASSWALNGRNSLMNPGPAMASHKEGEDRFFPLSGSLIFGKQGTGRDNVDADKVMITNEEEIKTTNPVATESYGQVATNHQSAQAQAQTGWVQNQGILPGMVWQDRDVYLQGPIWAKIPHTDGNFHPSPLMGGFGMKHPPPQILIKNTPVPADPPTAFNKDKLNSFITQYSTGQVSVEIEWELQKENSKRWNPEIQYTSNYYKSNNVEFAVNTEGVYSEPRPIGTRYLTRNL"
seq = split_sequence(seq)
print(f"{calculatePerplexity(seq, model, tokenizer):.2f}")

Training and evaluation data

Traning script is included in bash file in this repository.

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 1
  • eval_batch_size: 1
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 2
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 8
  • total_eval_batch_size: 2
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 32.0
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss Accuracy
No log 1.0 220 1.1623 0.8225
No log 2.0 440 0.9566 0.8539
1.1942 3.0 660 0.8456 0.8709
1.1942 4.0 880 0.7719 0.8801
0.7805 5.0 1100 0.7224 0.8872
0.7805 6.0 1320 0.6895 0.8928
0.6257 7.0 1540 0.6574 0.8972
0.6257 8.0 1760 0.6289 0.9014
0.6257 9.0 1980 0.6054 0.9045
0.5385 10.0 2200 0.5881 0.9077
0.5385 11.0 2420 0.5709 0.9102
0.4778 12.0 2640 0.5591 0.9121
0.4778 13.0 2860 0.5497 0.9143
0.427 14.0 3080 0.5385 0.9161
0.427 15.0 3300 0.5258 0.9180
0.394 16.0 3520 0.5170 0.9195
0.394 17.0 3740 0.5157 0.9212
0.394 18.0 3960 0.5038 0.9221
0.363 19.0 4180 0.4977 0.9234
0.363 20.0 4400 0.4976 0.9236
0.3392 21.0 4620 0.4924 0.9247
0.3392 22.0 4840 0.4888 0.9255
0.33 23.0 5060 0.4890 0.9262
0.33 24.0 5280 0.4856 0.9268
0.3058 25.0 5500 0.4803 0.9275
0.3058 26.0 5720 0.4785 0.9277
0.3058 27.0 5940 0.4813 0.9281
0.2973 28.0 6160 0.4799 0.9282
0.2973 29.0 6380 0.4773 0.9285
0.2931 30.0 6600 0.4778 0.9286
0.2931 31.0 6820 0.4756 0.9290
0.2879 32.0 7040 0.4775 0.9290

Framework versions

  • Transformers 4.26.1
  • Pytorch 1.13.1+cu117
  • Datasets 2.9.0
  • Tokenizers 0.13.2