from textgenrnn import textgenrnn from datetime import datetime import os model_cfg = { 'word_level': False, # set to True if want to train a word-level model (requires more data and smaller max_length) 'rnn_size': 256, # number of LSTM cells of each layer (128/256 recommended) 'rnn_layers': 3, # number of LSTM layers (>=2 recommended) 'rnn_bidirectional': True, # consider text both forwards and backward, can give a training boost 'max_length': 30, # number of tokens to consider before predicting the next (20-40 for characters, 5-10 for words recommended) 'max_words': 10000, # maximum number of words to model; the rest will be ignored (word-level model only) } train_cfg = { 'line_delimited': True, # set to True if each text has its own line in the source file 'num_epochs': 20, # set higher to train the model for longer 'gen_epochs': 5, # generates sample text from model after given number of epochs 'train_size': 0.8, # proportion of input data to train on: setting < 1.0 limits model from learning perfectly 'dropout': 0.0, # ignore a random proportion of source tokens each epoch, allowing model to generalize better 'validation': False, # If train__size < 1.0, test on holdout dataset; will make overall training slower 'is_csv': False # set to True if file is a CSV exported from Excel/BigQuery/pandas } file_name = "classes.txt" model_name = '_classes' # change to set file name of resulting trained models/texts textgen = textgenrnn(name=model_name) train_function = textgen.train_from_file if train_cfg['line_delimited'] else textgen.train_from_largetext_file train_function( file_path=file_name, new_model=True, num_epochs=train_cfg['num_epochs'], gen_epochs=train_cfg['gen_epochs'], batch_size=1024, train_size=train_cfg['train_size'], dropout=train_cfg['dropout'], validation=train_cfg['validation'], is_csv=train_cfg['is_csv'], rnn_layers=model_cfg['rnn_layers'], rnn_size=model_cfg['rnn_size'], rnn_bidirectional=model_cfg['rnn_bidirectional'], max_length=model_cfg['max_length'], dim_embeddings=100, word_level=model_cfg['word_level']) # this temperature schedule cycles between 1 very unexpected token, 1 unexpected token, 2 expected tokens, repeat. # changing the temperature schedule can result in wildly different output! temperature = [1.0, 0.5, 0.2, 0.2] prefix = None # if you want each generated text to start with a given seed text if train_cfg['line_delimited']: n = 100 max_gen_length = 60 if model_cfg['word_level'] else 300 else: n = 1 max_gen_length = 2000 if model_cfg['word_level'] else 10000 timestring = datetime.now().strftime('%Y%m%d_%H%M%S') gen_file = '{}_gentext_{}.txt'.format(model_name, timestring) textgen.generate_to_file(gen_file, temperature=temperature, prefix=prefix, n=n, max_gen_length=max_gen_length) file1 = open("class_out.txt", "w") file1.write(gen_file) file1.close()