Skip to main content

An Introductory Guide to AI/ML Engineering

·7342 words·35 mins
Josh Merrill
Author
Josh Merrill

Like any other technology, the AI/ML engineering stack spans from the low-level optimization of hardware to the application layer, where developers produce highly abstracted end products. Generally, the stack is divided into two categories: 1) AI Engineering and 2) ML Engineering.

ML Engineering focuses on the lower, model-level approach to development. ML engineers train and test their models from scratch. Although model creation is a complex process—more of an art than a science—there are common themes and techniques that can be extremely helpful.

Conversely, AI Engineering operates further up the stack, where developers leverage pretrained models to solve specific application problems. AI engineers rarely build proprietary models from scratch; instead, they enhance existing models to achieve their objectives.

Building Models - ML Engineering
#

First, let’s start by taking a look at the basics of ML engineering. Our goal here is to understand the core concepts of when and why someone would opt to train their own model, the considerations behind this decision, and the tools they would use to complete this task.

Common Python Libraries
#

Mostly all of AI/ML development is done in Python (for better or for worse). One of the benefits is the variety of available tools, there is typically going to be a robust and well maintained library for whatever task you may need, so let’s take a look at some of the major players.

Prior to the hype surrounds ML for LLMs, the predecessors to ChatGPT were simply statistical analysis tools. Therefore, there are many libraries used for shaping, marking, visualizing and handling data.

Data Analysis
#

Numpy
#

NumPy is the original scientific computing package for Python. It provides powerful tools to efficiently handle and manipulate large, multi-dimensional arrays, along with a comprehensive suite of built-in functions.

In this discussion, we’ll explore how to leverage NumPy’s capabilities to manipulate different matrices.

Let’s dive in and see how NumPy can simplify complex array operations.

Example 1: Creating and Manipulating Arrays
#
import numpy as np

# Create a matrix with the shape 3x3 (2D)
matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f"Original matrix: {matrix}")

# Reshape the matrix to a 1D array
reshaped = matrix.reshape(-1)
print(f"Reshaped matrix: {reshaped}")

# Create a zeros array with the same shape as the original array
zeros = np.zeros(matrix.shape)
print(f"Zeros matrix: {zeros}")
Example 2: Vector Ops and Feature Scaling
#
import numpy as np

# Define a sample dataset where each row is a sample and each column is a feature
data = np.array([[50, 200], [30, 180], [20, 160], [60, 220]])

# Normalizing the data between 0 and 1
min_vals = np.min(data, axis=0)
max_vals = np.min(data, axis=0)
scaled_data = (data - min_vals) / (max_vals - min_vals)

print(f"Original data: {data}")
print(f"Scaled data: {scaled_data}")
Example 3: Linear Algebra Operations in Model Implementation
#
import numpy as np

# Define a 3x2 weights matrix 
weights = np.array([[0.2, 0.8], [0.5, 0.1], [0.9, 0.7]])
# Define a 3, features matrix
features = np.array([1.0, 2.0, 3.0])

# Perform a dot product on a 3x2 and 3, matrix
output = np.dot(features, weights)
print(f"Output of linear transformation: {output}, {output.shape}")
Example 4: Filtering Data on Conditionals
#
import numpy as np

predictions = np.array([0.1, 0.4, 0.35, 0.8])
ground_truth = np.array([0, 1, 0, 1])

threshold = 0.5
high_confidence = np.where(predictions > threshold)[0]

print(f"Indices of high confidence predictions: {high_confidence}")

Numpy goes well beyond these simple examples, however the core concepts stay the same. We have a series of matrices, we need to wiggle them around such that they become useful to us.

Pandas
#

Similar to numpy, pandas is a python module commonly used across data science and ML. However, where numpy is aimed at performing operations on high-dimensional arrays (ndarrays), pandas is built on top of numpy, focusing on building DataFrames of different data types.

Example 1: Loading and Inspecting Data
#
import pandas as pd

# Create the dataframe object
df = pd.read_csv('data.csv')

# Display the first 5 rows of the dataframe
print(df.head())

# Get a summary of the data
print(df.info())
Example 2: Data Cleaning and Handling Missing Values
#
import pandas as pd

# Create the dataframe object
df = pd.read_csv('data.csv')

# Check for any missing values in each column
print(df.isnull().sum())

# Drop any rows with missing values
df_clean = df.dropna()

# Fill missing values with the dataset's mean
df_filled = df.fillna(df.mean())

print(df_clean.info())
print(df_filled.info())

Data Visualization
#

Sadly, humans are very poor at thinking mass amounts of data, especially in high dimensional space. Rather we need ways of breaking down the findings in a way thats easily to conceptualize. Only of the ways people can get useful insights from their data is to visualize it. Visualization help uncover trends and intuitions that may have been hard to spot when simply starting at a large spreadsheet of numbers.

Matplotlib
#

