Attention-based Deep Multiple Instance Learning

mediumThis post was originally published by Jonathan Glaser at Medium [AI]

for prostate cancer diagnosis using PyTorch and AWS SageMaker data parallelism

Introduction

Before diving into code, let’s step back and consider why artificial intelligence is posed to transform healthcare.

The remarkable advancements in AI that we see today are greatly attributed to the success of deep neural networks. This new era would not be possible without a perfect storm of the following four driving forces:

  1. The increasing availability of massive datasets such as ImageNet’s 15 million labeled images, Facebook’s library of billions of images, YouTube’s video library which grows by three hundred hours of video per minute, and Tesla’s collection of driving data which adds 1 million miles of data per hour¹.
  2. The use of graphic processing units (GPUs), and later more AI-specialized hardware called tensor processing units (TPUs), which are optimized for training deep learning models: their large amount of cores enables them to process large amount of data and perform multiple computations in parallel. A 2018 report by OpenAI² proposed that prior to 2012, AI compute growth closely tracked Moore’s law, doubling every two years, and that post-2012, compute has been doubling every 3–4 months. Overall, since 2012, this compute metric has grown by a factor of more than 300,000, while a two year doubling period would have only yielded a 16x increase.
  3. Cloud computing which made the ability to store large datasets and use them to train models more accessible and economical.
  4. Open-source algorithmic development modules such as Facebook’s PyTorch, Google’s TensorFlow, Microsoft’s Cognitive Kit, and others.

This comes at a time when physicians are in desperate need of improvement to their workflow. Since 1975, employment in the healthcare industry has increased from 4 million to 16 million, and the amount of spending per patient has increased from $550 per year to over $11,000 per year. Despite this, the average time allotted per office visit has dwindled from 60 minutes for new patients and 30 minutes for returning patients in 1975 to 12 minutes for new patients and 7 minutes for returning patients today. Doctors are also preoccupied with electronic health records, managed care, health maintenance organizations, and relative value units¹.

As Eric Topol, among others, asserts, the rise of AI offers an exciting opportunity to revolutionize the healthcare industry. Increased workflow and efficiency can afford clinicians more time to connect with patients. Paradoxically, the rise of machines will restore the humanity in medicine, and allow medical professionals to get back in touch with their motivations for pursuing a medical career in the first place.

One of the most effective applications of AI in healthcare has been in medical imaging. Radiology, pathology, and dermatology are specialties which rely on visual pattern analysis, and are therefore positioned to undergo a rapid and dramatic transformation due to integration with AI. Here, we focus on what that might look like for pathology.

Pathologists play a crucial role in diagnosing cancer, and their report helps dictate a patient’s treatment strategy. Typically, pathologists look at H&E stained tissue samples under a microscope, and describe the types of cells they see, how they’re arranged, whether they’re abnormal, and any other features that are important for diagnosis. The practice of using microscopes to examine glass slides containing tissue samples has been largely unchanged for a century. In recent years, however, a new era has emerged in which these slides are digitized using digital slide scanners to produce whole slide images (WSIs) which can then be examined on a computer. Pathologists have been slow to adopt WSI and other digital techniques which in turn has caused the encroachment of AI into pathology to be slower than expected. Nevertheless, WSIs have laid the groundwork for incorporating neural network image processing in pathology thereby making a new AI-assisted era imminent¹ ³.

Typical setup for conversion of glass slides into digital slides [Source]

One set of AI technologies has been directed towards simplifying routine workflows that are typically performed by human pathologists: for example, detection of tumor tissue in biopsy samples and determination of tumor subtype based on morphology. This can help reduce costs, turnaround time, and misdiagnoses. A major milestone for AI in pathology was the 2016 CAMELYON Challenge which set the goal of developing algorithms to detect cancer metastases in lymph node images. The best algorithms performed on par with pathologists. Since then, it’s been discovered that deep learning algorithms can uncover novel, abstract features from WSIs that perform better than traditional features at determining outcomes such as survival, treatment response, and genetic defects. It is remarkable that we are able to derive these insights directly from H&E slides which are easily accessible from pathology labs rather than performing additional tests which may be expensive and time consuming⁴.

