Skip to main content

Build a (Neural Network) Model over MEDS data

In this tutorial, we'll use the dataset we extracted previously and build a neural network model to predict the task we extracted previously. See below for the jupyter notebook tutorial, rendered here, or check it out online on Google Colab or on our GitHub Repository

Building a Model

In this part of the tutorial, we'll walk through the task of building a (neural network) model with MEDS

We will continue exactly where we've left off in the previous part of the tutorial. Our task is to predict a short ICU stay (of < 3 days) given the patient's input data. We have already transformed the dataset into the MEDS format and extracted a set of label files for the dataset, and built a baseline model -- now, let's go a bit further!

Let's download the files we produced in the prior sections from the tutorial resources:

1 
2
3
4
5
6
7
%%bash
wget -q -c https://github.com/Medical-Event-Data-Standard/MEDS_KDD_2025_Tutorial/raw/refs/heads/main/MEDS_data.zip

unzip -q -o MEDS_data.zip

apt-get -qq install tree
tree MEDS_data
Selecting previously unselected package tree.
(Reading database ... 
(Reading database ... 5%
(Reading database ... 10%
(Reading database ... 15%
(Reading database ... 20%
(Reading database ... 25%
(Reading database ... 30%
(Reading database ... 35%
(Reading database ... 40%
(Reading database ... 45%
(Reading database ... 50%
(Reading database ... 55%
(Reading database ... 60%
(Reading database ... 65%
(Reading database ... 70%
(Reading database ... 75%
(Reading database ... 80%
(Reading database ... 85%
(Reading database ... 90%
(Reading database ... 95%
(Reading database ... 100%
(Reading database ... 126284 files and directories currently installed.)
Preparing to unpack .../tree_2.0.2-1_amd64.deb ...
Unpacking tree (2.0.2-1) ...
Setting up tree (2.0.2-1) ...
Processing triggers for man-db (2.10.2-1) ...
MEDS_data
├── data
│   ├── held_out
│   │   └── 0.parquet
│   ├── train
│   │   └── 0.parquet
│   └── tuning
│       └── 0.parquet
├── labels
│   └── short_LOS
│       ├── held_out
│       │   └── 0.parquet
│       ├── train
│       │   └── 0.parquet
│       └── tuning
│           └── 0.parquet
└── metadata
    ├── codes.parquet
    ├── dataset.json
    └── subject_splits.parquet

10 directories, 9 files
1 
2
3
4
5
6
7
8
9
from pathlib import Path
import pandas as pd

data_root = Path("MEDS_data/data")
labels_root = Path("MEDS_data/labels/short_LOS")
metadata_root = Path("MEDS_data/metadata")

train_data = pd.read_parquet(data_root / "train")[["subject_id", "time", "code", "numeric_value", "text_value"]]
train_labels = pd.read_parquet(labels_root / "train")[["subject_id", "prediction_time", "boolean_value"]]

A Neural Network Model

Building a neural network model is inherently very different than building a tabular baseline. There are lots of questions we'll need to answer to do this effectively; questions like:

  1. What kind of neural network do we want to build?
  2. How will it need its data pre-processed?
  3. How will we train this neural network?

All of these will dictate what we need to do with our data to build the right kind of model.

For this tutorial, we'll build a pretty simple model. We'll take the patient's sequence of measurements, ordered by time (note this is a bit ambiguous -- many measurements happen at the same time point), and apply a simple, 1D convolution over those measurements, embedded via a simple strategy using an embedding of the code and the numeric value, if present. We'll bubble this convolution up to a single output, and use that to make our prediction.

This modeling strategy is reasonable, but definitely also has some issues -- it doesn't take into account time (or the fact that many measurements happen at the same time), we aren't using static data, and we don't know if convolution is the right way to model patient data here. But for this tutorial, its sufficient.

How is (this kind) of a NN different than our tabularized baseline?

In our tabular baseline, our primary challenge was identifying how to take a dataset and extract a fixed size set of tabular features that explicitly summarized the longitudinal MEDS data. This approach was necessary because tabular baselines can only take in fixed size data -- not sequential data.

But, neural networks (or at least the kind we're building here, as well as many other architectures) can work with sequential data directly. This means that here, our challenge is not one of manual featurization, but instead is one of transforming our raw data into a format suitable for ingestion by this kind of neural network.

To do so, we have to meet a few constraints:

  1. Neural networks can't take in text features directly. Instead, for our codes (as well as for the text_value column, if we wanted to use that), we need to put them in a form suitable for embedding.
  2. Neural networks can take in variable length sequences but still have a limit. This means that we need a vehicle to both encode too-short sequences (e.g., a patient who has less data in total than the maximum size sequence our model can process) and too-long sequences (a sequence of patient data longer than this maximum limit). We also need to make sure that when we try to predict a label for a task sample with a given prediction time, we don't use data from beyond that time.
  3. We need to make sure we can easily and efficiently load data in our model, without wasting valuable compute resources on repeated, non-modeling work.

Here's how we'll solve these challenges:

  1. Embedding: We'll do this in the normal way--by assigning an integer index for each unique code, then passing those indices into the model to be parsed with an embedding matrix.
  2. Sequence Lengths: We'll use a padding index to represent sequence elements (i.e., codes) that are not actually present (e.g., if a patient's data is too short), and we'll truncate sequences that are too long. There are many ways we could truncate sequences -- in this case, as we're trying to make a prediction as of a specific time, we'll use the longest contiguous sequence allowed that ends on or before the specified prediction time.
  3. Efficient Processing: We'll follow best practices and perform as many calculations as we can in advance of modeling time, and store just the data we need (in a format suitable for efficient subsequence slicing) on disk before we start modeling. Then, each batch of modeling, we'll just need to load the data we need and we'll be good to go! For this, we'll use existing external packages, and most of the details will be outside the scope of this tutorial.

MEDS TorchData

To help us out, we'll use a tool from the MEDS ecosystem that makes it easier to get the data in the right format. This tool is MEDS TorchData -- we won't go through all the details of how it works, but suffice it to say it has a couple of capabilities we can use:

  1. It comes with a script that will pre-process our dataset, ensuring that: (a) every code is assigned to a integer index suitable for an embedding matrix, (b) numeric values are centered and scaled to have zero mean and unit variance, and (c) data are written to disk in a manner that permits easy loading in a PyTorch dataset.
  2. It comes with a PyTorch dataset class we can inherit and instantiate that will let us work with the data in a straightforward manner.
  3. It lets us integrate natively with Lightning if we want to use that to help streamline training (which we will here).

Let's install it and see it in action!

Note: Installation may take some time. Additionally, if you see any pip dependency errors about polars versions, this is not something to worry about.

1 
2
%%bash
pip install --quiet meds-torch-data[lightning]
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.2/57.2 kB 2.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 821.1/821.1 kB 22.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 200.6/200.6 kB 13.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 36.3/36.3 MB 39.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 5.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 103.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 83.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 43.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 1.2 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 5.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 12.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 8.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 55.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.5/154.5 kB 9.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.1/57.1 kB 4.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 981.9/981.9 kB 37.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 825.4/825.4 kB 35.6 MB/s eta 0:00:00
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-polars-cu12 25.6.0 requires polars<1.29,>=1.25, but you have polars 1.30.0 which is incompatible.

Data Pre-processing

After installation, the next step is to pre-process our dataset. Following the documentation, all we have to do is run the following command.

1 
2
%%bash
MTD_preprocess MEDS_dataset_dir="MEDS_data" output_dir="tensorized_MEDS_data"
[2025-07-31 18:13:40,721][meds_torchdata.preprocessing.__main__][INFO] - Running in serial mode as N_WORKERS is not set.
[2025-07-31 18:13:40,722][meds_torchdata.preprocessing.__main__][INFO] - Running command: INPUT_DIR=/content/MEDS_data OUTPUT_DIR=/content/tensorized_MEDS_data MEDS_transform-pipeline --config-path=/usr/local/lib/python3.11/dist-packages/meds_torchdata/preprocessing/configs --config-name=runner pipeline_config_fp=/usr/local/lib/python3.11/dist-packages/meds_torchdata/preprocessing/configs/_MTD_preprocess.yaml ~parallelize ++do_overwrite=False

Let's see what that command produced!

1 
2
%%bash
tree tensorized_MEDS_data
tensorized_MEDS_data
├── data
│   ├── held_out
│   │   └── 0.nrt
│   ├── train
│   │   └── 0.nrt
│   └── tuning
│       └── 0.nrt
├── fit_normalization
│   ├── codes.parquet
│   └── train
│       └── 0.parquet
├── fit_vocabulary_indices
├── metadata
│   └── codes.parquet
├── normalization
│   ├── held_out
│   │   └── 0.parquet
│   ├── train
│   │   └── 0.parquet
│   └── tuning
│       └── 0.parquet
└── tokenization
    ├── event_seqs
    │   ├── held_out
    │   │   └── 0.parquet
    │   ├── train
    │   │   └── 0.parquet
    │   └── tuning
    │       └── 0.parquet
    └── schemas
        ├── held_out
        │   └── 0.parquet
        ├── train
        │   └── 0.parquet
        └── tuning
            └── 0.parquet

21 directories, 15 files

There are a number of files that are in this folder. Luckily, we don't need to understand them all to use them -- we can let MEDS TorchData handle that for us if we set up our dataset properly. Let's check it out!

Note: If you're interested, check out this part of the documentation and the documentation for the nested-ragged-tensors package for more information!

A PyTorch Dataset

To build a pytorch dataset, we first consturct a configuration object that defines the parameters of the dataset we want to create. There are a number of options supported by MEDS TorchData. Options about ways to work with static data, how you want your data represented, etc. For now, we'll use the following configuration file to kick off our modeling exploration. To explore different options, check out the documentation.

1 
2
3
4
5
6
7
8
9
10
11
from meds_torchdata import MEDSTorchDataConfig

tensorized_cohort_dir = Path("tensorized_MEDS_data")

config = MEDSTorchDataConfig(
    tensorized_cohort_dir,            # Our tensorized data directory
    max_seq_len=256,                  # How many measurements we want
    seq_sampling_strategy="to_end",   # What part of the patient's data to grab
    static_inclusion_mode="omit",     # What to do with static data
    task_labels_dir=labels_root,      # Our labels
)

Now, with our config defined, let's build a dataset and see some data!

1 
2
3
4
from meds_torchdata import MEDSPytorchDataset

pyd = MEDSPytorchDataset(config, split="train")
print(f"Dataset constructed with {len(pyd)} samples!")
Dataset constructed with 74 samples!

Inspecting data in a pytorch dataset can be a bit hard. We can simplify this by looking at full batches which have a special print function. We can use the built in dataloader function to help with this:

1 
2
3
dataloader = pyd.get_dataloader(batch_size=16)
batch = next(iter(dataloader))
print(batch)
MEDSTorchBatch:
│ Mode: Subject-Measurement (SM)
│ Static data? ✗
│ Labels? ✓
│
│ Shape:
│ │ Batch size: 16
│ │ Sequence length: 256
│ │
│ │ All dynamic data: (16, 256)
│ │ Labels: torch.Size([16])
│
│ Data:
│ │ Dynamic:
│ │ │ time_delta_days (torch.float32):
│ │ │ │ [[0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  ...,
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.01]]
│ │ │ code (torch.int64):
│ │ │ │ [[1980, 2150,  ..., 2692, 6529],
│ │ │ │  [1985, 2007,  ..., 2009, 6529],
│ │ │ │  ...,
│ │ │ │  [1983, 1995,  ..., 2162, 2134],
│ │ │ │  [2770, 1982,  ..., 1997, 1704]]
│ │ │ numeric_value (torch.float32):
│ │ │ │ [[-0.49,  0.00,  ..., -0.10, -0.80],
│ │ │ │  [-0.71,  1.12,  ...,  0.34, -0.46],
│ │ │ │  ...,
│ │ │ │  [ 0.33, -0.19,  ...,  0.00,  0.00],
│ │ │ │  [-1.01,  0.00,  ...,  0.50, -0.31]]
│ │ │ numeric_value_mask (torch.bool):
│ │ │ │ [[ True, False,  ...,  True,  True],
│ │ │ │  [ True,  True,  ...,  True,  True],
│ │ │ │  ...,
│ │ │ │  [ True,  True,  ..., False, False],
│ │ │ │  [ True, False,  ...,  True,  True]]
│ │
│ │ Labels:
│ │ │ boolean_value (torch.bool):
│ │ │ │ [False, False,  ...,  True, False]

This will be the structure we can work with in our model -- we'll have access to 2D tensors of time deltas, codes, and numeric values (with a mask indicating when numeric values were present).

Note: The time-deltas here are not normalized, and are often zero, as many measurements occur at the same point in time. This is normal, but may or may not be desired for your particular model application.

1 
2
3
4
print(batch.PAD_INDEX)
print(batch.code)
print(batch.numeric_value)
print(batch.numeric_value_mask)
0
tensor([[1980, 2150, 2708,  ..., 1982, 2692, 6529],
        [1985, 2007, 2000,  ..., 2878, 2009, 6529],
        [2885, 2407, 2203,  ..., 2216, 2206, 2207],
        ...,
        [2692, 2235, 2198,  ..., 3042, 3031, 3023],
        [1983, 1995, 1979,  ..., 2157, 2162, 2134],
        [2770, 1982, 2003,  ..., 1998, 1997, 1704]])
tensor([[-0.4875,  0.0000,  0.0000,  ...,  0.0000, -0.1005, -0.8004],
        [-0.7122,  1.1156,  0.7586,  ...,  0.0000,  0.3388, -0.4600],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0227, -0.0556, -0.3237,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3286, -0.1904, -0.5563,  ...,  0.0000,  0.0000,  0.0000],
        [-1.0095,  0.0000,  0.5741,  ..., -0.6803,  0.4985, -0.3070]])
tensor([[ True, False, False,  ..., False,  True,  True],
        [ True,  True,  True,  ..., False,  True,  True],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True, False,  True,  ...,  True,  True,  True]])

Building the model

Using this, let's go ahead and build our model! Given we know what the batch will contain, all that's left is to put together our model's structure.

1 
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch.nn as nn
import torch
from meds_torchdata import MEDSTorchBatch

class SimpleConv(nn.Module):
    """
    1D-convolutional model for sequence + static inputs.

    Pipeline
    --------
    1) Embed code indices (with padding_idx).
    2) Concatenate [code_embedding, numeric_value_mask, numeric_value] per timestep.
    3) Pass through a Conv1d stack over time -> global pooling -> fixed-size sequence embedding.
    4) Embed static features (linear MLP projection).
    5) Concatenate [sequence_embedding, static_embedding] -> predict labels.

    Inputs (forward)
    ----------------
    batch.code: LongTensor of shape (B, T)
        Tokenized code indices, padded with `pad_index`.
    batch.numeric_value: FloatTensor of shape (B, T)
        Numeric values (e.g., lab values) aligned with `code_idx`. Should be pre-normalized.
    batch.numeric_value_mask: Bool/FloatTensor of shape (B, T)
        1 if numeric_value is present/valid for that timestep, else 0.

    Output
    ------
    logits: FloatTensor of shape (B,)
        Raw (unnormalized) predictions. Apply sigmoid/softmax outside as appropriate.

    Notes
    -----
    - Padding handling: The embedding uses `padding_idx` so padded positions produce a zero vector.
      With ReLU activations and global max pooling, padded zeros will not dominate non-pad activations.
    """

    def __init__(
        self,
        *,
        code_vocab_size: int,
        pad_index: int = MEDSTorchBatch.PAD_INDEX,
        code_embed_dim: int = 128,
        conv_channels=(128, 128, 128),
        conv_kernel_sizes=(5, 5, 3),
        conv_dropout: float = 0.1,
        head_hidden_dim: int = 128,
        act: nn.Module = nn.ReLU,
    ):
        super().__init__()
        assert len(conv_channels) == len(conv_kernel_sizes), "conv_channels and conv_kernel_sizes must be same length"
        self.pad_index = pad_index

        # 1) Code embedding
        self.code_emb = nn.Embedding(
            num_embeddings=code_vocab_size,
            embedding_dim=code_embed_dim,
            padding_idx=pad_index,
        )

        # 2–3) Conv stack over time
        in_ch = code_embed_dim + 2  # [code_emb || numeric_value_mask || numeric_value]
        conv_blocks = []
        for out_ch, k in zip(conv_channels, conv_kernel_sizes):
            conv_blocks += [
                nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, padding=k // 2),
                nn.BatchNorm1d(out_ch),
                act(),
                nn.Dropout(conv_dropout),
            ]
            in_ch = out_ch
        self.conv = nn.Sequential(*conv_blocks)
        self.pool = nn.AdaptiveMaxPool1d(1)  # -> (B, C, 1) then squeeze to (B, C)

        # 5) Head
        head_in = conv_channels[-1]
        self.head = nn.Sequential(
            nn.Linear(head_in, head_hidden_dim),
            act(),
            nn.Dropout(conv_dropout),
            nn.Linear(head_hidden_dim, 1),
        )

        self._reset_parameters()

    def _reset_parameters(self):
        # Kaiming init for conv/linear; embeddings default init is fine (padding row already handled)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(
        self,
        batch: MEDSTorchBatch
    ) -> torch.Tensor:
        B, T = batch.code.shape

        # (B, T, E)
        code_e = self.code_emb(batch.code)

        # Ensure mask is float
        if batch.numeric_value_mask.dtype != torch.float32:
            batch.numeric_value_mask = batch.numeric_value_mask.float()

        # (B, T, E+2)
        x = torch.cat(
            [code_e, batch.numeric_value_mask.unsqueeze(-1), batch.numeric_value.unsqueeze(-1)],
            dim=-1,
        )

        # Conv1d expects (B, C, T)
        x = x.transpose(1, 2)  # -> (B, C=E+2, T)
        x = self.conv(x)       # -> (B, C_out, T)
        x = self.pool(x).squeeze(-1)  # -> (B, C_out)

        logits = self.head(x).squeeze()  # (B,)
        return logits

Let's see if it runs!

1 
2
M = SimpleConv(code_vocab_size=config.vocab_size)
M(batch)
[22]: 
tensor([ 7.8502, 10.1843,  4.1388,  6.4581,  8.9709, 11.3923,  6.9392,  7.1074,
         7.2635,  9.4971,  8.3914,  7.9199,  4.2960,  7.1975,  8.3680,  8.2609],
       grad_fn=<SqueezeBackward0>)

Training

What about training this model? For our purposes, we'll make this step a bit simpler using Lightning. To use lightning, we'll need to do two things:

  1. Convert our PyTorch dataset into a Lightning DataModule (this is handled for us by MEDS TorchData, as we'll see in a moment)
  2. Build a Lightning Module for our model.

For step 1, as indicated, we can use MEDS TorchData for this; let's see it in action:

1 
2
3
4
5
from meds_torchdata.extensions.lightning_datamodule import Datamodule

dm = Datamodule(config, batch_size=32, num_workers=0, pin_memory=False)

print(next(iter(dm.train_dataloader())))
MEDSTorchBatch:
│ Mode: Subject-Measurement (SM)
│ Static data? ✗
│ Labels? ✓
│
│ Shape:
│ │ Batch size: 32
│ │ Sequence length: 256
│ │
│ │ All dynamic data: (32, 256)
│ │ Labels: torch.Size([32])
│
│ Data:
│ │ Dynamic:
│ │ │ time_delta_days (torch.float32):
│ │ │ │ [[0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  ...,
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.00],
│ │ │ │  [0.00, 0.00,  ..., 0.00, 0.00]]
│ │ │ code (torch.int64):
│ │ │ │ [[2885, 2059,  ..., 2083, 2708],
│ │ │ │  [2739, 2666,  ..., 6566, 1795],
│ │ │ │  ...,
│ │ │ │  [2278, 2057,  ..., 2007, 6529],
│ │ │ │  [6529, 2198,  ..., 2249, 2250]]
│ │ │ numeric_value (torch.float32):
│ │ │ │ [[ 0.00,  0.00,  ..., -0.44,  0.00],
│ │ │ │  [ 0.00,  0.00,  ...,  2.24,  0.18],
│ │ │ │  ...,
│ │ │ │  [ 0.60,  0.00,  ...,  0.77, -0.72],
│ │ │ │  [-0.74, -0.32,  ...,  0.00,  0.00]]
│ │ │ numeric_value_mask (torch.bool):
│ │ │ │ [[False, False,  ...,  True, False],
│ │ │ │  [False, False,  ...,  True,  True],
│ │ │ │  ...,
│ │ │ │  [ True, False,  ...,  True,  True],
│ │ │ │  [ True,  True,  ..., False, False]]
│ │
│ │ Labels:
│ │ │ boolean_value (torch.bool):
│ │ │ │ [False, False,  ...,  True,  True]

To build our lightning module, we just need to add a few methods to our class. To keep things clean, we'll have our module take our PyTorch model class as an input.

1 
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import lightning as L
from torchmetrics import Accuracy

class SimpleConvL(L.LightningModule):
    def __init__(self, lr, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.model = SimpleConv(**kwargs)
        self.loss = nn.BCEWithLogitsLoss()
        self.accuracy = Accuracy(task="binary")

    def forward(self, batch: MEDSTorchBatch):
        return self.model(batch)

    def _step(self, batch: MEDSTorchBatch, split: str):
        logits = self(batch)
        labels = batch.boolean_value.float()

        loss = self.loss(logits, labels)
        self.log(f"{split}_loss", loss)

        preds = (logits > 0).float()
        acc = self.accuracy(preds, labels)
        self.log(
            f"{split}_acc",
            acc,
            on_step=False,
            on_epoch=True,
            prog_bar=True
        )

        return loss

    def training_step(self, batch, _):
        return self._step(batch, "train")

    def validation_step(self, batch, _):
        return self._step(batch, "val")

    def test_step(self, batch, _):
        return self._step(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

With these defined, let's try training our model!

1 
2
3
4
LM = SimpleConvL(lr=1e-3, code_vocab_size=config.vocab_size)

trainer = L.Trainer(max_epochs=10, check_val_every_n_epoch=1, log_every_n_steps=1)
trainer.fit(model=LM, datamodule=dm)
INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: 
  | Name     | Type              | Params | Mode 
-------------------------------------------------------
0 | model    | SimpleConv        | 1.1 M  | train
1 | loss     | BCEWithLogitsLoss | 0      | train
2 | accuracy | BinaryAccuracy    | 0      | train
-------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.318     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name     | Type              | Params | Mode 
-------------------------------------------------------
0 | model    | SimpleConv        | 1.1 M  | train
1 | loss     | BCEWithLogitsLoss | 0      | train
2 | accuracy | BinaryAccuracy    | 0      | train
-------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.318     Total estimated model params size (MB)
23        Modules in train mode
0         Modules in eval mode
Sanity Checking: |          | 0/? [00:00<?, ?it/s]
Training: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.

With lightning, we can easily test our model too:

1 
trainer.test(model=LM, datamodule=dm)
Testing: |          | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc                     0.375           │
│         test_loss             1.6943222284317017     │
└───────────────────────────┴───────────────────────────┘
[39]: 
[{'test_loss': 1.6943222284317017, 'test_acc': 0.375}]

Unsurprisingly, for our very small dataset of only ~70 samples, this model has grossly overfit. You should see (from the lightning progress bar description) that the train accuracy has gotten very high, but the test accuracy is still quite poor. But, were you to use this model on the real data--who knows!

Key Takeaways

In this part of the tutorial, you've learned how to train a neural network model over MEDS data -- how to convert it into a representation suitable for training a sequential neural network architecture using tools in the MEDS Ecosystem and train a simple model for a real task.