Matplotlib is a utility to create a range of static, animated, or interactive visualizations. When applied to machine learning, matplot is create at visualizing training runs and loss history, data clusters, or statistical distributions over the dataset.

Example 1: Basic Line Plot
#
import matplotlib.pyplot as plt
import numpy as np

# Sine data
x = np.linspace(0,10,100)
y = np.sin(x)

# Create the plot
plt.plot(x, y)
plt.title("Sine Wave")
plt.xlabel("X Axis")
plt.ylabel("Y Axis")

# Show the plot
plt.show()
Example 2: Plotting Training and Validation Loss
#
import matplotlib.pyplot as plt
import numpy as np

# Example data for training and validation loss
epochs = np.arange(1, 21)
train_loss = np.exp(-epochs / 10) + np.random.normal(0, 0.05, len(epochs))
val_loss = np.exp(-epochs / 10) + 0.1 + np.random.normal(0, 0.05, len(epochs))

plt.figure(figsize=(8, 5))
plt.plot(epochs, train_loss, label='Training Loss', marker='o')
plt.plot(epochs, val_loss, label='Validation Loss', marker='s')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.show()
Example 3: Visualizing a Confusion Matrix
#
import matplotlib.pyplot as plt
import numpy as np

# Example confusion matrix
conf_matrix = np.array([[50, 2, 1],
                        [5, 45, 5],
                        [0, 3, 47]])

fig, ax = plt.subplots(figsize=(6, 6))
cax = ax.matshow(conf_matrix, cmap=plt.cm.Blues)
plt.title('Confusion Matrix', pad=20)
fig.colorbar(cax)

# Set axis labels and ticks
classes = ['Class 1', 'Class 2', 'Class 3']
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
plt.xlabel('Predicted')
plt.ylabel('Actual')

# Annotate the matrix cells with counts
for (i, j), value in np.ndenumerate(conf_matrix):
    ax.text(j, i, f'{value}', ha='center', va='center', color='white' if value > 40 else 'black')

plt.show()
Example 4: Plotting Decision Boundaries for a classifier
#
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.svm import SVC

# Generate synthetic data
X, y = make_blobs(n_samples=300, centers=3, random_state=42, cluster_std=1.0)

# Train a simple classifier
clf = SVC(kernel='linear', decision_function_shape='ovo')
clf.fit(X, y)

# Create a grid to evaluate the classifier
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300),
                     np.linspace(y_min, y_max, 300))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.3, cmap=plt.cm.coolwarm)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k', cmap=plt.cm.coolwarm)
plt.title('Decision Boundaries of SVM Classifier')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

Building Models
#

So far, we have gone over some of the building blocks of scientific computing in python. These libraries alone and foundational and enable a developer to create small scale machine learning models. However, machine learning models are compute hungry, the more operations they can make, the better. Moreover, since python is an interpreted language, the runtime speed is extremely slow, negatively impacting the efficacy of standard libraries like NumPy powering large scale ML models.

To solve this issue, there have been multiple efforts to develop “auto grad” libraries that automatically perform gradient computations, the mathematical operation underpinning training neural networks. There are tons of frameworks to choose from but, in my eyes, the two best are Pytorch and JAX.

At the end of the day, these two libraries are used for the same purpose. However, the talk in the industry is Pytorch is used for research where JAX is used for production.

Pytorch
#

Pytorch is an open-source machine learning framework developed by Meta and has an extremely robust open source community supporting it with sub packages for different data modalities such as torch, torchvision, and torchaudio. Notably, Pytorch creates and maintains a computation graph at runtime which holds a representation of a tensor. In practice, this means that Pytorch will create a graphical representation of your neural network which is stored as a tensor. Whenever a function such as .backward() is called on the target tensor, Pytorch will traverse the graph in reverse order to compute the gradients during the backprop process.

Example: CIFAR Classification
#

Frameworks such as Pytorch have a wide spread application from training vision, audio, and language models. However, the underlying principle is the same: perform lots of matrix multiplications quickly and easily. In this example we are going to look as a simple proof of concept, building a small neural network to classify images from the CIFAR-100 dataset.

Click to see me!
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np

# Set hardcoded hyperparameters
batch_size = 64
epochs = 10
lr = 0.01
momentum = 0.9
seed = 1

# Set random seed for reproducibility and device
torch.manual_seed(seed)

# Get the device to use
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# Data augmentation and normalization for training
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), # Crop a region of the image
    transforms.RandomHorizontalFlip(), # Randomly flip the image, probability = 0.5
    transforms.ToTensor(), # Convert a PIL image or np ndarray to torch.tensor
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # Use normalization values Ex. each red pixel, subtract 0.5071 and divide by 0.2675
])

