import transformers | |
import os | |
import torch | |
import requests | |
MAX_LEN = 150 #256 | |
TRAIN_BATCH_SIZE = 8 | |
VALID_BATCH_SIZE = 4 | |
EPOCHS = 5 | |
# Folder to contain all the datasets | |
# from huggingface_hub import hf_hub_download | |
# print("hi") | |
# hf_hub_download(repo_id="thak123/bert-emoji-latvian-twitter-classifier", filename="model.bin",local_dir="./") | |
# from huggingface_hub import snapshot_download | |
# snapshot_download(repo_id="thak123/bert-emoji-latvian-twitter-classifier", allow_patterns="*.bin") | |
# import requests | |
URL = "https://huggingface.co/thak123/bert-emoji-latvian-twitter-classifier/resolve/main/model.bin" | |
response = requests.get(URL) | |
open("model.bin", "wb").write(response.content) | |
DATASET_LOCATION = "" # | |
MODEL_PATH = "model.bin" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# 7 EPOCH Version | |
BERT_PATH = "FFZG-cleopatra/bert-emoji-latvian-twitter" | |
# TODO check if lower casing is required | |
# BertTokenizer | |
TOKENIZER = transformers.BertTokenizer.from_pretrained( | |
BERT_PATH, | |
do_lower_case=True | |
) | |
#################################################################################################################################### | |