As discussed in my previous post, multiple instance learning (MIL) is a variation of supervised learning in which only a single class label is assigned to a given bag of instances. This formulation allows us to leverage weakly labeled data, and naturally fits various problems in diverse fields such as computer vision and document classification.

I also discussed how bag label probability, modeled as a Bernoulli distribution, can be fully parametrized by neural networks. Furthermore, in contrast to typical aggregation operators such as mean and max, a modified version of the attention mechanism as an aggregation operator allows for differentiability and a greater degree of interpretability.

Deep MIL approach with an attention mechanism for MIL pooling as proposed by Ilse et al.⁶

The MIL formulation is appropriate for computer-aided diagnosis for three main reasons:

  1. Processing whole slide images consisting of billions of pixels is computationally infeasible. For scale, roughly 470 pathology images contain approximately the same amount of pixels as the entire ImageNet dataset⁵. It is therefore tempting to divide each medical image into smaller patches which can be further considered a bag with a single label.
  2. Supervised approaches exist which require pixel-level annotations that outline abnormalities in a patient’s medical scan. However, these require pathologists to spend large amounts of time on data preparation thereby interfering with their daily routines. Since MIL techniques only require weak labels (i.e. the overall diagnosis of a patient), it holds the promise of significantly reducing pathologists’ workloads.
  3. MIL naturally fits the task of imaging-based patient diagnosis: diseased tissue samples have both abnormal and healthy regions, while healthy tissue samples have only healthy regions. These can easily be represented as bags, each with a single label.

Implementation of attention-based deep MIL model using PyTorch

The goal of the Prostate cANcer graDe Assessment (PANDA) Kaggle Challenge is to develop a model for diagnosing prostate cancer using a training set of 11,000 whole slide images of digitized H&E-stained prostate biopsies. Since this dataset was easily accessible to the public, I decided to use it to train a deep MIL model based on the architecture outlined in Attention-based Deep Multiple Instance Learning⁶.

Each tissue sample in the PANDA Challenge dataset is classified into Gleason patterns based on the architectural growth patterns of the tumor, and a corresponding ISUP grade on a scale of 1–5. Gleason scores are determined based on the extent to which white branched cavities, or glandular tissue, persist throughout the tissue sample. An increased loss in glandular tissue implies greater severity and corresponds to a higher Gleason score. If multiple Gleason patterns are present in a single biopsy, they can be broken down into the most and second most frequently occurring patterns (majority and minority, respectively) as judged by a pathologist.

Obtained from PANDA Kaggle Challenge description [Source]

In order to make use of the dataset in a deep MIL model, I referred to the following Kaggle notebook in order to divide the WSIs into collections of 16x128x128 tiles each.

Each WSI is divided into 16 equally-sized tiles which are fed as input into our deep MIL model

It is then easy to visualize this in the context of the MIL formulation:

Each collection of 16 tiles can be reframed as a bag consisting of 16 instances

As in Ilse et al.⁶, bags are labeled as either malignant or benign. Slides with ISUP grades of 1,2,3,4, or 5 are assigned a label of 1 corresponding to malignant, and those with ISUP grades of 0 are assigned a label of 0 corresponding to benign.

Here we implement a modified version of the model used in Ilse et al.⁶ which takes into account the dataset described above. My previous post explains deep multiple instance learning in further detail.

