You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.8 KiB
38 lines
1.8 KiB
|
3 years ago
|
from aitextgen.TokenDataset import TokenDataset, merge_datasets
|
||
|
|
from aitextgen.tokenizers import train_tokenizer
|
||
|
|
from aitextgen.utils import GPT2ConfigCPU
|
||
|
|
from aitextgen import aitextgen
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
# The name of the downloaded Shakespeare text for training
|
||
|
|
file_name = "classes.txt"
|
||
|
|
spell_file_name = "spelllist.txt"
|
||
|
|
monster_file_name = "monsters.txt"
|
||
|
|
|
||
|
|
# Train a custom BPE Tokenizer on the downloaded text
|
||
|
|
# This will save one file: `aitextgen.tokenizer.json`, which contains the
|
||
|
|
# information needed to rebuild the tokenizer.
|
||
|
|
train_tokenizer(file_name)
|
||
|
|
tokenizer_file = "aitextgen.tokenizer.json"
|
||
|
|
|
||
|
|
# GPT2ConfigCPU is a mini variant of GPT-2 optimized for CPU-training
|
||
|
|
# e.g. the # of input tokens here is 64 vs. 1024 for base GPT-2.
|
||
|
|
config = GPT2ConfigCPU()
|
||
|
|
|
||
|
|
# Instantiate aitextgen using the created tokenizer and config
|
||
|
|
ai = aitextgen(tokenizer_file=tokenizer_file, config=config, to_gpu=True)
|
||
|
|
|
||
|
|
# You can build datasets for training by creating TokenDatasets,
|
||
|
|
# which automatically processes the dataset with the appropriate size.
|
||
|
|
class_data = TokenDataset(file_name, tokenizer_file=tokenizer_file, block_size=64, line_by_line=True)
|
||
|
|
spell_data = TokenDataset(spell_file_name, tokenizer_file=tokenizer_file, block_size=64, line_by_line=True)
|
||
|
|
monster_data = TokenDataset(monster_file_name, tokenizer_file=tokenizer_file, block_size=64, line_by_line=True)
|
||
|
|
|
||
|
|
data = merge_datasets([class_data, spell_data, monster_data])
|
||
|
|
|
||
|
|
# Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder.
|
||
|
|
# On a 2020 8-core iMac, this took ~25 minutes to run.
|
||
|
|
ai.train(data, batch_size=8, num_steps=50000, generate_every=5000, save_every=5000)
|
||
|
|
|
||
|
|
# Generate text from it!
|
||
|
|
ai.generate(10)
|