# Normalization for validation
transform_test = transforms.Compose([
    transforms.ToTensor(), # Convert to tensor
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), # Apply same normalization
# Load CIFAR-100 dataset
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define the model (ResNet18 adjusted for 100 classes)
model = models.resnet18(num_classes=100)
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

# Lists to store loss values for each epoch
train_losses = []
val_losses = []

# Main training and validation loop
for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        num_batches += 1
        
    train_loss_epoch = running_loss / num_batches
    train_losses.append(train_loss_epoch)
    
    # Evaluate on validation set
    model.eval()
    val_loss = 0.0
    num_batches_val = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            num_batches_val += 1
            
            # Compute accuracy
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    val_loss_epoch = val_loss / num_batches_val
    val_losses.append(val_loss_epoch)
    accuracy = 100. * correct / total
    
    print(f"Epoch {epoch}/{epochs}: Train Loss: {train_loss_epoch:.4f} | Val Loss: {val_loss_epoch:.4f} | Val Acc: {accuracy:.2f}%")

# Plot the training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(np.arange(1, epochs+1), train_losses, label='Training Loss')
plt.plot(np.arange(1, epochs+1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()

JAX
#

JAX is a production oriented (super fast) machine learning framework that follows a functional programming design mechanism and automatically handles annoying features such as JIT compilation (speeding up your python code) and sharding data/ models across different nodes. JAX requires all of its functions to be pure and variables to be immutable to allow for functional programming.

Example: CIFAR Classification
#
Click to see me!
import tensorflow_datasets as tfds  # TFDS for dataset loading
import tensorflow as tf  # TensorFlow for data processing

# Set random seed for reproducibility
tf.random.set_seed(0)

# Hyperparameters for training
num_epochs = 10
batch_size = 32

# Load the CIFAR-100 dataset from TFDS for training and testing splits
train_ds: tf.data.Dataset = tfds.load('cifar100', split='train')
test_ds: tf.data.Dataset = tfds.load('cifar100', split='test')

# Preprocessing: Normalize image pixel values to the [0, 1] range.
train_ds = train_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255.0,
    'label': sample['label'],
  }
)
test_ds = test_ds.map(
  lambda sample: {
    'image': tf.cast(sample['image'], tf.float32) / 255.0,
    'label': sample['label'],
  }
)

# Prepare the training dataset:
# - Repeat the dataset for the number of epochs.
# - Shuffle with a buffer of 1024 for randomness.
# - Batch the data and drop the last batch if it's not full.
# - Prefetch one batch to optimize latency.
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)

# Prepare the test dataset:
# - Shuffle, batch, and prefetch similarly to training dataset.
test_ds = test_ds.shuffle(1024)
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

from flax.experimental import nnx  # NNX API from Flax
from functools import partial

class CNN(nnx.Module):
  """A simple Convolutional Neural Network model for CIFAR‑100 classification."""
  
  def __init__(self, *, rngs: nnx.Rngs):
    # First convolution: Input has 3 channels (RGB) -> 32 filters, 3x3 kernel
    self.conv1 = nnx.Conv(3, 32, kernel_size=(3, 3), rngs=rngs)
    # Second convolution: 32 channels -> 64 filters, 3x3 kernel
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    # Average pooling layer with a 2x2 window and stride of 2
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    # Fully connected layer:
    # With two rounds of pooling, image size reduces from 32x32 to 8x8.
    # Thus, the flattened input size is 8*8*64 = 4096.
    self.linear1 = nnx.Linear(4096, 256, rngs=rngs)
    # Final layer outputs 100 logits (one per CIFAR‑100 class)
    self.linear2 = nnx.Linear(256, 100, rngs=rngs)

  def __call__(self, x):
    # Apply first convolution and ReLU activation
    x = nnx.relu(self.conv1(x))
    # Apply average pooling to reduce spatial dimensions
    x = self.avg_pool(x)
    # Apply second convolution and ReLU activation
    x = nnx.relu(self.conv2(x))
    # Apply second average pooling
    x = self.avg_pool(x)
    # Flatten the tensor for the fully connected layers
    x = x.reshape(x.shape[0], -1)
    # Apply the first linear layer and ReLU activation
    x = nnx.relu(self.linear1(x))
    # Final linear layer to produce logits for each class
    x = self.linear2(x)
    return x

# Instantiate the model with a random seed
model = CNN(rngs=nnx.Rngs(0))
# Display the model architecture
nnx.display(model)

import jax.numpy as jnp  # JAX NumPy for numerical operations

# Test a forward pass with a dummy input (batch size 1, 32x32 RGB image)
y = model(jnp.ones((1, 32, 32, 3)))
nnx.display(y)

import optax  # Optimizer library

# Define optimizer hyperparameters
learning_rate = 0.005
momentum = 0.9

# Set up the optimizer (AdamW in this case) with our model parameters
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
# Define metrics to track accuracy and average loss
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(), 
  loss=nnx.metrics.Average('loss'),
)

# Display optimizer details
nnx.display(optimizer)

