Build a (Neural Network) Model over MEDS data
In this tutorial, we'll use the dataset we extracted previously and build a neural network model to predict the task we extracted previously. See below for the jupyter notebook tutorial, rendered here, or check it out online on Google Colab or on our GitHub Repository
Building a Model
In this part of the tutorial, we'll walk through the task of building a (neural network) model with MEDS
We will continue exactly where we've left off in the previous part of the tutorial. Our task is to predict a short ICU stay (of < 3 days) given the patient's input data. We have already transformed the dataset into the MEDS format and extracted a set of label files for the dataset, and built a baseline model -- now, let's go a bit further!
Let's download the files we produced in the prior sections from the tutorial resources:
%%bash
wget -q -c https://github.com/Medical-Event-Data-Standard/MEDS_KDD_2025_Tutorial/raw/refs/heads/main/MEDS_data.zip
unzip -q -o MEDS_data.zip
apt-get -qq install tree
tree MEDS_data
Selecting previously unselected package tree. (Reading database ... (Reading database ... 5% (Reading database ... 10% (Reading database ... 15% (Reading database ... 20% (Reading database ... 25% (Reading database ... 30% (Reading database ... 35% (Reading database ... 40% (Reading database ... 45% (Reading database ... 50% (Reading database ... 55% (Reading database ... 60% (Reading database ... 65% (Reading database ... 70% (Reading database ... 75% (Reading database ... 80% (Reading database ... 85% (Reading database ... 90% (Reading database ... 95% (Reading database ... 100% (Reading database ... 126284 files and directories currently installed.) Preparing to unpack .../tree_2.0.2-1_amd64.deb ... Unpacking tree (2.0.2-1) ... Setting up tree (2.0.2-1) ... Processing triggers for man-db (2.10.2-1) ... MEDS_data ├── data │ ├── held_out │ │ └── 0.parquet │ ├── train │ │ └── 0.parquet │ └── tuning │ └── 0.parquet ├── labels │ └── short_LOS │ ├── held_out │ │ └── 0.parquet │ ├── train │ │ └── 0.parquet │ └── tuning │ └── 0.parquet └── metadata ├── codes.parquet ├── dataset.json └── subject_splits.parquet 10 directories, 9 files
from pathlib import Path
import pandas as pd
data_root = Path("MEDS_data/data")
labels_root = Path("MEDS_data/labels/short_LOS")
metadata_root = Path("MEDS_data/metadata")
train_data = pd.read_parquet(data_root / "train")[["subject_id", "time", "code", "numeric_value", "text_value"]]
train_labels = pd.read_parquet(labels_root / "train")[["subject_id", "prediction_time", "boolean_value"]]
A Neural Network Model
Building a neural network model is inherently very different than building a tabular baseline. There are lots of questions we'll need to answer to do this effectively; questions like:
- What kind of neural network do we want to build?
- How will it need its data pre-processed?
- How will we train this neural network?
All of these will dictate what we need to do with our data to build the right kind of model.
For this tutorial, we'll build a pretty simple model. We'll take the patient's sequence of measurements, ordered by time (note this is a bit ambiguous -- many measurements happen at the same time point), and apply a simple, 1D convolution over those measurements, embedded via a simple strategy using an embedding of the code and the numeric value, if present. We'll bubble this convolution up to a single output, and use that to make our prediction.
This modeling strategy is reasonable, but definitely also has some issues -- it doesn't take into account time (or the fact that many measurements happen at the same time), we aren't using static data, and we don't know if convolution is the right way to model patient data here. But for this tutorial, its sufficient.
How is (this kind) of a NN different than our tabularized baseline?
In our tabular baseline, our primary challenge was identifying how to take a dataset and extract a fixed size set of tabular features that explicitly summarized the longitudinal MEDS data. This approach was necessary because tabular baselines can only take in fixed size data -- not sequential data.
But, neural networks (or at least the kind we're building here, as well as many other architectures) can work with sequential data directly. This means that here, our challenge is not one of manual featurization, but instead is one of transforming our raw data into a format suitable for ingestion by this kind of neural network.
To do so, we have to meet a few constraints:
- Neural networks can't take in text features directly. Instead, for our codes (as well as for the
text_value
column, if we wanted to use that), we need to put them in a form suitable for embedding. - Neural networks can take in variable length sequences but still have a limit. This means that we need a vehicle to both encode too-short sequences (e.g., a patient who has less data in total than the maximum size sequence our model can process) and too-long sequences (a sequence of patient data longer than this maximum limit). We also need to make sure that when we try to predict a label for a task sample with a given prediction time, we don't use data from beyond that time.
- We need to make sure we can easily and efficiently load data in our model, without wasting valuable compute resources on repeated, non-modeling work.
Here's how we'll solve these challenges:
- Embedding: We'll do this in the normal way--by assigning an integer index for each unique code, then passing those indices into the model to be parsed with an embedding matrix.
- Sequence Lengths: We'll use a padding index to represent sequence elements (i.e., codes) that are not actually present (e.g., if a patient's data is too short), and we'll truncate sequences that are too long. There are many ways we could truncate sequences -- in this case, as we're trying to make a prediction as of a specific time, we'll use the longest contiguous sequence allowed that ends on or before the specified prediction time.
- Efficient Processing: We'll follow best practices and perform as many calculations as we can in advance of modeling time, and store just the data we need (in a format suitable for efficient subsequence slicing) on disk before we start modeling. Then, each batch of modeling, we'll just need to load the data we need and we'll be good to go! For this, we'll use existing external packages, and most of the details will be outside the scope of this tutorial.
MEDS TorchData
To help us out, we'll use a tool from the MEDS ecosystem that makes it easier to get the data in the right format. This tool is MEDS TorchData -- we won't go through all the details of how it works, but suffice it to say it has a couple of capabilities we can use:
- It comes with a script that will pre-process our dataset, ensuring that: (a) every code is assigned to a integer index suitable for an embedding matrix, (b) numeric values are centered and scaled to have zero mean and unit variance, and (c) data are written to disk in a manner that permits easy loading in a PyTorch dataset.
- It comes with a PyTorch dataset class we can inherit and instantiate that will let us work with the data in a straightforward manner.
- It lets us integrate natively with Lightning if we want to use that to help streamline training (which we will here).
Let's install it and see it in action!
Note: Installation may take some time. Additionally, if you see any pip dependency errors about polars versions, this is not something to worry about.
%%bash
pip install --quiet meds-torch-data[lightning]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.2/57.2 kB 2.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 821.1/821.1 kB 22.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 200.6/200.6 kB 13.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 36.3/36.3 MB 39.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 5.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 103.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 83.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 43.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 1.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 5.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 12.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 8.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 55.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.5/154.5 kB 9.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.1/57.1 kB 4.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 981.9/981.9 kB 37.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 825.4/825.4 kB 35.6 MB/s eta 0:00:00
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. cudf-polars-cu12 25.6.0 requires polars<1.29,>=1.25, but you have polars 1.30.0 which is incompatible.
Data Pre-processing
After installation, the next step is to pre-process our dataset. Following the documentation, all we have to do is run the following command.
%%bash
MTD_preprocess MEDS_dataset_dir="MEDS_data" output_dir="tensorized_MEDS_data"
[2025-07-31 18:13:40,721][meds_torchdata.preprocessing.__main__][INFO] - Running in serial mode as N_WORKERS is not set. [2025-07-31 18:13:40,722][meds_torchdata.preprocessing.__main__][INFO] - Running command: INPUT_DIR=/content/MEDS_data OUTPUT_DIR=/content/tensorized_MEDS_data MEDS_transform-pipeline --config-path=/usr/local/lib/python3.11/dist-packages/meds_torchdata/preprocessing/configs --config-name=runner pipeline_config_fp=/usr/local/lib/python3.11/dist-packages/meds_torchdata/preprocessing/configs/_MTD_preprocess.yaml ~parallelize ++do_overwrite=False
Let's see what that command produced!
%%bash
tree tensorized_MEDS_data
tensorized_MEDS_data ├── data │ ├── held_out │ │ └── 0.nrt │ ├── train │ │ └── 0.nrt │ └── tuning │ └── 0.nrt ├── fit_normalization │ ├── codes.parquet │ └── train │ └── 0.parquet ├── fit_vocabulary_indices ├── metadata │ └── codes.parquet ├── normalization │ ├── held_out │ │ └── 0.parquet │ ├── train │ │ └── 0.parquet │ └── tuning │ └── 0.parquet └── tokenization ├── event_seqs │ ├── held_out │ │ └── 0.parquet │ ├── train │ │ └── 0.parquet │ └── tuning │ └── 0.parquet └── schemas ├── held_out │ └── 0.parquet ├── train │ └── 0.parquet └── tuning └── 0.parquet 21 directories, 15 files
There are a number of files that are in this folder. Luckily, we don't need to understand them all to use them -- we can let MEDS TorchData handle that for us if we set up our dataset properly. Let's check it out!
Note: If you're interested, check out this part of the documentation and the documentation for the nested-ragged-tensors package for more information!
A PyTorch Dataset
To build a pytorch dataset, we first consturct a configuration object that defines the parameters of the dataset we want to create. There are a number of options supported by MEDS TorchData. Options about ways to work with static data, how you want your data represented, etc. For now, we'll use the following configuration file to kick off our modeling exploration. To explore different options, check out the documentation.
from meds_torchdata import MEDSTorchDataConfig
tensorized_cohort_dir = Path("tensorized_MEDS_data")
config = MEDSTorchDataConfig(
tensorized_cohort_dir, # Our tensorized data directory
max_seq_len=256, # How many measurements we want
seq_sampling_strategy="to_end", # What part of the patient's data to grab
static_inclusion_mode="omit", # What to do with static data
task_labels_dir=labels_root, # Our labels
)
Now, with our config defined, let's build a dataset and see some data!
from meds_torchdata import MEDSPytorchDataset
pyd = MEDSPytorchDataset(config, split="train")
print(f"Dataset constructed with {len(pyd)} samples!")
Dataset constructed with 74 samples!
Inspecting data in a pytorch dataset can be a bit hard. We can simplify this by looking at full batches which have a special print function. We can use the built in dataloader function to help with this:
dataloader = pyd.get_dataloader(batch_size=16)
batch = next(iter(dataloader))
print(batch)
MEDSTorchBatch: │ Mode: Subject-Measurement (SM) │ Static data? ✗ │ Labels? ✓ │ │ Shape: │ │ Batch size: 16 │ │ Sequence length: 256 │ │ │ │ All dynamic data: (16, 256) │ │ Labels: torch.Size([16]) │ │ Data: │ │ Dynamic: │ │ │ time_delta_days (torch.float32): │ │ │ │ [[0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ [0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ ..., │ │ │ │ [0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ [0.00, 0.00, ..., 0.00, 0.01]] │ │ │ code (torch.int64): │ │ │ │ [[1980, 2150, ..., 2692, 6529], │ │ │ │ [1985, 2007, ..., 2009, 6529], │ │ │ │ ..., │ │ │ │ [1983, 1995, ..., 2162, 2134], │ │ │ │ [2770, 1982, ..., 1997, 1704]] │ │ │ numeric_value (torch.float32): │ │ │ │ [[-0.49, 0.00, ..., -0.10, -0.80], │ │ │ │ [-0.71, 1.12, ..., 0.34, -0.46], │ │ │ │ ..., │ │ │ │ [ 0.33, -0.19, ..., 0.00, 0.00], │ │ │ │ [-1.01, 0.00, ..., 0.50, -0.31]] │ │ │ numeric_value_mask (torch.bool): │ │ │ │ [[ True, False, ..., True, True], │ │ │ │ [ True, True, ..., True, True], │ │ │ │ ..., │ │ │ │ [ True, True, ..., False, False], │ │ │ │ [ True, False, ..., True, True]] │ │ │ │ Labels: │ │ │ boolean_value (torch.bool): │ │ │ │ [False, False, ..., True, False]
This will be the structure we can work with in our model -- we'll have access to 2D tensors of time deltas, codes, and numeric values (with a mask indicating when numeric values were present).
Note: The time-deltas here are not normalized, and are often zero, as many measurements occur at the same point in time. This is normal, but may or may not be desired for your particular model application.
print(batch.PAD_INDEX)
print(batch.code)
print(batch.numeric_value)
print(batch.numeric_value_mask)
0 tensor([[1980, 2150, 2708, ..., 1982, 2692, 6529], [1985, 2007, 2000, ..., 2878, 2009, 6529], [2885, 2407, 2203, ..., 2216, 2206, 2207], ..., [2692, 2235, 2198, ..., 3042, 3031, 3023], [1983, 1995, 1979, ..., 2157, 2162, 2134], [2770, 1982, 2003, ..., 1998, 1997, 1704]]) tensor([[-0.4875, 0.0000, 0.0000, ..., 0.0000, -0.1005, -0.8004], [-0.7122, 1.1156, 0.7586, ..., 0.0000, 0.3388, -0.4600], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0227, -0.0556, -0.3237, ..., 0.0000, 0.0000, 0.0000], [ 0.3286, -0.1904, -0.5563, ..., 0.0000, 0.0000, 0.0000], [-1.0095, 0.0000, 0.5741, ..., -0.6803, 0.4985, -0.3070]]) tensor([[ True, False, False, ..., False, True, True], [ True, True, True, ..., False, True, True], [False, False, False, ..., False, False, False], ..., [ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], [ True, False, True, ..., True, True, True]])
Building the model
Using this, let's go ahead and build our model! Given we know what the batch will contain, all that's left is to put together our model's structure.
import torch.nn as nn
import torch
from meds_torchdata import MEDSTorchBatch
class SimpleConv(nn.Module):
"""
1D-convolutional model for sequence + static inputs.
Pipeline
--------
1) Embed code indices (with padding_idx).
2) Concatenate [code_embedding, numeric_value_mask, numeric_value] per timestep.
3) Pass through a Conv1d stack over time -> global pooling -> fixed-size sequence embedding.
4) Embed static features (linear MLP projection).
5) Concatenate [sequence_embedding, static_embedding] -> predict labels.
Inputs (forward)
----------------
batch.code: LongTensor of shape (B, T)
Tokenized code indices, padded with `pad_index`.
batch.numeric_value: FloatTensor of shape (B, T)
Numeric values (e.g., lab values) aligned with `code_idx`. Should be pre-normalized.
batch.numeric_value_mask: Bool/FloatTensor of shape (B, T)
1 if numeric_value is present/valid for that timestep, else 0.
Output
------
logits: FloatTensor of shape (B,)
Raw (unnormalized) predictions. Apply sigmoid/softmax outside as appropriate.
Notes
-----
- Padding handling: The embedding uses `padding_idx` so padded positions produce a zero vector.
With ReLU activations and global max pooling, padded zeros will not dominate non-pad activations.
"""
def __init__(
self,
*,
code_vocab_size: int,
pad_index: int = MEDSTorchBatch.PAD_INDEX,
code_embed_dim: int = 128,
conv_channels=(128, 128, 128),
conv_kernel_sizes=(5, 5, 3),
conv_dropout: float = 0.1,
head_hidden_dim: int = 128,
act: nn.Module = nn.ReLU,
):
super().__init__()
assert len(conv_channels) == len(conv_kernel_sizes), "conv_channels and conv_kernel_sizes must be same length"
self.pad_index = pad_index
# 1) Code embedding
self.code_emb = nn.Embedding(
num_embeddings=code_vocab_size,
embedding_dim=code_embed_dim,
padding_idx=pad_index,
)
# 2–3) Conv stack over time
in_ch = code_embed_dim + 2 # [code_emb || numeric_value_mask || numeric_value]
conv_blocks = []
for out_ch, k in zip(conv_channels, conv_kernel_sizes):
conv_blocks += [
nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, padding=k // 2),
nn.BatchNorm1d(out_ch),
act(),
nn.Dropout(conv_dropout),
]
in_ch = out_ch
self.conv = nn.Sequential(*conv_blocks)
self.pool = nn.AdaptiveMaxPool1d(1) # -> (B, C, 1) then squeeze to (B, C)
# 5) Head
head_in = conv_channels[-1]
self.head = nn.Sequential(
nn.Linear(head_in, head_hidden_dim),
act(),
nn.Dropout(conv_dropout),
nn.Linear(head_hidden_dim, 1),
)
self._reset_parameters()
def _reset_parameters(self):
# Kaiming init for conv/linear; embeddings default init is fine (padding row already handled)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
def forward(
self,
batch: MEDSTorchBatch
) -> torch.Tensor:
B, T = batch.code.shape
# (B, T, E)
code_e = self.code_emb(batch.code)
# Ensure mask is float
if batch.numeric_value_mask.dtype != torch.float32:
batch.numeric_value_mask = batch.numeric_value_mask.float()
# (B, T, E+2)
x = torch.cat(
[code_e, batch.numeric_value_mask.unsqueeze(-1), batch.numeric_value.unsqueeze(-1)],
dim=-1,
)
# Conv1d expects (B, C, T)
x = x.transpose(1, 2) # -> (B, C=E+2, T)
x = self.conv(x) # -> (B, C_out, T)
x = self.pool(x).squeeze(-1) # -> (B, C_out)
logits = self.head(x).squeeze() # (B,)
return logits
Let's see if it runs!
M = SimpleConv(code_vocab_size=config.vocab_size)
M(batch)
[22]:
tensor([ 7.8502, 10.1843, 4.1388, 6.4581, 8.9709, 11.3923, 6.9392, 7.1074, 7.2635, 9.4971, 8.3914, 7.9199, 4.2960, 7.1975, 8.3680, 8.2609], grad_fn=<SqueezeBackward0>)
Training
What about training this model? For our purposes, we'll make this step a bit simpler using Lightning. To use lightning, we'll need to do two things:
- Convert our PyTorch dataset into a Lightning DataModule (this is handled for us by MEDS TorchData, as we'll see in a moment)
- Build a Lightning Module for our model.
For step 1, as indicated, we can use MEDS TorchData for this; let's see it in action:
from meds_torchdata.extensions.lightning_datamodule import Datamodule
dm = Datamodule(config, batch_size=32, num_workers=0, pin_memory=False)
print(next(iter(dm.train_dataloader())))
MEDSTorchBatch: │ Mode: Subject-Measurement (SM) │ Static data? ✗ │ Labels? ✓ │ │ Shape: │ │ Batch size: 32 │ │ Sequence length: 256 │ │ │ │ All dynamic data: (32, 256) │ │ Labels: torch.Size([32]) │ │ Data: │ │ Dynamic: │ │ │ time_delta_days (torch.float32): │ │ │ │ [[0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ [0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ ..., │ │ │ │ [0.00, 0.00, ..., 0.00, 0.00], │ │ │ │ [0.00, 0.00, ..., 0.00, 0.00]] │ │ │ code (torch.int64): │ │ │ │ [[2885, 2059, ..., 2083, 2708], │ │ │ │ [2739, 2666, ..., 6566, 1795], │ │ │ │ ..., │ │ │ │ [2278, 2057, ..., 2007, 6529], │ │ │ │ [6529, 2198, ..., 2249, 2250]] │ │ │ numeric_value (torch.float32): │ │ │ │ [[ 0.00, 0.00, ..., -0.44, 0.00], │ │ │ │ [ 0.00, 0.00, ..., 2.24, 0.18], │ │ │ │ ..., │ │ │ │ [ 0.60, 0.00, ..., 0.77, -0.72], │ │ │ │ [-0.74, -0.32, ..., 0.00, 0.00]] │ │ │ numeric_value_mask (torch.bool): │ │ │ │ [[False, False, ..., True, False], │ │ │ │ [False, False, ..., True, True], │ │ │ │ ..., │ │ │ │ [ True, False, ..., True, True], │ │ │ │ [ True, True, ..., False, False]] │ │ │ │ Labels: │ │ │ boolean_value (torch.bool): │ │ │ │ [False, False, ..., True, True]
To build our lightning module, we just need to add a few methods to our class. To keep things clean, we'll have our module take our PyTorch model class as an input.
import lightning as L
from torchmetrics import Accuracy
class SimpleConvL(L.LightningModule):
def __init__(self, lr, **kwargs):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.model = SimpleConv(**kwargs)
self.loss = nn.BCEWithLogitsLoss()
self.accuracy = Accuracy(task="binary")
def forward(self, batch: MEDSTorchBatch):
return self.model(batch)
def _step(self, batch: MEDSTorchBatch, split: str):
logits = self(batch)
labels = batch.boolean_value.float()
loss = self.loss(logits, labels)
self.log(f"{split}_loss", loss)
preds = (logits > 0).float()
acc = self.accuracy(preds, labels)
self.log(
f"{split}_acc",
acc,
on_step=False,
on_epoch=True,
prog_bar=True
)
return loss
def training_step(self, batch, _):
return self._step(batch, "train")
def validation_step(self, batch, _):
return self._step(batch, "val")
def test_step(self, batch, _):
return self._step(batch, "test")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
With these defined, let's try training our model!
LM = SimpleConvL(lr=1e-3, code_vocab_size=config.vocab_size)
trainer = L.Trainer(max_epochs=10, check_val_every_n_epoch=1, log_every_n_steps=1)
trainer.fit(model=LM, datamodule=dm)
INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs INFO: | Name | Type | Params | Mode ------------------------------------------------------- 0 | model | SimpleConv | 1.1 M | train 1 | loss | BCEWithLogitsLoss | 0 | train 2 | accuracy | BinaryAccuracy | 0 | train ------------------------------------------------------- 1.1 M Trainable params 0 Non-trainable params 1.1 M Total params 4.318 Total estimated model params size (MB) 23 Modules in train mode 0 Modules in eval mode INFO:lightning.pytorch.callbacks.model_summary: | Name | Type | Params | Mode ------------------------------------------------------- 0 | model | SimpleConv | 1.1 M | train 1 | loss | BCEWithLogitsLoss | 0 | train 2 | accuracy | BinaryAccuracy | 0 | train ------------------------------------------------------- 1.1 M Trainable params 0 Non-trainable params 1.1 M Total params 4.318 Total estimated model params size (MB) 23 Modules in train mode 0 Modules in eval mode
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
With lightning, we can easily test our model too:
trainer.test(model=LM, datamodule=dm)
Testing: | | 0/? [00:00<?, ?it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_acc │ 0.375 │ │ test_loss │ 1.6943222284317017 │ └───────────────────────────┴───────────────────────────┘
[39]:
[{'test_loss': 1.6943222284317017, 'test_acc': 0.375}]
Unsurprisingly, for our very small dataset of only ~70 samples, this model has grossly overfit. You should see (from the lightning progress bar description) that the train accuracy has gotten very high, but the test accuracy is still quite poor. But, were you to use this model on the real data--who knows!
Key Takeaways
In this part of the tutorial, you've learned how to train a neural network model over MEDS data -- how to convert it into a representation suitable for training a sequential neural network architecture using tools in the MEDS Ecosystem and train a simple model for a real task.