import torch
import torch.nn.functional as F
import torch.nn as nn 
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 512 # 512 node fully connected layer
        self.D = 128 # 128 node attention layer
        self.K = 1        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 36, kernel_size=4),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(36, 48, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

self.feature_extractor_part2 = nn.Sequential(
nn.Linear(16 * 48 * 30 * 30, self.L),
nn.ReLU(),
nn.Dropout(),
nn.Linear(self.L, self.L),
nn.ReLU(),
nn.Dropout()
)

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )        self.classifier = nn.Sequential(
            nn.Linear(self.L * self.K, 1),
            nn.Sigmoid()
        )    def forward(self, x):
        x = x.squeeze(0)        H = self.feature_extractor_part1(x)
        H = H.view(-1, 16 * 48 * 30 * 30)
        H = self.feature_extractor_part2(H)        A = self.attention(H) # NxK
        A = torch.transpose(A, 1, 0) # KxN
        A = F.softmax(A, dim=1) # softmax over N        M = torch.mm(A, H)       # The probability that a given bag is malignant or benign
        Y_prob = self.classifier(M)         # The prediction given the probability (Y_prob >= 0.5 returns a Y_hat of 1 meaning malignant)
        Y_hat = torch.ge(Y_prob, 0.5).float()        return Y_prob, Y_hat, A.byte()

Model training using AWS SageMaker data parallelism (SDP)

In general, we train neural networks by adjusting their parameters in a direction that reduces prediction error. A common technique is stochastic gradient descent in which these parameter changes occur iteratively using equally sized samples called mini-batches. It is possible to speed up training time by evenly distributing mini-batches across a collection of independent machines which each have its own copy of the model, optimizer, and other essentials. Here, we use AWS SageMaker’s data parallelism toolkit which has been shown to achieve superior performance over PyTorch DistributedDataParallel⁷.

Schematic illustration of data parallelism [Source]

To prepare for SDP training, we can upload the aforementioned data into an Amazon S3 bucket, and launch a Jupyter notebook instance using SageMaker’s pre-built PyTorch container. For this project, training is initialized by calling the PyTorch estimator from the Amazon SageMaker Python SDK. Notably, we pass the training script, specify the instance count and type, and enable the SDP distribution method, as shown below:

import sagemaker
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()from sagemaker.pytorch import PyTorch
estimator = PyTorch(base_job_name='pytorch-smdataparallel-histopathology-mil',
                        source_dir='code',
                        entry_point='train.py',
                        role=role,
                        framework_version='1.8.1',
                        py_version='py36',
                        instance_count=2,
                        instance_type= 'ml.p3.16xlarge',
                        sagemaker_session=sagemaker_session,
                        distribution={'smdistributed':{
                                            'dataparallel':{
                                                    'enabled': True
                                                 }
                                          }
                                      },
                        debugger_hook_config=False,
                        volume_size=40)

ml.p3.16xlarge is one of three instance types supported by the SageMaker data parallelism toolkit, and AWS recommends using at least 2 instances to get the best performance and most out of it⁸. One instance of this type contains 8 NVIDIA V100 GPUs, each with 16 GB of memory. Here, this amounts to running 16 independent variants of our model.

We can then fit our PyTorch estimator by passing in the data we uploaded to S3. This imports our data into the local filesystem of the training cluster so that our train.py script can simply read the data from disk.

channels = {
    'training': 's3://sagemaker-us-east-1-318322629142/train/',
    'testing': 's3://sagemaker-us-east-1-318322629142/test/'
}
estimator.fit(inputs=channels)

In our train.py entry point script, we define our train function as shown below:

def train(model, device, train_loader, optimizer, epoch):
    model.train()    train_loss = 0.
    train_error = 0.
    predictions = []
    labels = []    for batch_idx, (data, label) in enumerate(train_loader):

bag_label = label
data = torch.squeeze(data)
data, bag_label = Variable(data), Variable(bag_label)
data, bag_label = data.to(device), bag_label.to(device)

        # reset gradients
        optimizer.zero_grad()        # calculate error
        bag_label = bag_label.float()
        Y_prob, Y_hat, _ = model(data)
        error = 1. - Y_hat.eq(bag_label).cpu().float().mean().data
        train_error += error        # calculate loss
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        loss = -1. * (bag_label * torch.log(Y_prob) + (1. - bag_label) * torch.log(1. - Y_prob))
        train_loss += loss.data[0]        # Keep track of predictions and labels to calculate accuracy after each epoch
        predictions.append(int(Y_hat))
        labels.append(int(bag_label))        # backward pass
        loss.backward()        # step
        optimizer.step()    # calculate loss and error for epoch
    train_loss /= len(train_loader)
    train_error /= len(train_loader)    print('Train Set, Epoch: {}, Loss: {:.4f}, Error: {:.4f},
 Accuracy: {:.2f}%'.format(epoch, train_loss.cpu().numpy()[0],
 train_error, accuracy_score(labels, predictions)*100))