# Define a loss function that computes softmax cross entropy and returns logits
def loss_fn(model: CNN, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

# Define a training step (compiled with JIT for performance)
@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Perform a single training step: forward pass, loss calculation, gradient computation, and parameter update."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  # Update metrics with current loss and predictions
  metrics.update(loss=loss, logits=logits, labels=batch['label'])
  # Update model parameters using the computed gradients
  optimizer.update(grads)

# Define an evaluation step (also compiled with JIT)
@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
  """Perform a single evaluation step: compute loss and update evaluation metrics."""
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])

# Reset TensorFlow random seed for consistent shuffling
tf.random.set_seed(0)

# Determine the number of steps per epoch from the training dataset's cardinality
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs

# Initialize a dictionary to record metrics over training epochs
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

# Training loop: iterate over the training dataset
for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Perform a single training step
  train_step(model, optimizer, metrics, batch)

  # Check if one epoch has passed
  if (step + 1) % num_steps_per_epoch == 0:
    # Compute and record training metrics for the current epoch
    for metric, value in metrics.compute().items():
      metrics_history[f'train_{metric}'].append(value)
    # Reset metrics before evaluation
    metrics.reset()

    # Evaluate the model on the test dataset
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Record test metrics for the current epoch
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    # Reset metrics for the next epoch
    metrics.reset()

    # Print training and test results for the current epoch
    print(
      f"train epoch: {(step+1) // num_steps_per_epoch}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    print(
      f"test epoch: {(step+1) // num_steps_per_epoch}, "
      f"loss: {metrics_history['test_loss'][-1]}, "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    )

import matplotlib.pyplot as plt  # Library for plotting

# Plot loss and accuracy over epochs in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()

Checkpointing
#

If you have looked over the code from the previous examples, you may have noticed that at the end of the scripts, we are not saving our model anywhere. This means all that time we spent training the model has been wasted since we don’t have any way of using it! This is where checkpointing comes into play.

Checkpointing is the functionality of reading and writing the representation of a model to disk. Different frameworks will store models in different formats (which can also introduce vulnerabilities!). The frequency of checkpointing a model is up to you, depending on how long your training is going to run for, you might want to save a version of the model at every 1, 5, or 10 epochs.

Pytorch
#

In Pytorch models create a state dictionary which is dictionary object that maps each layer (usually a custom or Pytorch class) to its parameters (usually torch.tensor)

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)
    
    def forward(self, x):
        return self.fc(x)

# Instantiate the model
model = SimpleModel()

# Save the model's state_dict to disk
torch.save(model.state_dict(), 'simple_model.pth')
print("Model saved!")

# To load the model, first create an instance and then load the state_dict
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('simple_model.pth'))
loaded_model.eval()  # Set the model to evaluation mode
print("Model loaded and ready for inference!")

JAX
#

JAX stores its models in a graphical representation, called a pytree, but at the end of the day, the concept stays the same. JAX will store each layer and class and map it to its data. However, with the use of pytrees, JAX can store a wider range of data such as any arrays/ dictionaries, metadata/ configs, etc. Rather than solely the layers and their parameters. Another library, orbax is the standard for handling checkpointing with JAX. For a comprehensive guide on checkpointing with Orbax and NNX, check out the docs here!

# A simple model with one linear layer.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
print(ckpt)

# ----------------

from flax.training import orbax_utils

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args)

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
    '/tmp/flax_ckpt/orbax/managed', orbax_checkpointer, options)

# Inside a training loop
for step in range(5):
    # ... do your training
    checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args})

os.listdir('/tmp/flax_ckpt/orbax/managed')  # Because max_to_keep=2, only step 3 and 4 are retained

# ---------------- Restoring Checkpoints

raw_restored = orbax_checkpointer.restore('/tmp/flax_ckpt/orbax/single_save')
step = checkpoint_manager.latest_step()  # step = 4
checkpoint_manager.restore(step)

Datasets
#

Data, and datasets, are a crucial aspect of the machine learning pipeline. As the old adage goes: “Garbage in, garbage out”. Curating, cleaning, and processing, data is a hefty task commonly referred to as “data pipelining”. For example, unsupervised training on image classification is not nearly as effective as training on a labelled dataset. However, now it takes effort to find images of all the classes you want to classify, make sure they are all the same size and normalization steps have been applied, then you need (usually a person) to go in a manually label each on of the images, really just a big pain.

In addition, the type or modality of data you are using will change based on the context of the problem. Image classification will use images, text models need text, and other specialized models may need a niche data type, for example, weather or geographical data. Moreover, the model you are using will operate on different data types. Having a handle on what type of data a model is well suited at analyzing will prove a useful skill when developing your tools. Given our context of offensive security, we may collect information such as traffic captures, commands and tool output, LDAP data, etc. formatting this data to be easily digested by the model of choice is a primary task before diving into developing the tools.

Collection
#

