Transformer-DeID: Deidentification of free-text clinical notes with transformers 1.0.0

File: <base>/eval_each_epoch.py (5,813 bytes)
import argparse
import math

from datetime import datetime
import logging
from pathlib import Path
import os
from tqdm import tqdm

import numpy as np

from transformers import AutoModelForTokenClassification
from transformers import Trainer, TrainingArguments
from datasets import load_metric

# local packages
from transformer_deid.evaluation import compute_metrics

from transformer_deid.train import which_transformer_arch
from transformer_deid.model_evaluation_functions import load_data

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

multi_class_fields = [
    'AGEprecision', 'AGErecall', 'AGEf1', 'AGEnumber', 'CONTACTprecision',
    'CONTACTrecall', 'CONTACTf1', 'CONTACTnumber', 'DATEprecision',
    'DATErecall', 'DATEf1', 'DATEnumber', 'IDprecision', 'IDrecall', 'IDf1',
    'IDnumber', 'LOCATIONprecision', 'LOCATIONrecall', 'LOCATIONf1',
    'LOCATIONnumber', 'NAMEprecision', 'NAMErecall', 'NAMEf1', 'NAMEnumber',
    'PROFESSIONprecision', 'PROFESSIONrecall', 'PROFESSIONf1',
    'PROFESSIONnumber', 'overall_precision', 'overall_recall', 'overall_f1',
    'overall_accuracy'
]
binary_fields = [
    'PHIprecision', 'PHIrecall', 'PHIf1', 'PHInumber', 'overall_precision',
    'overall_recall', 'overall_f1', 'overall_accuracy'
]


def flatten_dict(d):
    """
    Return flattened version of the evaluation result dict
    """
    out = {}
    for key in d:
        if type(d[key]) is dict:
            child = flatten_dict(d[key])
            for child_key in child:
                val = child[child_key]
                if isinstance(val, np.int64):
                    val = int(val)
                out[key + child_key] = val
        else:
            out[key] = d[key]
    return out


def add_row(
    path, epochs, results_multiclass, results_binary, multi_class_fields,
    binary_fields, test_loss
):
    """
    Add row to worksheet
    fields: [epochs] + multi_class_fields + binary_fields
    """
    root = Path(path).parent

    row = [epochs] + [
        flatten_dict(results_multiclass).get(field)
        for field in multi_class_fields
    ] + [flatten_dict(results_binary).get(field)
         for field in binary_fields] + [test_loss]

    text_metrics = ','.join(map(str, row)) + '\n'

    with open(str(root) + '/training_eval.csv', 'at') as f:
        f.write(text_metrics)
    # worksheet.append_row(row, table_range='A1')


def eval_checkpoints(
    path, deid_task, train_dataset, val_dataset, test_dataset, training_args
):
    step = int(path.split('-')[-1])
    steps_per_epoch = math.ceil(
        len(train_dataset) / training_args.per_device_train_batch_size
    )
    epoch = step / steps_per_epoch

    model = AutoModelForTokenClassification.from_pretrained(
        path, num_labels=len(deid_task.labels)
    )

    model.eval()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )

    predictions, labels, metrics = trainer.predict(test_dataset)
    predicted_label = np.argmax(predictions, axis=2)

    metric_dir = "transformer_deid/token_evaluation.py"
    metric = load_metric(metric_dir)

    results_multiclass = compute_metrics(
        predicted_label, labels, deid_task.labels, metric=metric
    )
    results_binary = compute_metrics(
        predicted_label,
        labels,
        deid_task.labels,
        metric=metric,
        binary_evaluation=True
    )

    add_row(
        path, epoch, results_multiclass, results_binary, multi_class_fields,
        binary_fields, metrics['test_loss']
    )


def parse_args():
    parser = argparse.ArgumentParser(
        description='Evaluate transformer-based model at each checkpoint.'
    )

    parser.add_argument(
        '-n',
        '--task_name',
        type=str,
        help=
        'name of folder containing train and test data; defaults to i2b2_2014',
        default='i2b2_2014'
    )

    parser.add_argument(
        '-m',
        '--model',
        type=str,
        help='folder containing checkpoint files',
        default='bert'
    )

    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    root = f'{args.model}'
    arch = args.model.split('results')[0].lower()
    epochs = int(args.model.split('results')[1])
    task_name = args.task_name

    _, tokenizer, _ = which_transformer_arch(arch)

    dataDir = f'{task_name}'
    testDir = f'{task_name}/test'

    deid_task, train_dataset, val_dataset, test_dataset = load_data(
        task_name, dataDir, testDir, tokenizer
    )

    train_batch_size = 8

    training_args = TrainingArguments(
        output_dir=root,
        num_train_epochs=epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        save_strategy='steps',
        eval_steps=1155
    )
    
    if not os.path.exists(str(root) + '/training_eval.csv'):
        with open(str(root) + '/training_eval.csv', 'wt') as f:
            header = 'epoch,' + ','.join(
                map(str, multi_class_fields + binary_fields + ['test_loss'])
            ) + '\n'
            f.write(header)

    checkpoints = [
        item for item in os.listdir(root)
        if 'checkpoint' in item and os.path.isdir(os.path.join(root, item))
    ]

    for item in tqdm(sorted(checkpoints, key=lambda x: int(x.split('-')[1]))):
        path = os.path.join(root, item)
        eval_checkpoints(
            path, deid_task, train_dataset, val_dataset, test_dataset,
            training_args
        )


if __name__ == '__main__':
    main()