We also create a function to save our model once training has completed:

def save_model(model, model_dir):
    with open(os.path.join(model_dir, 'model.pth'), 'wb') as f:
        torch.save(model.module.state_dict(), f)

In the main guard, we load our dataset (see repository for details), train over 10 epochs, and save our model:

device = torch.device("cuda")
model = DDP(Attention().to(device))
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0005)print('Start Training')
for epoch in range(1, 10 + 1):
    train(model, device, train_loader, optimizer, epoch)save_model(model, args.model_dir)

After training is complete, we can use the PyTorch estimator to deploy a create an endpoint which runs a SageMaker-provided PyTorch model server and hosts our trained model. In general, deployment is used to perform real-time predictions on a client application, but here we deploy for demonstrative purposes.

import sagemaker
role = sagemaker.get_execution_role()

from sagemaker.pytorch import PyTorchModel
model = PyTorchModel(model_data=model_data, source_dir=’code’,
entry_point=’inference.py’, role=role, framework_version=’1.6.0′, py_version=’py3′)

We can now use our predictor to predict labels for our test data and determine our accuracy score:

predictions = []
true_labels = []for batch_idx, (data, label) in enumerate(test_loader):
    _, Y_hat, _ = predictor.predict(data)
    predictions.append(int(Y_hat))
    true_labels.append(int(label))from sklearn.metrics import accuracy_score
accuracy_score(true_labels, predictions)

Our accuracy here is 67.2% which is approximately 7.5% lower than the reported accuracy in Attention-based Deep Multiple Instance Learning. While the model in the literature was trained over 100 epochs, ours was only trained for 10 epochs in order to minimize costs. I anticipate that longer training time would lead to results more similar to those seen in the paper.

References

  1. Topol, E. J. (2019). Deep medicine: how artificial intelligence can make healthcare human again. Basic Books.
  2. Amodei, D. (2020, September 2). AI and Compute. OpenAI. https://openai.com/blog/ai-and-compute/.
  3. Campanella, G., Hanna, M. G., Geneslaw, L., Miraflor, A., Werneck Krauss Silva, V., Busam, K. J., Brogi, E., Reuter, V. E., Klimstra, D. S., & Fuchs, T. J. (2019). Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nature Medicine, 25(8), 1301–1309. https://doi.org/10.1038/s41591-019-0508-1
  4. Rajpukar, P., Saporta, A., & Banerjee, O. (2020, December 16). The AI Health Podcast. episode 5.
  5. Prostate cANcer graDe Assessment (PANDA) Challenge. Kaggle. (n.d.). https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview/description.
  6. Ilse, M., Tomczak, J. M., Welling, M. (2018). Attention-based Deep Multiple Instance Learning. Proceedings of the 35th International Conference on Machine Learning, Stockholm, Sweden, PMLR 80. https://arxiv.org/abs/1802.04712.
  7. Webber, E., & Cruchant, O. (2020, December 9). Scale deep learning with 2 new libraries for distributed training on Amazon SageMaker [web log]. https://towardsdatascience.com/scale-neural-network-training-with-sagemaker-distributed-8cf3aefcff51.
  8. Aws. (n.d.). aws/amazon-sagemaker-examples. GitHub. https://github.com/aws/amazon-sagemaker-examples/blob/35e2faf7d1cc48ccedf0b2ede1da9987a18727a5/training/distributed_training/pytorch/data_parallel/mnist/pytorch_smdataparallel_mnist_demo.ipynb.
Spread the word

This post was originally published by Jonathan Glaser at Medium [AI]

Related posts