Hopefully, there is going to be a researcher who has already compiled a useful dataset for you to use. Yet, offensive security does not have a lively academic community and therefore, much of the data collection may have to be done manually. The internals of data collection and label are quite intricate and a bit out of scope for this post. However, some helpful resources on the the topic can be seen here:

Article Link
Data Labeling: The Authoritative Guide Data Labeling: The Authoritative Guide
Text Classification For Machine Learning https://developers.google.com/machine-learning/guides/text-classification/step-1

Evals
#

Once you finish training your model, you’re going to need a method of testing how good is actually is. For simple tasks this could be straight forward, how good is my model at classify pictures of hot dogs? How well does my model prevent malware from executing? etc. However, for more complex tasks, such as reasoning and math, the method of evaluation is not so simple. Hence the need for a set of standardized benchmarks that different models/ agents can be graded off of. The following is a set of common LLM evals and their evaluation category:

Eval Link Purpose
MMLU MMLU General knowledge & reasoning
GSM8K GSM8K Math (grade school arithmetic)
SWE-bench SWE-bench Real world coding tasks
HellaSwag HellaSwag Commonsense reasoning
ARC ARC Reading comprehension & reasoning
GPQA GPQA Graduate-level science questions
HumanEval HumanEval Coding (program synthesis)
MBPP MBPP Coding (Python programming)
TruthfulQA TruthfulQA Truthfulness & factual accuracy
MT-Bench MT-Bench Conversational/chatbot evaluation

Evals in cybersecurity
#

Again, cybersecurity is an under-explored area of application for ML, let alone offensive security. Therefore the ecosystem for security specific evaluations is lackluster. Their are a few available evals, specifically Cybench. However, the main barometer for model performance along the security axis seems to be the ability for the model to solve CTF or lab (ex. PortSwigger) challenges.

A good unified offensive security eval is a needed place of research, but maybe just not by us :).

Weights & Biases (wandb)
#

Weights & Biases is a software package and UI interface to extract and display log data on training runs. It will log important details such as metrics, parameters, outputs, etc. and ship them to a web instance for visualization. The implementation is straight forward, simply import the package and include a call to the logging function during the training steps. Here is an example of a simple dashboard using W&B:

wandb.png

When Would ML Engineering be Applicable?
#

The use case of ML engineering can vary. On one hand, there are mega labs with infinite budget and resources that are building the modern SOTA models with varied and generalizable use cases. On the other hand, there are many use cases where a small and easily trained model can get a specific task done within the context of a bigger application. Are we looking to train and reproduce a SOTA LLM from scratch? No, not even close. Could we design a small model that can classify if a machine is running EDR and if so what EDR is running based on the process list available? Yes we could!

In short, ML engineering is the process of making a machine learning model from scratch. Ranging from collecting and cleaning the data to implementing the model architecture and training, storing, and testing it. There are many tools to get used to in ML engineering but sticking to the popular, well supported tools is a great place to be.

It’s also important to realize the scope of an issue. We aren’t looking to re-invent the wheel. We want to stand on the shoulders of giants and take advantage of the explosion of research and innovation that labs, universities, and independent researchers have contributed to the open community.

Building on top of Models - AI Engineering
#

Now we have a basic understanding of the traditional ML side, lets shift our focus to the recently blooming field of AI engineering. AI engineering is simply working our way up the stack and trying to answer the question of “Wow we have these really great and powerful models, how can we get them to solve real world problems?”. Some people tend to dismiss AI engineering as nothing more than crafting simple wrappers for models, underestimating its true value. I believe this perspective is both short-sighted and misinformed. While it may not be as glamorous as building your own model from scratch to solve a specific problem, it’s unrealistic—and perhaps a bit vain—to expect an individual to compete with companies that have hundreds of billions of dollars in funding behind them. Fully discounting a field of technology based on the aversion to leaning on someone else’s product for support is going to put you behind the curve of innovation. Embracing the new changes and paradigms is crucial to staying fast, lean, and competitive.

In the modern state, the art of AI engineering is focus on pulling as much efficiency out of a model as possible by enhancing, finetuning, or augmenting its capabilities. There are many ways to make an efficient system, lets look at some of the highlights.

RAG
#

What is RAG, why is it useful?
#

Retrieval Augmented Generation, or RAG, is a technique to increase the knowledge base of an LLM without the need for finetuning. The reason RAG works so effectively is fairly straight forward. During training, a model will learn some attributes about each token is sees in it’s dataset. However, during this process the specific details of this process all the tokens and their weights get jumbled up and pulling precise, specific information from the learned weights might be more difficult. Alternatively, data that is put into the context window during a call to the model is directly accessible and useful for generating output due to the attention mechanism of the transformer. To give an analogy, attempting to pull specific facts about the life of your favorite president can be difficult if done directly from memory. However, if you can look up the facts from their wikipedia article, your performance of answer Abraham Lincoln related questions is going to significantly increase. This is the core of RAG. We let a model look up information related to its current task to inform it on the specifics.

The Practical Bits
#

