You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
3.1 KiB

3 years ago
from textgenrnn import textgenrnn
from datetime import datetime
import os
model_cfg = {
'word_level': False, # set to True if want to train a word-level model (requires more data and smaller max_length)
'rnn_size': 256, # number of LSTM cells of each layer (128/256 recommended)
'rnn_layers': 3, # number of LSTM layers (>=2 recommended)
'rnn_bidirectional': True, # consider text both forwards and backward, can give a training boost
'max_length': 30, # number of tokens to consider before predicting the next (20-40 for characters, 5-10 for words recommended)
'max_words': 10000, # maximum number of words to model; the rest will be ignored (word-level model only)
}
train_cfg = {
'line_delimited': True, # set to True if each text has its own line in the source file
'num_epochs': 20, # set higher to train the model for longer
'gen_epochs': 5, # generates sample text from model after given number of epochs
'train_size': 0.8, # proportion of input data to train on: setting < 1.0 limits model from learning perfectly
'dropout': 0.0, # ignore a random proportion of source tokens each epoch, allowing model to generalize better
'validation': False, # If train__size < 1.0, test on holdout dataset; will make overall training slower
'is_csv': False # set to True if file is a CSV exported from Excel/BigQuery/pandas
}
file_name = "classes.txt"
model_name = '_classes' # change to set file name of resulting trained models/texts
textgen = textgenrnn(name=model_name)
train_function = textgen.train_from_file if train_cfg['line_delimited'] else textgen.train_from_largetext_file
train_function(
file_path=file_name,
new_model=True,
num_epochs=train_cfg['num_epochs'],
gen_epochs=train_cfg['gen_epochs'],
batch_size=1024,
train_size=train_cfg['train_size'],
dropout=train_cfg['dropout'],
validation=train_cfg['validation'],
is_csv=train_cfg['is_csv'],
rnn_layers=model_cfg['rnn_layers'],
rnn_size=model_cfg['rnn_size'],
rnn_bidirectional=model_cfg['rnn_bidirectional'],
max_length=model_cfg['max_length'],
dim_embeddings=100,
word_level=model_cfg['word_level'])
# this temperature schedule cycles between 1 very unexpected token, 1 unexpected token, 2 expected tokens, repeat.
# changing the temperature schedule can result in wildly different output!
temperature = [1.0, 0.5, 0.2, 0.2]
prefix = None # if you want each generated text to start with a given seed text
if train_cfg['line_delimited']:
n = 100
max_gen_length = 60 if model_cfg['word_level'] else 300
else:
n = 1
max_gen_length = 2000 if model_cfg['word_level'] else 10000
timestring = datetime.now().strftime('%Y%m%d_%H%M%S')
gen_file = '{}_gentext_{}.txt'.format(model_name, timestring)
textgen.generate_to_file(gen_file,
temperature=temperature,
prefix=prefix,
n=n,
max_gen_length=max_gen_length)
file1 = open("class_out.txt", "w")
file1.write(gen_file)
file1.close()