Implementing RAG is fairly straight forward in the simplest case, however there are many advanced version at your disposal. At the end of the day each RAG implementation is going to need the same fundamental pieces:

  • Data (you bring this)
  • Chunking mechanism
  • Vector database
  • Search mechanism
Choosing your data
#

The data you bring is just as important as the data you train on. If the data is unstructured and messy, its going to be harder for the search mechanism to extract the relevant information to send to the model. Generally speaking, consistently formatted markdown with clear natural language explanations is the best and most consistent option.

Choosing a database
#

There are many options for vector databases. In my experience, milvus has been easy to set up, scale, and interact with.

Milvus
#

Milvus is a vector database with built in features for semantic search, data tagging, and process with a strong SDK directly in python. Milvus natively offers API for interacting with a series of embedding models, reducing the pain of wrangling different libraries. For getting started with Milvus checkout this guide: https://milvus.io/docs/quickstart.md

Other Options
#

Although I don’t have personal experience with other vector DB options, I have heard good things about the following choices:

Choosing chunking method
#

Chunking data for RAG is often the most difficult piece to get correct. Chunking data is the process of determining how to split up your provided documents into smaller sections that will get vectorized and stored in the database. There are a couple heuristics that can be used for chunking, such as splitting on paragraphs (\n\n) or by markdown headers (#). However, there is no silver bullet to this issue and it will require some experimentation to best fit your data and RAG method.

Finetuning
#

Fine tuning is the current hotness in the AI community and the underpinning of breakthrough dating from the original release of ChatGPT and GPT-3 to the recent advancements from DeepSeek R1. Reinforcement Learning (RL) specifically has been use to improve performance of a model at a given, niche task, such as reasoning, math, coding or even extremely specialized tasks such as front end generation. The current meta for fine tuning models is using algorithms such as LoRA or GRPO to induce the ability to reason about a given task.

Tool Use
#

In order to enhance the capabilities of a model, tool use has emerged as an essential strategy for bridging the gap between a model’s inherent knowledge and real-world tasks.

Tool use is a behavior that is imputed in a model through reinforcement learning where the model will learn to structure its output in the format of a tool call. Tools can be defined in a variety of methods, depending on the framework in use. However at the end of the day, tools are stored as JSON similar to the following:

{
    "type": "function",
    "function": {
        "name": "get_current_temperature",
        "description": "Gets the temperature at a given location.",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The location to get the temperature for"
                }
            },
            "required": [
                "location"
            ]
        }
    }
}

Agents
#

What are agents?
#

I ❤️ agents. Agents allow an LLM to power a system to go and perform some action out in the world. There is great research into determine the best structure for how to design agents. However, at the end of the day, an agent is a system that allows a model to keep track of its actions, plan what I wants to do, and gives it access to tools to get them done.

For an overview of what agents are, check out the Agents paper from Google.

What are good agents when should I use them?
#

In my eyes, all agents are good agents :). However, in reality, there are design patterns that allow agents to achieve better success. Anthropic released a great paper detailing effective design patterns for building agents. The keys to making strong agents are as follows:

  1. Know when to use an agent and when not to Agents are best used when there is a complex issue that is hard to model statically with code. If you are going to need a huge web of if else statements then consider making using an agent.
  2. Give your agent guardrails. Don’t give your agent unrestricted access to bash. Not only would this be a huge security risk, but without defined uses and tools, agent performance will decrease. Instead, consider giving your agent access to specific tools with templates on how to use them. For example, instead of having a model run an nmap scan from bash, create a run_nmap_scan tool that templates the command for the agent.
  3. Keep it as simple as possible. You might not need an agent for your use case. Try to keep your solution absolutely as simple as possible. Take the following section from Anthropic’s post:
When building applications with LLMs, we recommend finding the simplest solution possible, and only increasing complexity when needed. This might mean not building agentic systems at all. Agentic systems often trade latency and cost for better task performance, and you should consider when this tradeoff makes sense.

When more complexity is warranted, workflows offer predictability and consistency for well-defined tasks, whereas agents are the better option when flexibility and model-driven decision-making are needed at scale. For many applications, however, optimizing single LLM calls with retrieval and in-context examples is usually enough.

Notable Agent Frameworks
#

Framework Link Use case
smolagents https://github.com/huggingface/smolagents Barebones library with automatic tool creation. Implements features from the ReAct paper by using a python interpreter at each step of execution.
CrewAI https://www.crewai.com/ Allows for orchestration of multiple agents (hence “crew”) with designated roles.
browser-use https://github.com/browser-use/browser-use Designated agent framework for interacting with websites through a browser. Like a human.

Developer Environments
#

Luckily, creating a developer environment is very straight forward. There is little setup needed since 99% of machine learning is done using python. As long as you can install python, you should be good to go.

Of course, the main constraint to development is the access to GPUs. Typically, during development, a researcher might have access to a single card in a local device and perform quick and dirty POCs locally. Whereas there will be a machine (or pod of machines) with many chips and large scale compute. These will often run a remote execution environment, such as Jupyter, where a developer can connect from their client to have access to the extended compute.

Local Development
#

Local setup is quite simple and is up to the developer to implement. Personally, I like to keep all my venvs in my home directory.

mkdir ~/.venvs
python -m venv ~/.venvs/machine-learning
source ~/.venvs/machine-learning/bin/activate

All that is left is installing your dependencies (hopefully in a requirements.txt) and using your code editor of choice.

Remote/ Shared Development
#

A remote development environment its typically going to be a Jupyter notebook/lab environment. In my home lab, I have a Debian box that runs a notebook server.

# 1. Update package lists and install Python 3, pip, and venv
sudo apt update
sudo apt install -y python3 python3-pip python3-venv

# 2. Create a virtual environment (optional but recommended)
python3 -m venv jupyter-env
source jupyter-env/bin/activate

# 3. Upgrade pip and install Jupyter along with ipykernel
pip install --upgrade pip
pip install jupyter ipykernel

# 4. Create a new IPython kernel named "machine-learning"
python -m ipykernel install --user --name machine-learning --display-name "machine-learning"

# 5. Start the Jupyter server (notebook or lab as preferred)
jupyter notebook
# or, if you prefer JupyterLab:
# jupyter lab

Once the server is running, forward the port of the server to your local machine to access it.

ssh -L 8080:localhost:8080 user@machine

You can either access the server through the browser at localhost:8080 or through a client of your choice. I prefer to use VSCode/ Cursor with the Jupyter extension.

Useful Commands
#

This following is a series of useful commands when working with CUDA devices or machine learning system in general. I often find myself needing to monitor the amount of memory allocated on my GPU and fighting with pip errors :,).

  • GPU and CUDA Tools:
    • nvcc --version: Displays the version of the CUDA compiler, ensuring your CUDA toolkit is correctly installed.
    • nvidia-settings: Opens the NVIDIA settings GUI for more detailed configuration and monitoring of your GPU.
    • nvidia-smi: Provides a real-time, more granular monitoring of GPU performance metrics.
  • System Resource Monitoring:
    • htop or top: Monitor CPU and memory usage in real time.
    • free: Quickly check available system memory.
    • lscpu: Get detailed information about your CPU architecture.
    • df: Check disk usage and available space.
  • Profiling and Debugging Tools:
  • Package and Environment Management:
    • pip list or conda list: Verify installed Python packages and their versions.
    • Environment-specific commands like python -m pip freeze help capture dependencies for reproducibility.
  • Network and I/O Monitoring (for distributed training scenarios):
    • iftop or nload: Monitor network bandwidth usage.
    • iotop: Monitor disk I/O, which can be critical if you’re working with large datasets.

Running a model locally
#

There are many instances when you’re going to want to run a model locally, either on your personal machine, or on a beefier machine that you own. Luckily there are many options to do this! However, keep in mind the constraints to run a model. In order to get a model to run, we need to have enough VRAM (for running on GPU) or RAM (for running on CPU).

Ollama
#

ollama is by far the easiest solution to running a model. It only takes a few commands to get the server started.

Ollama + Docker
#

The easiest way to run ollama cross platform is by using the prepackaged docker container

sudo apt update
sudo apt install -y docker.io

sudo docker pull ollama/ollama:latest

sudo docker run -d \
  --name ollama-server \
  -p 8080:8080 \
  ollama/ollama:latest --model=llama3.1_8b
Bare Metal
#

You can run ollama bare metal, the install page can be found here: https://ollama.com/download. Once installed you can run the server as follows:

ollama serve
ollama pull llama3.1:8b
ollama run llama3.1:8b

Once the server is running, either from docker or bare metal, you can interact with your client of choice. Here is a simple python script using the openai package

import openai

# Configure the API base URL to point to your local Ollama server
openai.api_base = "http://localhost:8080"
# If your server requires an API key, set it here; otherwise, it can be a dummy value.
openai.api_key = "sk-dummy-key"

try:
    response = openai.Completion.create(
        model="llama3.1_8b",  # Ensure this matches your server's model identifier
        prompt="Hello, world!",
        max_tokens=100
    )
    print("Response from Ollama server:")
    print(response)
except Exception as e:
    print(f"An error occurred: {e}")
Remote Ollama server
#

By default ollama will serve its api on port 8080. If the ollama serve is remote, its recommended to forward the port locally

ssh -L 8080:localhost:8080 user@host

Hugging Face Transformers
#

Hugging Face’s transformers library is a great tool for getting different transformer models spun up quickly. Removing a full abstraction, such as ollama, allow you more granular access to the model, tokenizer, inputs, and outputs. We won’t go over these features here, but please feel free to experiment on your own.

Sample script to inference on llama3.1 8b:

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

def main():
    model_name = "meta-llama/llama3.1-8b"  

    print(f"Loading tokenizer and model from '{model_name}' ...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

    # Create a text-generation pipeline using the loaded model and tokenizer
    generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

    prompt = "Hello, world!"
    print(f"Generating text for prompt: {prompt}")
    
    # Generate text with a maximum length of 100 tokens and sampling enabled
    results = generator(prompt, max_length=100, do_sample=True, temperature=0.7)
    
    for i, result in enumerate(results):
        print(f"\n--- Generated Output {i+1} ---")
        print(result["generated_text"])

if __name__ == "__main__":
    main()

Custom code!
#

Of course, you can write your own implementation for your model. However, this requires looots of dev work to convert the paper and technical report of a model to use the open weights. If you’re feeling up for a challenge, I’m not stopping you :^).

A Quick Note on Quantization
#

Quantization is the process of reducing the amount of memory needed per parameter in a network. Specifically, the use of uncommon, low precision data types such as BF16 or FP8, even down to 4 and 2 bit types, are used to store the values for each node in a neural network. Simply put, the research community found that empirically model performance is relatively constant even when less precise data types are used. Quantized models are able to achieve similar performance to their higher-precision counterparts while dramatically reducing both memory usage and computational requirements. This efficiency boost comes from storing model parameters in lower-precision formats, which in turn leads to faster inference times and reduced power consumption. As a result, quantized models are particularly advantageous for deployment on resource-constrained devices like mobile phones, embedded systems, and edge computing platforms.

How much compute do I need to do X?
#

If you’re as GPU poor as I am, you will be constantly asking yourself this question. Typically the answer is “more than I have access to”, but with tricks like quantization and suffering through low token throughput, you can run different size models.

VRAM
#

VRAM is the main constraint when it comes to running a model. If you don’t have enough VRAM, you can’t run the model. Here is a helpful calculator from Hugging Face to determine if running a model is possible: https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator

Token Throughput
#

Once a model is successfully loaded into memory, the speed of the hardware will dictate the amount of tokens you can generate per second. With consumer grade GPUs, you can probably expect 20-30 tokens per second for small (~8b) quanitzed models, about 3-5 tokens per second on medium (30-70b) quantized models, and <1 token per second on large (400+b) param models.

Glossary
#

  • vector: A one dimensional list.
    • Ex. [0, 1, 2]
  • matrix: A two dimensional list.
    • Ex. [[0, 1, 2], [3, 4, 5]]
  • tensor: A three+ dimensional list.
    • Ex. [[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]]
  • backprop: Backwards propagation. The process of taking the derivate of a single parameter in a neural network with respect to a given parameter. This is the main operation during training a neural network.
  • logits: Raw output of a model before conversion to: a distribution, output token, etc.
  • SOTA/SoTA/sota: “State of the art”.
    • RAG (Retrieval Augmented Generation): A technique for extending an LLM’s knowledge base by retrieving and incorporating relevant external documents during inference.
  • Chunking: The process of splitting documents into smaller segments (e.g., by paragraphs or markdown headers) so they can be vectorized and stored in a database for efficient retrieval.
  • Checkpointing: The practice of periodically saving a model’s state (including parameters, configurations, and optimizer state) to disk during training for later restoration or inference.
  • Agent: A system that empowers an LLM to perform actions by planning, managing outputs, and interfacing with tools, often using reinforcement learning to decide on actions.
  • Tool Use: The capability of a model to leverage external resources—such as search engines, calculators, or custom APIs—by structuring its output (often in JSON) to invoke specific functions.
  • Fine Tuning: The process of taking a pretrained model and further training it on a specific, often smaller, dataset (using methods like LoRA or GRPO) to specialize its performance for a particular task.
  • Vector Database: A database optimized for storing and retrieving high-dimensional vectors, enabling efficient semantic search and similarity comparisons.
  • Embedding: A numerical representation of data (such as text or images) in a continuous vector space that captures its semantic properties for tasks like similarity search or classification.
  • Pytree: A nested, tree-like data structure (commonly used in JAX) to represent model parameters, configurations, or other hierarchical data.
  • JIT (Just-In-Time Compilation): A performance optimization technique where code is compiled on the fly to speed up execution, as seen in frameworks like JAX.
  • Quantization: The method of reducing the precision of a model’s parameters (using formats like BF16, FP8, or lower bit representations) to decrease memory usage and accelerate computation without significantly sacrificing performance.
  • VRAM: Video Random Access Memory; the dedicated memory on GPUs used to store model parameters and intermediate computations during training and inference.
  • Token Throughput: The rate at which a language model can generate tokens during inference, typically measured in tokens per second.
  • State Dictionary: In frameworks like PyTorch, a mapping of model layer names to their corresponding parameters (tensors), used for saving and loading models.
  • Auto Grad: The automatic differentiation mechanism that computes gradients for model parameters during training, facilitating backpropagation.