Machine Learning

Home Machine Learning

Anatomy of a Parquet File

0

In recent years, Parquet has become a standard format for data storage in Big Data ecosystems. Its column-oriented format offers several advantages:

  • Faster query execution when only a subset of columns is being processed
  • Quick calculation of statistics across all data
  • Reduced storage volume thanks to efficient compression

When combined with storage frameworks like Delta Lake or Apache Iceberg, it seamlessly integrates with query engines (e.g., Trino) and data warehouse compute clusters (e.g., Snowflake, BigQuery). In this article, the content of a Parquet file is dissected using mainly standard Python tools to better understand its structure and how it contributes to such performances.

Writing Parquet file(s)

To produce Parquet files, we use PyArrow, a Python binding for Apache Arrow that stores dataframes in memory in columnar format. PyArrow allows fine-grained parameter tuning when writing the file. This makes PyArrow ideal for Parquet manipulation (one can also simply use Pandas).

# generator.py

import pyarrow as pa
import pyarrow.parquet as pq
from faker import Faker

fake = Faker()
Faker.seed(12345)
num_records = 100

# Generate fake data
names = [fake.name() for _ in range(num_records)]
addresses = [fake.address().replace("\n", ", ") for _ in range(num_records)]
birth_dates = [
    fake.date_of_birth(minimum_age=67, maximum_age=75) for _ in range(num_records)
]
cities = [addr.split(", ")[1] for addr in addresses]
birth_years = [date.year for date in birth_dates]

# Cast the data to the Arrow format
name_array = pa.array(names, type=pa.string())
address_array = pa.array(addresses, type=pa.string())
birth_date_array = pa.array(birth_dates, type=pa.date32())
city_array = pa.array(cities, type=pa.string())
birth_year_array = pa.array(birth_years, type=pa.int32())

# Create schema with non-nullable fields
schema = pa.schema(
    [
        pa.field("name", pa.string(), nullable=False),
        pa.field("address", pa.string(), nullable=False),
        pa.field("date_of_birth", pa.date32(), nullable=False),
        pa.field("city", pa.string(), nullable=False),
        pa.field("birth_year", pa.int32(), nullable=False),
    ]
)

table = pa.Table.from_arrays(
    [name_array, address_array, birth_date_array, city_array, birth_year_array],
    schema=schema,
)

print(table)
pyarrow.Table
name: string not null
address: string not null
date_of_birth: date32[day] not null
city: string not null
birth_year: int32 not null
----
name: [["Adam Bryan","Jacob Lee","Candice Martinez","Justin Thompson","Heather Rubio"]]
address: [["822 Jennifer Field Suite 507, Anthonyhaven, UT 98088","292 Garcia Mall, Lake Belindafurt, IN 69129","31738 Jonathan Mews Apt. 024, East Tammiestad, ND 45323","00716 Kristina Trail Suite 381, Howelltown, SC 64961","351 Christopher Expressway Suite 332, West Edward, CO 68607"]]
date_of_birth: [[1955-06-03,1950-06-24,1955-01-29,1957-02-18,1956-09-04]]
city: [["Anthonyhaven","Lake Belindafurt","East Tammiestad","Howelltown","West Edward"]]
birth_year: [[1955,1950,1955,1957,1956]]

The output clearly reflects a columns-oriented storage, unlike Pandas, which usually displays a traditional “row-wise” table.

How is a Parquet file stored?

Parquet files are generally stored in cheap object storage databases like S3 (AWS) or GCS (GCP) to be easily accessible by data processing pipelines. These files are usually organized with a partitioning strategy by leveraging directory structures:

# generator.py

num_records = 100

# ...

# Writing the parquet files to disk
pq.write_to_dataset(
    table,
    root_path='dataset',
    partition_cols=['birth_year', 'city']
)

If birth_year and city columns are defined as partitioning keys, PyArrow creates such a tree structure in the directory dataset:

dataset/
├─ birth_year=1949/
├─ birth_year=1950/
│ ├─ city=Aaronbury/
│ │ ├─ 828d313a915a43559f3111ee8d8e6c1a-0.parquet
│ │ ├─ 828d313a915a43559f3111ee8d8e6c1a-0.parquet
│ │ ├─ …
│ ├─ city=Alicialand/
│ ├─ …
├─ birth_year=1951 ├─ ...

The strategy enables partition pruning: when a query filters on these columns, the engine can use folder names to read only the necessary files. This is why the partitioning strategy is crucial for limiting delay, I/O, and compute resources when handling large volumes of data (as has been the case for decades with traditional relational databases).

The pruning effect can be easily verified by counting the files opened by a Python script that filters the birth year:

# query.py
import duckdb

duckdb.sql(
    """
    SELECT * 
    FROM read_parquet('dataset/*/*/*.parquet', hive_partitioning = true)
    where birth_year = 1949
    """
).show()
> strace -e trace=open,openat,read -f python query.py 2>&1 | grep "dataset/.*\.parquet"

[pid    37] openat(AT_FDCWD, "dataset/birth_year=1949/city=Box%201306/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    37] openat(AT_FDCWD, "dataset/birth_year=1949/city=Box%201306/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Box%201306/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Box%203487/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Box%203487/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Clarkemouth/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Clarkemouth/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=DPO%20AP%2020198/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=DPO%20AP%2020198/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=East%20Morgan/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=East%20Morgan/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=FPO%20AA%2006122/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=FPO%20AA%2006122/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=New%20Michelleport/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=New%20Michelleport/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=North%20Danielchester/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=North%20Danielchester/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Port%20Chase/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Port%20Chase/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Richardmouth/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Richardmouth/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 4
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Robbinsshire/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 5
[pid    39] openat(AT_FDCWD, "dataset/birth_year=1949/city=Robbinsshire/e1ad1666a2144fbc94892d4ac1234c64-0.parquet", O_RDONLY) = 3

Only 23 files are read out of 100.

Reading a raw Parquet file

Let’s decode a raw Parquet file without specialized libraries. For simplicity, the dataset is dumped into a single file without compression or encoding.

# generator.py

# ...

pq.write_table(
    table,
    "dataset.parquet",
    use_dictionary=False,
    compression="NONE",
    write_statistics=True,
    column_encoding=None,
)

The first thing to know is that the binary file is framed by 4 bytes whose ASCII representation is “PAR1”. The file is corrupted if this is not the case.

# reader.py

with open("dataset.parquet", "rb") as file:
    parquet_data = file.read()

assert parquet_data[:4] == b"PAR1", "Not a valid parquet file"
assert parquet_data[-4:] == b"PAR1", "File footer is corrupted"

As indicated in the documentation, the file is divided into two parts: the “row groups” containing actual data, and the footer containing metadata (schema below).

The footer

The size of the footer is indicated in the 4 bytes preceding the end marker as an unsigned integer written in “little endian” format (noted “unpack function).

# reader.py

import struct

# ...

footer_length = struct.unpack("
Footer size in bytes: 1088

The footer information is encoded in a cross-language serialization format called Apache Thrift. Using a human-readable but verbose format like JSON and then translating it into binary would be less efficient in terms of memory usage. With Thrift, one can declare data structures as follows:

struct Customer {
	1: required string name,
	2: optional i16 birthYear,
	3: optional list interests
}

On the basis of this declaration, Thrift can generate Python code to decode byte strings with such data structure (it also generates code to perform the encoding part). The thrift file containing all the data structures implemented in a Parquet file can be downloaded here. After having installed the thrift binary, let’s run:

thrift -r --gen py parquet.thrift

The generated Python code is placed in the “gen-py” folder. The footer’s data structure is represented by the FileMetaData class – a Python class automatically generated from the Thrift schema. Using Thrift’s Python utilities, binary data is parsed and populated into an instance of this FileMetaData class.

# reader.py

import sys

# ...

# Add the generated classes to the python path
sys.path.append("gen-py")
from parquet.ttypes import FileMetaData, PageHeader
from thrift.transport import TTransport
from thrift.protocol import TCompactProtocol

def read_thrift(data, thrift_instance):
    """
    Read a Thrift object from a binary buffer.
    Returns the Thrift object and the number of bytes read.
    """
    transport = TTransport.TMemoryBuffer(data)
    protocol = TCompactProtocol.TCompactProtocol(transport)
    thrift_instance.read(protocol)
    return thrift_instance, transport._buffer.tell()

# The number of bytes read is not used for now
file_metadata_thrift, _ = read_thrift(footer_data, FileMetaData())

print(f"Number of rows in the whole file: {file_metadata_thrift.num_rows}")
print(f"Number of row groups: {len(file_metadata_thrift.row_groups)}")

Number of rows in the whole file: 100
Number of row groups: 1

The footer contains extensive information about the file’s structure and content. For instance, it accurately tracks the number of rows in the generated dataframe. These rows are all contained within a single “row group.” But what is a “row group?”

Row groups

Unlike purely column-oriented formats, Parquet employs a hybrid approach. Before writing column blocks, the dataframe is first partitioned vertically into row groups (the parquet file we generated is too small to be split in multiple row groups).

This hybrid structure offers several advantages:

Parquet calculates statistics (such as min/max values) for each column within each row group. These statistics are crucial for query optimization, allowing query engines to skip entire row groups that don’t match filtering criteria. For example, if a query filters for birth_year > 1955 and a row group’s maximum birth year is 1954, the engine can efficiently skip that entire data section. This optimisation is called “predicate pushdown”. Parquet also stores other useful statistics like distinct value counts and null counts.

# reader.py
# ...

first_row_group = file_metadata_thrift.row_groups[0]
birth_year_column = first_row_group.columns[4]

min_stat_bytes = birth_year_column.meta_data.statistics.min
max_stat_bytes = birth_year_column.meta_data.statistics.max

min_year = struct.unpack("
The birth year range is between 1949 and 1958
  • Row groups enable parallel processing of data (particularly valuable for frameworks like Apache Spark). The size of these row groups can be configured based on the computing resources available (using the row_group_size property in function write_table when using PyArrow).
# generator.py

# ...

pq.write_table(
    table,
    "dataset.parquet",
    row_group_size=100,
)

# /!\ Keep the default value of "row_group_size" for the next parts
  • Even if this is not the primary objective of a column format, Parquet’s hybrid structure maintains reasonable performance when reconstructing complete rows. Without row groups, rebuilding an entire row might require scanning the entirety of each column which would be extremely inefficient for large files.

Data Pages

The smallest substructure of a Parquet file is the page. It contains a sequence of values from the same column and, therefore, of the same type. The choice of page size is the result of a trade-off:

  • Larger pages mean less metadata to store and read, which is optimal for queries with minimal filtering.
  • Smaller pages reduce the amount of unnecessary data read, which is better when queries target small, scattered data ranges.

Now let’s decode the contents of the first page of the column dedicated to addresses whose location can be found in the footer (given by the data_page_offset attribute of the right ColumnMetaData) . Each page is preceded by a Thrift PageHeader object containing some metadata. The offset actually points to a Thrift binary representation of the page metadata that precedes the page itself. The Thrift class is called a PageHeader and can also be found in the gen-py directory.

💡 Between the PageHeader and the actual values contained within the page, there may be a few bytes dedicated to implementing the Dremel format, which allows encoding nested data structures. Since our data has a regular tabular format and the values are not nullable, these bytes are skipped when writing the file (https://parquet.apache.org/docs/file-format/data-pages/).

# reader.py
# ...

address_column = first_row_group.columns[1]
column_start = address_column.meta_data.data_page_offset
column_end = column_start + address_column.meta_data.total_compressed_size
column_content = parquet_data[column_start:column_end]

page_thrift, page_header_size = read_thrift(column_content, PageHeader())
page_content = column_content[
    page_header_size : (page_header_size + page_thrift.compressed_page_size)
]
print(column_content[:100])
b'6\x00\x00\x00481 Mata Squares Suite 260, Lake Rachelville, KY 874642\x00\x00\x00671 Barker Crossing Suite 390, Mooreto'

The generated values finally appear, in plain text and not encoded (as specified when writing the Parquet file). However, to optimize the columnar format, it is recommended to use one of the following encoding algorithms: dictionary encoding, run length encoding (RLE), or delta encoding (the latter being reserved for int32 and int64 types), followed by compression using gzip or snappy (available codecs are listed here). Since encoded pages contain similar values (all addresses, all decimal numbers, etc.), compression ratios can be particularly advantageous.

As documented in the specification, when character strings (BYTE_ARRAY) are not encoded, each value is preceded by its size represented as a 4-byte integer. This can be observed in the previous output:

To read all the values (for example, the first 10), the loop is rather simple:

idx = 0
for _ in range(10):
    str_size = struct.unpack("
481 Mata Squares Suite 260, Lake Rachelville, KY 87464
671 Barker Crossing Suite 390, Mooretown, MI 21488
62459 Jordan Knoll Apt. 970, Emilyfort, DC 80068
948 Victor Square Apt. 753, Braybury, RI 67113
365 Edward Place Apt. 162, Calebborough, AL 13037
894 Reed Lock, New Davidmouth, NV 84612
24082 Allison Squares Suite 345, North Sharonberg, WY 97642
00266 Johnson Drives, South Lori, MI 98513
15255 Kelly Plains, Richardmouth, GA 33438
260 Thomas Glens, Port Gabriela, OH 96758

And there we have it! We have successfully recreated, in a very simple way, how a specialized library would read a Parquet file. By understanding its building blocks including headers, footers, row groups, and data pages, we can better appreciate how features like predicate pushdown and partition pruning deliver such impressive performance benefits in data-intensive environments. I am convinced knowing how Parquet works under the hood helps making better decisions about storage strategies, compression choices, and performance optimization.

All the code used in this article is available on my GitHub repository at where you can explore more examples and experiment with different Parquet file configurations.

Whether you are building data pipelines, optimizing query performance, or simply curious about data storage formats, I hope this deep dive into Parquet’s inner structures has provided valuable insights for your Data Engineering journey.

All images are by the author.

Introduction to State Space Models as Natural Language Models

0

State Space Models (SSMs) use first-order differential equations to represent dynamic systems.

The HiPPO framework provides a mathematical foundation for maintaining continuous representations of time-dependent data, enabling efficient approximation of long-range dependencies in sequence modeling.

Discretization of continuous-time SSMs lays the groundwork for processing natural language and modeling long-range dependencies in a computationally efficient way.

LSSL, S4, and S5 are increasingly sophisticated and efficient sequence-to-sequence state-space models that pave the way for viable SSM-based alternatives to transformer models.

While transformer-based models are in the limelight of the NLP community, a quiet revolution in sequence modeling is underway. State Space Models (SSMs) have the potential to address one of the key challenges of transformers: scaling efficiently with sequence length.

In a series of articles, we’ll introduce the foundations of SSMs, explore their application to sequence-to-sequence language modeling, and provide hands-on guidance for training the state-of-the-art SSMs Mamba and Jamba.

In this first article of the three-part series, we’ll examine the core principles of SSMs, trace their evolution from Linear State Space Layers (LSSL) to the S5 model, and examine their potential to revolutionize sequence modeling with unparalleled efficiency.

Understanding state space models

Before exploring how State Space Models (SSMs) can function as components of large language models (LLMs), we’ll examine their foundational mechanics. This will allow us to understand how SSMs operate within deep neural networks and why they hold promise for efficient sequence modeling.

SSMs are a method for modeling, studying, and controlling the behavior of dynamic systems, which have a state that varies with time. SSMs represent dynamic systems using first-order differential equations, providing a structured framework for analysis and simplifying computations compared to solving higher-order differential equations directly.

Let’s dissect what this means.

Consider a system consisting of a moving car on the road. When we supply a certain input to this system (like pressing the gas pedal), we alter the car’s current state (for example, the amount of gas the engine is burning) and consequently cause the car to move at a certain speed.

Because our system’s state varies with time, it is considered a dynamic system. In this case, we are studying one state variable (the amount of gas the engine burns) in our state (the car’s internals). State variables are the minimum number of variables we can use to understand the system’s behavior through mathematical representation.

A car as a dynamic system. The system has a certain input, which is a foot pressing the gas pedal. This input is supplied to the car, influencing its state. The state variable being changed is the amount of gas the engine is burning. The output of the system is the speed of the car.
A car as a dynamic system. The system has a certain input, which is a foot pressing the gas pedal. This input is supplied to the car, influencing its state. The state variable being changed is the amount of gas the engine is burning. The output of the system is the speed of the car.

In our scenario, the car was already moving, so it was burning gas—a result of the previous force on the gas pedal. The speed we would get if we pressed the pedal in a stationary car differs from the speed we would get if the car were already moving since the engine would need less additional gas (and less additional input force) to reach a certain speed. Thus, when determining the speed, we should also factor in the car’s previous state.

A dynamic system with a previous state as the input. The value of the state variable depends not only on the input but also on the previous state.
A dynamic system with a previous state as the input. The value of the state variable depends not only on the input but also on the previous state.

There is one more thing to consider. State Space Models also model a “skip connection,” which represents the direct influence of the input on the output. In our case, the skip connection would model an immediate influence of pressing the gas pedal on the car’s speed, regardless of the current state. In the specific case of a car, this direct feedthrough (D) is zero, but we keep it in the model as, generally, systems can (and do) have direct input‐to‐output dependencies.

A dynamic system with a direct connection between input and output. There is a direct relationship between pressing a car’s gas pedal (input) and the car’s speed (output).
A dynamic system with a direct connection between input and output. There is a direct relationship between pressing a car’s gas pedal (input) and the car’s speed (output).

Now that we have considered all the possible connections in our system, let’s try to model it mathematically. First, we need representations for the variables in our system. We have the previous state of the model, x(t-1), the input, u(t), the current state of the model, x(t), and the output, y(t).

We also need a notation to represent the relationship between every two variables in the system. Let’s denote the effect of the previous state on the current one by a matrix A, the effect of the input on the current state by a matrix B, the effect of the state on the output by a matrix C, and the direct effect of the input on the output by the matrix D.

State space representation of a dynamic system. The input u(t), the state x(t), the output y(t), and the system’s previous state x(t-1) are connected through matrices A, B, C, and D, respectively.
State space representation of a dynamic system. The input u(t), the state x(t), the output y(t), and the system’s previous state x(t-1) are connected through matrices A, B, C, and D, respectively.

From the input u(t), we need to compute two variables:

1. The new state x(t), which considers the effect of the previous state x(t-1) and the input u(t).

2. The output y(t), which considers the effect of the new state x(t) and the direct effect of the input u(t).

Consequently, we can derive the equations for the two variables:

1. The equation for the new state x(t):

The equation for the new state x(t)

2. The equation for the output y(t):

The equation for the output y(t)

These two equations form our system’s state space representation (SSR). The SSR allows us to study the system’s stability by analyzing the effects of inputs on the system’s state variables and output.

We can model probabilistic dependencies between state variables and the inputs by introducing noise terms into the dynamics and observation equations. These stochastic extensions enable us to account for uncertainties in the system and its environment, providing a foundation for modeling and controlling the system’s behavior in real-world scenarios.

State space models for natural language processing

State Space Models (SSMs), long established in time series analysis, have been utilized as trainable sequence models for decades. Around 2020, their ability to efficiently handle long sequences spurred significant progress in adapting them for natural language processing (NLP).

The exploration of SSMs as trainable sequence models was gradual through multiple contributions that laid the foundation for introducing SSMs in deep learning models as “State Space Layers” (SSLs). In the following sections, we’ll explore key contributions that led to the use of SSMs as NLP models.

Applying SSMs to natural language processing reframes the input as a token, the state as the contextual representation, and the output as the predicted next token.

HiPPO: recurrent memory with optimal polynomial projections

The primary challenge sequence models face is capturing dependencies between two inputs that are far apart in a long sequence.

Let’s say we have a paragraph where the last sentence references something mentioned in the first sentence:

The word ‘Sushi’ in the first sentence is referenced in the last sentence, with a large number of words in between. Thus, understanding the phrase “that name” in the last sentence requires the first sentence for context.

The word ‘Sushi’ in the first sentence is referenced in the last sentence, with a large number of words in between. Thus, understanding the phrase “that name” in the last sentence requires the first sentence for context.

Historically, sequence models, such as traditional RNNs, GRUs, and LSTMs, struggled to retain such long-range dependencies due to problems like vanishing or exploding gradients. The gating mechanisms these algorithms rely on regulate information flow by selectively retaining important features and discarding irrelevant ones, which mitigates issues like short-term memory loss.

However, these mechanisms are insufficient for capturing long-range dependencies because they struggle to preserve information over extended sequences. This is due to capacity constraints, a tendency to prioritize short-term patterns during training, and cumulative errors that degraded information over long sequences. While transformers address many of these issues through their self-attention mechanism, due to the quadratic complexity of attention, they are computationally inefficient for long sequences.

Albert Gu and colleagues at Stanford attempted to solve this problem by introducing HiPPO (short for “High-order Polynomial Projection Operators”). This mathematical framework aims to compress historical information into a fixed-size representation. The fixed-size representation captures the entire processed sequence and enables sequence models to process and utilize long-range dependencies efficiently. Unlike the hidden state in an LSTM or GRU, which is also a fixed-size representation but primarily optimized for short-term memory retention, HiPPO is explicitly designed to capture the entire processed sequence, enabling sequence models to process and utilize long-range dependencies efficiently.

HiPPO works by constructing a set of polynomial bases that are mathematically orthogonal with respect to a specific weighting function. The weighting function w(t) weighs the importance of historical information using one of two variants:

1. Transform HiPPO Matrix Variations: Transform matrices prioritize the latest inputs and change the system’s response continuously with time. The importance of information stored in the sequence history decays over time.

2. Stationary HiPPO Matrix Variations: Stationary matrices are time-invariant and consider all past data with consistent importance. The rate of natural decay of information remains consistent over time, providing a balance between retaining historical information and responding to new inputs.

Gu and colleagues applied the two variants to three different polynomial families referred to as Leg, Lag, and Cheb. The difference between the Leg, Lag, and Cheb is the amount of information retention, which is determined by the variations in the weighting functions w(t) associated with each set of polynomials and their orthogonality properties:

1. HiPPO-Leg is based on the Legendre polynomials. It gives uniform weighting for all the information in the sequence. Thus, the weighting function w(t) = 1. As the sequence length becomes larger, the older parts of the sequence are compressed into a fixed-size representation. 

2. HiPPO-Lag is based on the Laguerre polynomials. There is an exponential decay of information over time.

3. HiPPO-Cheb is based on the Chebyshev polynomials. It creates a non-uniform distribution that prioritizes the latest and oldest information.

The storage and prioritization of the sequence’s historical data is due to the mathematical properties of these polynomials. The appendix of the HiPPO paper contains all the equations and mathematical proofs.

The HiPPO matrix is obtained by deriving differential operators that project the input signal onto the specified polynomial basis in real-time. The operators ensure the orthogonality of the states while preserving the defined weighting function. The following equation defines them:

The HiPPO matrix

Here, ϕ​(t) are the basis functions of the chosen family of orthogonal polynomials (i.e., Legendre, Laguerre, or Chebyshev), ϕ′i is the derivative of the i-th basis function with respect to time t, and w(t) is the weighting function that defines the importance of information over time. i is the index of the current state or basis function being updated, and j is the index of the previous state or basis function contributing to the update. It points to the j-th basis function that is being integrated with respect to w(t). The integral computes the contribution of the j-th basis function to the update of the i-th state, considering the weighting w(t).

This mechanism allows for efficiently updating the model’s hidden state, minimizing the loss of long-range dependencies. Thus, the HiPPO matrix can be used to control the update of a model’s context or hidden state.

This sounds familiar, right? In the previous section, we saw that the representation of the state change (A) for text data would be the context of the text (or sequence). Just like in RNNs and LSTMs, we can use this context (or hidden state) to predict the next word. Since its structure allows it to handle long- and short-range dependencies, HiPPO acts as a template for the matrix A

Combining recurrent, convolutional, and continuous-time models with linear state-space layers

HiPPO’s inventors collaborated with other Stanford researchers to develop the Structured State Space Sequence model, which uses the HiPPO framework. This model makes significant strides in applying SSMs to sequence modeling tasks.

Their 2021 paper Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers aims to combine the best and most efficient properties of all the existing sequence modeling algorithms.

According to the authors, an ideal sequence modeling algorithm would have the following capabilities:

1. Parallelizable training, as is possible with Convolutional Neural Networks (CNNs). This saves computational resources and enables a faster training process.

2. Stateful inference, as provided by Recurrent Neural Networks (RNNs). This allows context to be used as a factor while deciding on the output.

3. Time-scale adaptation, as in Neural Differential Equations (NDEs). This enables the sequence model to adapt to various lengths of input sequences.

In addition to these properties, the model should also be able to handle long-range dependencies in a computationally efficient manner.

Motivated by these goals, the authors explored using State Space Models (SSMs) to develop a computationally efficient and generalizable sequence model suitable for long sequences.

Let’s explore how they did that:

As we learned above, the SSR equations represent a dynamic system with a continuously changing state. To apply SSMs to NLP, we need to adapt these continuous-time models to operate on discrete input sequences. Rather than continuous signals, we’ll now feed strings of individual tokens to the model one by one.

Discretization

We can discretize the continuous SSR equations using numerical methods.

To understand this process, we will return to the example of the continuously moving car. The car’s speed is a continuous signal. To study the variation in the car’s speed, we need to measure it at all times. However, it’s impractical to record every infinitesimal change in speed. Instead, we take measurements at regular intervals—for example, every 30 seconds.

By recording the car’s speed at these specific moments, we convert the continuous speed profile into a series of discrete data points. This process of sampling the continuous signal at regular intervals is called “discretization.” The interval of time we are using to measure the speed is called the time scale Δt, also known as “step size” or “discretization parameter.”

To convert a continuous signal into a discrete signal, it is sampled in fixed intervals Δt.
To convert a continuous signal into a discrete signal, it is sampled in fixed intervals Δt.

Similar to discretizing car speed, to adapt SSMs for natural language processing, we start with continuous-time equations that describe how a system evolves. We discretize the equations, converting them into a form that updates at each discrete time step.

The choice of Δt is critical: if it is too large, we risk losing important details of the state dynamics (undersampling):

The choice of Δt is critical: if it is too large, we risk losing important details of the state dynamics (undersampling):

If Δt is too small, the system might become inefficient or numerically unstable due to excessive computations (oversampling):

If Δt is too small, the system might become inefficient or numerically unstable due to excessive computations (oversampling).

In Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers, the authors explored several methods for discretizing state-space models to adapt them for sequence modeling tasks. They ultimately selected the Generalized Bilinear Transform (GBT), which effectively balances accuracy (by avoiding oversampling) and stability (by avoiding undersampling). The GBT allows the discrete state-space model to approximate the continuous dynamics while maintaining robustness in numerical computations.

The discrete state equation under GBT is given by:

Here, x is the state representation, Δt is the time step, A is the matrix that represents how the state is influenced by the previous state, B is the matrix that represents the effect of the input on the current state, and I is the identity matrix which ensures that the output has consistent dimensionality.

A critical decision when applying the Generalized Bilinear Transform is the choice of the parameter α, which controls the balance between preserving the characteristics of the continuous-time system and ensuring stability in the discrete domain. The authors selected α = 0.5 as it counterbalances accuracy and numerical stability. The resulting state equation is given by:

The bilinear transform equation is then applied to the initialized continuous-time matrices A and B, discretizing them into A  and B respectively.

Now that we have a discretized version of the SSR equations, we can apply them to natural language generation tasks where:

1. u(t) is the input token we feed into the model.

2. x(t) is the context, which is the representation of the sequence’s history thus far.

3. y(t) is the output, the predicted next token.

Thus, we have a representation of SSMs that can handle tokens as input.

State Space Model with discretized matrices A and B. A and B map the current context xt-1 and the input token ut to the new context xt. C maps the context to the output token yt, with D modeling the direct relationship between ut and yt. The direct connection between the input and the output mediated by D is treated as a skip connection and is not explicitly incorporated into the model's internal architecture.
State Space Model with discretized matrices A and B. A and B map the current context xt-1 and the input token ut to the new context xt. C maps the context to the output token yt, with D modeling the direct relationship between ut and yt. The direct connection between the input and the output mediated by D is treated as a skip connection and is not explicitly incorporated into the model’s internal architecture.

The three pillars of SSMs as sequence models

Now that we can use SSMs for NLP tasks, let’s see how they measure up with respect to the other available sequencing algorithms by circling back to the goals the authors stated at the beginning of Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers.

Parallelizable training

Parallelizable training would save a considerable amount of computational resources and time. Two widely used sequencing architectures are inherently parallelizable during training:

1. Convolutional Neural Networks (CNNs) are inherently parallelizable because the convolution operation can be applied simultaneously across all positions in the input sequence. In sequence modeling, CNNs process the entire input in parallel by applying convolutional filters over the sequence, allowing for efficient computation during training.

2. Transformers achieve parallelism through the self-attention mechanism, which simultaneously computes attention weights between all pairs of tokens in the sequence. This is possible because the computations involve matrix operations that can be parallelized, allowing the model to process entire sequences at once.

Efficiently distributing the computational workload is crucial for sequence algorithms, especially when training on large datasets. To address this challenge, the authors introduced a convolutional representation of SSMs, which allows these models to process sequences in parallel, similar to CNNs and Transformers.

The author’s idea is to express the SSM as a convolution operation with a specific kernel k derived from the state-space parameters, enabling the model to compute outputs over long sequences efficiently.

To derive the SSR equations as a convolution operation, they assume the SSM model to be time-invariant. This means the matrices A, B, C, and D do not vary with time, the matrix A is stable (which is already achieved by adopting the HiPPO matrix for A that allows a numerically stable update of the context), and the initial state x(0) is 0.

Using the SSR equations mentioned earlier (state equation that derives x(t) and output equation that derives y(t)), the kernel k can be derived in two steps:

1. Solving for the state, we start with the state equation from the SSR equations where x0 = 0:

Solving for the state, we start with the state equation from the SSR equations where x0 = 0

We derived the state xn, which represents the system’s state at time step n, based on the contributions of past inputs. Similarly, uk denotes the input to the system at a specific time step k within the sequence. The number of time steps n (i.e., the number of times we sample using Δt) depends on the length of the input sequence, as the state xn​ is influenced by all preceding inputs up to time n−1.

2. Substitute the xn in the SSR output equation with the state that is derived from step 1.

Substitute the xn in the SSR output equation with the state that is derived from step 1.

We can simplify this equation by combining the state representations (A, B, C, and D) as the kernel k:

We can simplify this equation by combining the state representations (A, B, C, and D) as the kernel k

Here, m is the index for summing over past inputs. The result is the following equation for the output at step n:

Here, m is the index for summing over past inputs. The result is the following equation for the output at step n

Thus, we are left with the convolutional representation of State Space Representation: We take the input un as a common factor and denote the term multiplied by the input as the kernel k. We obtain the outputs from the input sequence by passing the kernel across it.

Stateful inference

Stateful inference refers to a sequence model’s ability to create, maintain, and utilize a “state,” which includes all the relevant context needed for further computations. This ability is desirable because it eliminates the computational inefficiency of understanding the context whenever a new input token is present.

Transformers capture long-range dependencies and context through the self-attention mechanism. However, recomputing the attention weights and value vectors every time we have a new input token is computationally expensive. We can cache the values of key and value vectors to avoid some recomputation, which makes it slightly more efficient. Still, it does not solve the problem of transformers scaling quadratically.

RNNs achieve stateful inference through a hidden state that is only updated and not recomputed for every input token. However, RNNs struggle to retain information from earlier tokens in long sequences. This limitation arises because, during backpropagation, gradients associated with long-range dependencies diminish exponentially as they are propagated through many layers (or time steps), a phenomenon known as the vanishing gradient problem. As a result, RNNs cannot effectively model long-range dependencies between tokens.

Thanks to their state equation, SSMs achieve stateful inference. They inherently maintain a state containing the sequence’s context, making them more computationally efficient than transformer-based models.

To handle long-range dependencies, the authors of Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers use the HiPPO-LegS (Stationary form of HiPPO-Leg) formulation to parameterize A.

Time-scale adaptation

Time-scale adaptation refers to a sequence model’s ability to capture dependencies for the input token in different parts of the input sequence. In technical terms, this means the context can retain dependencies that occur over different temporal distances within the same sequence. Time-scale adaptation enables effective capturing of both short-term (immediate) and long-term (distant) relationships between elements in the data.

A model’s context representation is crucial for its ability to capture the internal dependencies within a sequence. SSMs represent the context as the matrix A. Thus, an SSM’s ability to update the state based on the new input through the state equation allows the model to adapt to the contextual dependencies within a sequence, allowing it to handle both long and short-range dependencies.

Linear state space layers (LSSLs)

So far, we’ve seen that State Space Models are efficient sequence models. In their paper Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers, Gu and colleagues introduced the Linear State Space Layer (LSSL) utilizing both the discretized recurrent and convolutional forms of State Space Representation equations. This layer is integrated into deep learning architectures to introduce efficient handling of long-range dependencies and structured sequence representations.

Like RNNs, SSMs are recurrent. They update the context by combining the previous state with the new state. This recurrent form is very slow to train because we need to wait for the previous output to be available before computing the next one. To address this problem, the authors devised the convolutional representation of the SSM equations that we discussed in the previous sections.

While the convolutional representation of SSMs enables training parallelization, it is not without its own problems. The key issue is the fixed size of the kernel. The kernel we are using to process the input sequence is determined by the model parameters (matrices A, B, C, and D) and sequence length, as we saw in the first step of the kernel derivation. However, natural language sequences vary in length. Thus, the kernel would be recomputed during inference based on the input sequence, which is inefficient. 

Although recurrent representations are inefficient to train, they can handle varying sequence lengths. Thus, to have a computationally efficient model, we seem to need the properties of both the convolutional and recurrent representations. Gu and colleagues devised a “best of both worlds” approach, using the convolutional representation during training and the recurrent representation during inference.

Comparison of the continuous-time, recurrent, and convolutional forms of SSMs. The Linear State Space Layer adopts both the recurrent and convolutional forms of the SSM representation to leverage their complementary advantages. The recurrent form is used during inference, and the convolutional form during training.
Comparison of the continuous-time, recurrent, and convolutional forms of SSMs. The Linear State Space Layer adopts both the recurrent and convolutional forms of the SSM representation to leverage their complementary advantages. The recurrent form is used during inference, and the convolutional form during training. | Source

In their paper, Gu and collaborators describe the LSSL architecture as a “deep neural network that involves stacking LSSL layers connected with normalization layers and residual connections.” Similar to the attention layers in the transformer architecture, each LSSL layer is preceded by a normalization layer and followed by a GeLU activation function. Then, through a residual connection, the output is added to the normalized output of a position-wise feedforward layer.

Architecture of a Linear State Space Layer. Each input has H features (the size of the token’s embedding vector) that are processed by independent copies of the SSM as one-dimensional inputs in parallel. Each SSM copy produces an M-dimensional output for each feature. The combined outputs are fed through a GeLU activation function and a position-wise feed-forward layer.
Architecture of a Linear State Space Layer. Each input has H features (the size of the token’s embedding vector) that are processed by independent copies of the SSM as one-dimensional inputs in parallel. Each SSM copy produces an M-dimensional output for each feature. The combined outputs are fed through a GeLU activation function and a position-wise feed-forward layer.

Efficiently modeling long sequences with state structured spaces

The LSSL model performed impressively well on sequence data but was not widely adopted due to computational complexities and memory bottlenecks.

Results of testing the original LSSL model on the sequential MNIST, permuted MNIST, and sequential CIFAR tasks, which are popular benchmarks originally designed to test theability of recurrent models to capture long-term dependencies of length up to1k. LSSL sets SoTA on sCIFAR by more than 10 points.
Results of testing the original LSSL model on the sequential MNIST, permuted MNIST, and sequential CIFAR tasks, which are popular benchmarks originally designed to test theability of recurrent models to capture long-term dependencies of length up to1k. LSSL sets SoTA on sCIFAR by more than 10 points.

In the paper Efficiently Modeling Long Sequences with State Structured Spaces, Gu, together with close collaborators Karan Goel and Christopher Ré, advanced the LSSL to reduce the computational complexity and accuracy of the training process.

Improvements on the state matrix A

In the previous section, we explored how the original LSSL relied on a fixed, predefined form of the HiPPO matrix to serve as the state matrix A. While this representation was successful in compressing information, it was computationally inefficient due to the full (dense) matrix representation of A. Gu, Goel, and Ré described this implementation as “infeasible to use in practice because of prohibitive computation and memory requirements induced by the state representation.”

In the LSSL, the state is multiplied by the matrix A to produce the updated version of the state. The most computationally efficient form of the matrix A for multiplication would be a diagonal matrix. Unfortunately, the HiPPO matrix could not be reformed as a diagonal matrix since it does not have a full set of eigenvectors.

However, the authors were able to dissect the matrix into a diagonal plus low-rank decomposition (DPLR). The diagonal matrix has nonzero entries only on the main diagonal, which makes the multiplication process more efficient by requiring only a single multiplication per vector element. The low-rank matrix can be represented as the product of two much smaller matrices. Because of this factorization, the operations needed to multiply by the vector are greatly reduced compared to a full-rank matrix of the same size.

The original LSSL architecture required O(N2L) operations, where N is the state dimension, and L is the sequence length. After the transformation of the matrix A into its diagonal plus low-rank (DPLR) form, both the recursive and convolutional forms’ computational complexity were reduced:

1. For the recurrent form, the DLPR form has only O(NL) matrix-vector multiplications.

2. For the convolutional form, the convolutional kernel was reduced to require only O(N log L + L log L) operations. This was achieved by changing the technique used to derive the kernel, which included using the inverse Fast Fourier Transform (iFFT) and applying the Woodbury identity to reduce the low-rank term of matrix A.

This is a considerable leap in computational efficiency, significantly reducing the scaling with sequence length and bringing SSMs closer to linear time complexity, in contrast to the quadratic scaling of transformers.

Improvements in the training implementation

After tackling the LSSL’s computational complexity, the authors found another significant improvement, which is making the matrix A (partially) learnable. In the LSSL, the matrix was fixed and not updated during the training process. Rather, the matrices B and C were responsible for the update and learnability of the SSM blocks.

Keeping the matrix A fixed ensures computational efficiency, but it limits the model’s ability to capture complex dynamics and underlying patterns in the sequence. A fully learnable matrix A offers the flexibility to adapt to arbitrary dynamics. However, it comes with trade-offs: more parameters to optimize, slower training, and higher computational costs during inference.

To balance these competing demands, the modified LSSL – dubbed S4 – adopts a partially learnable A. By maintaining the DPLR structure of A, the model retains computational efficiency, while the introduction of learnable parameters enhances its ability to capture richer, domain-specific behaviors. By introducing learnable parameters into A, a model can adjust the state dynamics during training and update sequence-specific internal representations in the state.

Additionally, Efficiently Modeling Long Sequences with State Structured Spaces introduces techniques for implementing bidirectional state-space models. These models can process sequences in both the forward and backward directions, capturing dependencies from past and future contexts.

Simplified state space layers for sequence modeling

In Simplified State Space Layers for Sequence Modeling, Jimmy Smith, Andrew Warrington, and Scott Linderman proposed multiple improvements to the S4 architecture to enhance performance while maintaining the same computational complexity.

While the improvements of S4 over the original LSSL mainly focus on reducing the model’s computational complexity, S5 aimed to simplify the architecture, making it more efficient and easier to implement while maintaining or improving performance.

Using parallel associative scan

Parallel scan, also known as parallel associative scan, is an algorithm that allows parallel computation through pre-computing cumulative operations (in this case, products) up to each position in the sequence so they can be selected during the processing step instead of processed one at a time.

Using a parallel associative scan, Smith and colleagues were able to parallelize the training process of recurrent SSMs, removing the need for the use of the convolutional representation.

Thus, the S5 layer operates only in the time domain instead of having the convolutional and frequency domain. This is an important improvement because it allows the time complexity per layer to be O(N log ⁡L) instead of O(NL), leveraging parallel computation over the sequence length while reducing the memory overhead.

Allowing multi-input-multi-output

LSSL and S4 are Single-Input-Single-Output (SISO) models. Allowing Multi-Input-Multi-Output (MIMO) was computationally infeasible since the computations inside LSSL and S4 were designed under the assumption of having one input at a time. For example, adapting the convolutional representation to operate on matrices instead of vectors would have significantly increased the computational cost, making the approach impractical.

Smith and collaborators discretized the MIMO SSM equations instead of the SISO SSM equations. Using the same SSR equations, they extended the discretization process to handle m-dimensional inputs and n-dimensional outputs. Assuming the state has N dimensions, this change makes B an N x m matrix instead of N x 1, and C an n x N matrix instead of 1 x N.

S5’s support for MIMO allows it to handle multidimensional data, such as multivariate and multi-channel time series data, process multiple sequences simultaneously, and produce multiple outputs. This reduces computational overhead by allowing multiple sequences to be processed at the same time instead of having m copies of the SSM.

Diagonalized parametrization

As we discussed above, HiPPO-LegS could not be diagonalized. However, the parallel scan approach requires a diagonal matrix A. Through experimentation, Smith and colleagues discovered that they could represent the HiPPO-LegS matrix as a normal plus low-rank (NLPR) matrix, where the normal component is referred to as HiPPO-N, which can be diagonalized.

They showed that removing the low-rank terms and initializing the HiPPO-N matrix had similar results by proving that HiPPO-N and HiPPO-LegS produced the same dynamics. (A proof is given in the appendix of the paper.) However, if they were to use the diagonal matrix from the DPLR approximation, the approximation would have produced very different dynamics than the original structure.

Using a diagonalized version of the HiPPO-N matrix reduced the model’s computational complexity by removing the need to convert the HiPPO-LegS matrix into its DPLR approximation.

Similar to how using a structured parametrization for matrix A decreased the computational overhead, S5 uses a low-rank representation of matrices B and C, further reducing the number of parameters.

The computational components of an S5 layer, which uses a parallel scan on a diagonalized linear SSM to compute the SSM outputs. A nonlinear activation function is applied to the SSM outputs to produce the layer outputs.
The computational components of an S5 layer, which uses a parallel scan on a diagonalized linear SSM to compute the SSM outputs. A nonlinear activation function is applied to the SSM outputs to produce the layer outputs. | Source

Conclusion and outlook

The evolution of State Space Models (SSMs) as sequence-to-sequence models has highlighted their growing importance in the NLP domain, particularly for tasks requiring the modeling of long-term dependencies. Innovations such as LSSL, S4, and S5 have advanced the field by enhancing computational efficiency, scalability, and expressiveness.

Despite the advancements made by the S5 model, it still lacks the ability to be context-aware. The S5 can efficiently train and infer in the time domain and retain information for long-range dependencies, but it does not explicitly filter or focus on specific parts of the sequence, as Transformers do with attention mechanisms.

Hence, a key next step is to incorporate a mechanism into SSMs that enables them to focus on the most relevant parts of the state rather than processing the entire state uniformly. This is what the Mamba model architecture addresses, which we’ll explore in the upcoming second part of the series.

Was the article useful?

Explore more content topics:

Hyperparameter Optimization For LLMs: Advanced Strategies

0

Finding an optimal set of hyperparameters is essential for efficient and effective training of Large Language Models (LLMs).

The key LLM hyperparameters influence the model size, learning rate, learning behavior, and token generation process.

Due to their computational demands, traditional methods for optimizing hyperparameters, such as grid search, are impractical for LLMs.

Advanced hyperparameter optimization strategies, like population-based training, Bayesian optimization, and adaptive LoRA, promise to balance computational effort and outcome.

The rise of large language models (LLMs) is bringing advances in text generation and contextual understanding. Hyperparameters control the size of LLMs, their training process, and how they generate outputs.

An optimal combination of hyperparameters is fundamental to efficiently pre-training and fine-tuning LLMs. Since LLM training is computationally intensive, exhaustive experimentation is not viable. This rules out traditional machine-learning hyperparameter optimization (HPO) methods that rely on systematically exploring the hyperparameter space by training many models with slightly different configurations.

When configuring models and training processes, LLM developers rely on a thorough understanding of each hyperparameter’s influence, insights from fundamental research, and empirical evidence gained from training state-of-the-art foundation models. Methods for estimating optimal hyperparameter values with limited compute budgets and adapting hyperparameters throughout the training process can help pre-training and fine-tuning.

After reading this article, you’ll be able to answer the following questions:

  • What key hyperparameters should be considered when developing, training, and applying LLMs?
  • How does each hyperparameter influence the LLM, and which trade-offs do we need to be aware of?
  • How can we select an optimal combination of hyperparameters in our scenario without fully training multiple model variants?
  • What advanced hyperparameter optimization techniques are available for LLMs, and when can we apply them?

LLM hyperparameters

A hyperparameter is a configuration value that controls the behavior of a machine-learning model during the training or inference process. Unlike model parameters (the weights), which are learned directly from the training data, hyperparameters are defined by the model developers. A hyperparameter can be constant or adjusted dynamically according to predefined rules or schedules.

Model size

In the case of LLMs, we often work with pre-trained models, where the activation functions, internal architecture of layers or blocks, and their connections—all examples of hyperparameters—are fixed. If our pre-trained LLM of choice is available in different sizes, the model size is the only hyperparameter affecting the model’s makeup we can actively control.

The size of an LLM refers to the total number of parameters it contains, which influences the model’s capacity to understand and generate complex language patterns. Hyperparameters set and tuned during pre-training influence the total size of an LLM.

One hyperparameter influencing a model’s size is its depth, corresponding to the total number of layers stacked sequentially. Each additional layer in an LLM adds more parameters, such as the weights for the self-attention mechanism and feed-forward layers in a transformer block.

Another hyperparameter influencing an LLM’s size is its hidden size, which refers to the dimensionality of the token embeddings and the internal representations within each layer. The hidden size determines how richly the model can encode information about each input token and how effectively it can process complex language patterns. A larger hidden size means each token is represented in a higher-dimensional space, allowing the model to capture more detailed semantic and syntactic nuances.

Further, the number of parallel attention heads in each transformer block influences the size of the LLM. Multiple heads allow the model to focus on different input aspects simultaneously. Through multi-query and grouped-query attention, we can reduce the number of necessary parameters.

Finally, the vocabulary size and context window (maximum sequence length) also impact the model’s size. They determine the language diversity a model can handle and the context length it can maintain, respectively.

These hyperparameters, set before beginning the training process and unable to be changed later, determine the model size. For example, GPT-3 has 96 layers, a hidden size of 12,288, 96 attention heads, a vocabulary of 50,257 tokens, and a context window of 2,048 tokens, resulting in a total of 175 billion parameters.

Learning rate

The learning rate (LR) is a critical hyperparameter in training LLMs. Optimizing these hyperparameters is essential for efficient learning, stable convergence, and good generalization to unseen data.

The learning rate determines how much model weights are changed during each update. A high learning rate helps speed up the training process but increases the risk of instability and overfitting. A low learning rate increases stability and tends to benefit generalization but leads to slow training.

In the case of LLMs, the learning rate is typically not constant but varies as training progresses. This variation is governed by a learning rate schedule (LRS). The schedule is usually tied to the number of tokens seen—either directly, or indirectly through the number of samples, steps, or epochs. At a high level, it contains phases of a rising, constant, and decreasing learning rate.

How does the learning rate affect training duration and quality?

Following theoretical work by Stanford researcher Kaiyue Wen and colleagues published in December 2024, we can think of LLM training as progressing along a loss landscape that looks like a river valley. They hypothesize that the existence and overall direction of the river are due to the facts and knowledge an LLM learns, which are reflected as highly deterministic and, therefore, easy-to-predict tokens. The valley slopes arise from flexibility and ambiguity inherent to language, i.e., hard-to-predict tokens.

Visualization of LLM training as traveling down a river valley. Using a stable but high learning rate ensures quick progress down the river but leads to jumps between relatively high loss values. Reducing the learning rate during a subsequent decay phase brings the model towards a local loss minimum.
Visualization of LLM training as traveling down a river valley. Using a stable but high learning rate ensures quick progress down the river but leads to jumps between relatively high loss values. Reducing the learning rate during a subsequent decay phase brings the model towards a local loss minimum. | Source

In this picture, the training goal is to reach the river mouth, at which point we should be as close to the bottom of the valley as possible. The first crucial insight is that it does not matter whether we stay at the bottom of the valley until then. Thus, if we can make faster progress down the river by bouncing back and forth between points high up the loss valley’s slopes, we can do this without affecting the final outcome.

Thus, we should aim to use a high learning rate—resulting in large steps towards the loss minimum but leading to wildly fluctuating loss values—for as long as possible. Towards the end of the training, the learning rate should be decreased to a very low value. This will slow down progress towards the river mouth but reduce the oscillations to a point where we constantly stay at the valley’s bottom, i.e., the local loss minimum.

However, all of this is only going to work if we are already in a sufficiently deep loss river valley. When training is first starting, a high learning rate will lead to undirected jumps across the loss landscape. To avoid this, learning rate schedules for LLMs start with a small learning rate and slowly ramp it up to its maximum value. This is called the warmup phase.

Cosine schedule

The cosine schedule (also known as cosine decay or cosine annealing) implements this approach by starting with a linear warmup phase that brings the learning rate to its maximum value, followed by a slow decay following the cosine function:

LR(t) = LRmin + 0.5 (LRmax – LRmin) (1 + cos(π t/T)

Here, LRmin and LRmax are the minimum and maximum learning rates, t is the training step, and T is the total number of training steps. The advantage of this schedule is that it stays close to the peak learning rate for a long time, and the final decay is gradual. It’s also easy to implement, as it depends on just three hyperparameters (LRmax, LRmin, and T) linked by the cosine function.

Cosine schedules have been highly popular for pretraining LLMs. For example, it was used for BLOOM, a 176-billion-parameter multilingual model developed by the BigScience Research Workshop and released in 2022. In an initial warmup phase, the learning rate was ramped to a peak of 6 x 10-5 over 375 million tokens. Afterward, it was lowered to 10% of this value with cosine decay over 410 million tokens and remained at this value. The implementation and detailed description are publicly accessible in BLOOM’s GitHub repository.

For pre-training their Llama 3 405B model, Meta used a slightly more involved variant of the cosine schedule. In the first stage, a warm-up phase of up to 8,000 steps brought the learning rate to a maximum of 8 x 10-5. Subsequently, the learning rate decreased to 8 x 10-7 over 1.2 million steps with a cosine decay. After the second stage focused on training the LLM up to its final context length of 128,000 tokens, the learning rate linearly decreased to 0 over 40 million tokens in the third stage. Supervised fine-tuning was conducted over about 9,000 steps with a learning rate of 10-5.

A major disadvantage of the cosine schedule is that the total number of training steps has to be known beforehand. When training large foundation models, the total compute budget is typically set, and the optimal number of training tokens can be estimated. However, when fine-tuning or experimenting, it would be preferable to base the decision on when to end training on the model’s performance.

Warmup-stable-decay schedule

The warmup-stable-decay (WSD) schedule is a simple protocol introduced by Shengding Hu and colleagues at Tsinghua University in 2024. It starts with a linear warmup to the maximum learning rate, keeps the learning rate constant for the majority of the training, and ramps it down at the end.

Through experiments, they found that a decay phase that makes up 10% of the total length is sufficient. They also demonstrated that a WSD schedule leads to a lower loss than a cosine schedule. According to Wen and colleagues at Stanford, this can readily be understood in the river valley picture. In the WSD schedule, the learning rate stays at a high value longer than in the cosine schedule. Hence, we make it further down the valley before dropping to its bottom. Further, their analysis shows that training progress in the stable phase is dominated by learning to predict deterministic tokens (facts and knowledge), while in the decay phase, the LLM learns the stochastic tokens (language variability).

Comparison of the loss curves resulting from a cosine and warmup-stable-decay (WSD) learning rate schedule. In the WSD schedule, the learning rate remains at a constant high value during the stable phase. This leads to high intermediate loss values as the loss fluctuates around the local minimum as it progresses towards lower values. During the final 10% of the total training steps, the learning rate is decreased to its minimum, leading to a sharp drop in the loss. Since the learning rate remained at a high value for longer, the final loss resulting from the WSD schedule is smaller than the loss from the cosine schedule.
Comparison of the loss curves resulting from a cosine and warmup-stable-decay (WSD) learning rate schedule. In the WSD schedule, the learning rate remains at a constant high value during the stable phase. This leads to high intermediate loss values as the loss fluctuates around the local minimum as it progresses towards lower values. During the final 10% of the total training steps, the learning rate is decreased to its minimum, leading to a sharp drop in the loss. Since the learning rate remained at a high value for longer, the final loss resulting from the WSD schedule is smaller than the loss from the cosine schedule. | Source

While a WSD schedule yields a lower loss for the same training budget, knowing the total number of training steps ahead of time is still required for scheduling the decay phase. However, the WSD schedule offers a straightforward way to extend the total number of training steps retroactively: If we find that our final model’s performance is unsatisfactory, we can resume training from a model snapshot taken at the end of the stable phase. This beams us back a small distance up the loss river valley, from where we continue making large jumpy steps towards the river mouth as if we had never descended down to the valley’s bottom in the first place.

Restarting this way, we still benefit from 90% of the compute budget spent so far. It allows us to determine the compute budget we need as we go, producing fully trained intermediate models—something that the cosine schedule inherently does not allow for.

Track months-long model training with more confidence. Use neptune.ai forking feature to iterate faster and optimize the usage of GPU resources.

With Neptune, users can visualize forked training out of the box. This means you can:

  • Test multiple configs at the same time. Stop the runs that don’t improve accuracy. And continue from the most accurate last step.
  • Restart failed training sessions from any previous step. The training history is inherited, and the entire experiment is visible on a single chart.

Cyclical cosine schedule

Returning to a high learning rate after decaying to a minimum is not a new idea in machine learning. Long established in gradient-free optimization, it was made popular for deep learning training through the “Stochastic Gradient Descent with Warm Restarts” technique proposed by Ilya Loshchilov and Frank Hutter in 2017. The learning rate is governed by a function very similar to the one for the cosine schedule:

LR(t) = LRmin + 0.5 (LRmax − LRmin) (1 + cos(π (t mod T)/T))

This time, T is not the total number of training steps but is understood as the schedule’s period. For example, we might train for 10,000 steps with T = 1,000, leading to ten consecutive cosine decay cycles. Commonly, LRmax is set to a new, lower value at the beginning of each cycle.

In the loss landscape river valley, we’re climbing down to the bottom over T steps, making ever slower progress down the river as we keep closer to the bottom. Then, we immediately go back to make large jumps toward the river mouth high up the valley’s slopes.

Right at the beginning of a new cosine cycle, the loss will be significantly higher than it was previously. This could be due to the jump in the learning rate, which might perturb the model. However, Wen and colleagues argue, based on their experiments and theoretical insights, that it is the result of training with a small learning rate for too long.

Whatever the cause, this doesn’t just make training less efficient. It’s also an obstacle to continue model training later. Whether we aim to further pre-train on newly acquired or different data, fine-tune an LLM, or incrementally evolve a model in a continual learning scenario—ideally, we could take a model snapshot and train it effectively, making the most of the compute budget we have available and the compute budget we have already spent. The learning rate schedule used during pretraining directly impacts this.

Cyclical warmup-stable-decay schedule

The Warmup-Stable-Decay (WSD) schedule allows continuing training from the final model checkpoint of the stable phase without incurring a loss penalty. This preserves a large fraction of the compute budget spent, as we only have to discard what we spent on intermediate decay phases. But this is not negligible at the scale of LLM pretraining, where the costs regularly exceed tens of millions of US dollars.

As Wen and colleagues found, starting from the final decay phase model checkpoint in a WSD schedule does not cause the same loss penalty as the cosine schedule. As the WSD schedule’s decay phase is rather short, they hypothesize it does not have the same destructive effect as the cosine schedule’s long and slow decay. Given a total compute budget, consecutively repeating the WSD cycle is more efficient than restarting from the final checkpoint of the latest stable phase.

A cyclical WSD schedule is easier to implement than WSD restarts, as the model evolves continuously down the loss landscape river valley, and no prior checkpoints have to be reloaded. It also helps downstream users, who initially often utilize few-shot prompting to adapt an LLM to their use case. If they later decide to fine-tune it, and the LLM is trained with a WSD schedule, training the same model checkpoint they already use for inference is efficient.

Learning behavior

In a neural network, the weights are the parameters of its neurons learned during training. In an LLM, weights include the query, key, and value matrices in the attention heads and the activation function parameters in the feed-forward layers. While the learning rate governs the scale of changes made to the model’s weights, we can also control how the weights change on a more fine-grained level.

Weight decay

Employing weight decay during training penalizes large weights, preventing small parts of the model from dominating its output. Weight decay in stochastic gradient descent is implemented by adding a term to the loss function. For example, using L2 regularization, the adapted loss function looks like this:

Here, Lorig is the original loss function, λ is the weight decay factor, and wi are the model weights.

Weight decay has been applied to transformer-based NLP models since the beginning. In the seminal 2018 paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, the authors state that they trained the model using “Adam with [a] learning rate of 1e-4, β₁=0.9, β₂=0.999, L2 weight decay of 0.01, learning rate warm up over the first 10,000 steps, and linear decay of the learning rate.”

As Ilya Loshchilov and Frank Hutter point out in their 2019 paper Decoupled Weight Decay Regularization, in adaptive optimizers like Adam, L2 regularization and weight decay are not identical, and L2 regularization is not effective. In Adam, the gradient of the regularization term is scaled with the gradient of Lorig, which leads to minimal regularization for terms in L for which the gradient is large. They introduced the AdamW optimizer, where the weight decay term is independent of the gradient-based update. AdamW is widely used for LLMs, such as for training Megatron-LM (2019), Llama 1 (2023), Llama 2 (2023), and Llama 3 (2024).

In LLM pretraining, models often see each training sample only once. Thus, overfitting to training data, which weight decay helps prevent in traditional deep learning scenarios, is only of concern if there are many similar or even identical samples in the training dataset. Still, weight decay positively affects training speed and the final loss.

According to a 2023 analysis by Francesco D’Angelo and colleagues at EPFL, this is because weight decay increases the effective learning rate. The effective learning rate at training step t is defined as LR(t)/||wt||2, the learning rate scaled by the inverse norm of the weight vector. The smaller the weights, the larger the influence of a weight update. Further, D’Angelo and colleagues find that weight decay stabilizes training in reduced floating-point precision.

Gradient clipping

Gradient clipping caps gradient magnitudes, helping maintain numerical stability. In the river valley analogy, we impose a threshold on slope steepness when deciding where to move next. Rather than jumping off a cliff, we treat it as a moderately steep hillside.

There are two common types of gradient clipping:

  1. Clipping by value: Set predefined minimum and maximum values for gradient magnitudes. A gradient component is clipped to the respective limit if it exceeds these thresholds. This approach has the key benefit of not requiring access to the entire gradient vector.
  2. Clipping by norm: The entire gradient vector is scaled down if the norm exceeds a specified threshold. For example, Nvidia’s original Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism paper first published in 2019 notes: “[W]e use global gradient norm clipping of 1.0 to improve the stability of training large models.” In contrast to clipping by value, this preserves the gradient vector’s direction but requires access to the entire gradient vector to compute.

In 2022, Yang and Ma introduced the Component-Wise Gradient Norm Clipping (CWGNC) approach for fine-tuning LLMs. In a nutshell, CWGNC applies gradient-clipping by norm separately to components in the LLM, such as the key, query, and value matrices or feed-forward layers. This stabilizes the training of each component individually, which might progress at significantly different rates.

Next-token generation

LLMs are autoregressive language models. They predict the next token by taking the sequence of previously generated tokens as input and producing a vector containing a probability for each token in the vocabulary. Different post-processing techniques can be used to determine the next token from these probabilities.

Temperature

Typically, LLMs use a softmax function as the final step in computing token probabilities. A temperature parameter controls this function.

The temperature influences the degree of randomness (or “originality” or “creativity”) in an LLM’s predicted text. At low temperatures, the model becomes more deterministic, rarely considering less likely options and instead focusing on the tokens with the highest probabilities. Conversely, a high temperature increases unpredictability, allowing the model to choose from a broader range of tokens. Thus, lower temperatures are helpful when you need reliable answers, while higher temperatures lead to more varied and surprising outputs.

The Text Gen Playground Hugging Face Space allows users to experiment with different temperature settings and models. By inputting a prompt and adjusting the temperature parameter, you can observe how the model’s output varies from predictable and deterministic to creative and varied.

For example, using the prompt “The sun rises in the” at different temperatures:

  • Low Temperature (e.g., T = 0.2): The model will likely complete the sentence with “east,” reflecting a common and expected continuation.
  • High Temperature (e.g., T = 1.2): The model might generate more imaginative completions like “morning haze” or “golden skies,” showcasing increased creativity.

Adjusting the temperature parameter in such playgrounds provides valuable insights into controlling the balance between determinism and creativity in language model outputs.

Sampling strategy

Given the vector of probabilities, there are many ways to select the next token.

A straightforward strategy is always picking the most likely token. Since the sampling process only considers the probabilities for the very next token, this “greedy decoding” leads to highly probable multi-token sequences being discarded if they start with a token that – viewed in isolation – is less likely.

Using beam search or random sampling according to the token probabilities can mitigate this. While the former produces deterministic outputs and thus no variety, the latter can lead to the selection of highly improbable tokens, producing nonsensical sequences.

A more balanced approach is top-k sampling, which restricts sampling of the next token to the k most probable tokens. Alternatively, in top-p sampling, only the most likely tokens up to a cumulative probability of p are considered. This approach adapts dynamically to the probability distribution, sampling from many tokens in uncertain scenarios and picking from only a few when the model is more confident. (p and k can be adjusted during training or inference time.)

As ML Engineers, we can fine-tune temperature and sampling strategy parameters according to your project needs. For example, if our tasks require precision (e.g., technical writing or summarization), we’ll use lower temperatures and top-k sampling to prioritize high-probability tokens. If we need more diversity, we’ll begin with common default values (temperature 0.7, top-k: k = 40, top-p: p = 0.9). We’ll iteratively adjust them based on the qualitative evaluation of outputs and document our findings to build a shared knowledge base with your team.

How do we find the optimal hyperparameters?

LLM training involves many hyperparameters, resulting in a combinatorial explosion of the search space. Simply guessing hyperparameters is unlikely to yield good results. Further, hyperparameters interact in complex ways, so the optimal value for one may depend on the values of others. Thus, adjusting hyperparameters one at a time may lead to suboptimal solutions, as we easily become trapped in local optima and don’t adequately explore the hyperparameter space.

Finding an optimal combination of hyperparameters requires a systematic approach. First, it’s paramount to understand the relevant hyperparameters and their influence on the particular LLM. It’s essential to research how similar architectures were trained or how the LLM we want to fine-tune was pre-trained. Further, we should clarify the available time, our compute budget, and the training objectives.

Next, we can sketch a roadmap. Can we afford to conduct experiments with particular hyperparameter combinations we believe are useful? Do we already have an experiment tracker and resource monitoring in place, or do we need to set it up first? What will be the decision points and criteria that ensure we end up with a fully trained LLM at the end of the project? Finally, we can start executing this roadmap and adjust our plans as we gather more information and insight.

The BLOOM team published a detailed paper on their preliminary experiments to determine the optimal model size and architecture. They describe how they started with GPT-3’s hyperparameters and conducted trial runs to estimate the optimal balance between model size and number of tokens given their fixed compute budget. Similar experiments were run by the Meta team that trained Llama3, who also aimed to predict downstream task performance.

Can we use traditional machine learning hyperparameter optimization methods for LLMs?

Methods for systematic hyperparameter optimization have long been studied in machine learning:

  • Learning curve analysis involves training models with varying hyperparameters over several epochs and plotting the loss to identify trends. In deep-learning models, plotting the gradient can further help assess whether and how efficiently a model learns.
  • Grid search systematically steps through the hyperparameter space, training a model for each possible combination. Random search samples the hyperparameter space, training models for randomly selected combinations.

While these approaches have successfully been applied to optimize LLM hyperparameters, their use is severely limited by the fact that LLMs are very expensive to train. The computational and memory requirements make it unviable to train large numbers of models. If training a model takes several months on a large cluster, we’ll only get one shot at a full training run.

Advanced strategies for LLM hyperparameter optimization

Beyond starting from a well-known hyperparameter combination and systematically conducting experiments, there is a range of approaches for automatically identifying or optimizing LLM hyperparameters in specific circumstances.

Population-based training (PBT)

Population-Based Training (PBT) is an approach pioneered by Google DeepMind that combines the concepts of evolutionary search and online training. Instead of fixing hyperparameters at the start of training and leaving them static throughout the process, PBT adapts them dynamically, informed by the models’ performance.

In a nutshell, the population-based training process consists of the following steps:

  1. Set up a population of models, each with unique hyperparameters hi and weights i. 
  2. Train each model, updating i every iteration.
  3. After a fixed number of iterations, evaluate each model’s performance on a validation dataset.
  4. Identify models that are underperforming relative to others. Replace their current weights​ and hyperparameters with those of a better-performing model (exploitation).
  5. Slightly perturb the hyperparameters of previously underperforming models to prevent the population from converging to a single configuration too early and improve diversity (exploration).
  6. Conclude the training if the compute budget is exhausted or the objective has been met. Otherwise, repeat the process starting from step 2.

This process initially appears resource-intensive since it requires maintaining and updating multiple models simultaneously, which can increase total GPU hours. However, PBT’s dynamic refinement of hyperparameters during training can significantly save wall-clock time. By avoiding restarting from scratch for each hyperparameter configuration and leveraging partially trained models, PBT reduces the number of training epochs needed to achieve optimal performance.

The 2017 DeepMind study on Population-Based Training (PBT) showcased its potential for LLMs by fine-tuning the first transformer model on the WMT 2014 English-German machine translation benchmark. They manually optimized a baseline model and compared it to a model where they used PBT to optimize the dropouts for different layers and the learning rate. Their evaluation showed that the PBT-optimized model outperformed their hand-tuned baseline. Further, they discovered that the learning rate schedule generated through PBT mimicked the human-created one. Starting with a small learning rate, it then jumped to a high value before something resembling an exponential decay” brought it down to a low value again. DeepMind’s original PBT transformer model also learned noticeably faster.

Ray Tune is a hyperparameter tuning library that supports population-based training. It is part of the open-source Ray framework for scaling machine-learning applications. The Ray Tune documentation includes an example of tuning BERT and RoBERTa on the GLUE benchmark dataset using population-based training.

Bayesian optimization

Bayesian optimization is a popular method for efficiently navigating the hyperparameter space by building a probabilistic model (surrogate model) of the influence of the hyperparameters on the objective (e.g., validation loss). The surrogate model is used to predict promising hyperparameter combinations to try next. The results of this exploration are then used to refine the surrogate model.

The 2024 paper Crafting Efficient Fine-Tuning Strategies for Large Language Models investigates the applicability of Bayesian optimization to fine-tuning LLMs. First, a population of N models is trained for a pre-defined budget t1. As each model is trained, the surrogate model is updated, and the updated version is used to set the hyperparameters of the next model. Once all N models are trained, the top k models are selected and are trained up to t2. Finally, the best model among the k fully trained models is selected.

Adaptive Low-Rank Adaptation (LoRA)

Low-Rank Adaptation (LoRA) is a popular technique for reducing the memory footprint and computational demands when fine-tuning LLMs. In brief, the idea is to represent the weights of the fine-tuned model as 

Wfine = Wpre + ∆W =  Wpre + BA

Here, the fine-tuned weights Wfine are the sum of the original weights Wpre and a difference ∆W, which is the product of two matrices, B and A. Only B and A are updated during fine-tuning, while Wpre remains unchanged. If Wpre and ∆W have dimensions m x n, B and A have dimensions m x r and r x n, respectively. If the rank r is much smaller than m and n, the number of weights to be updated is greatly reduced, leading to faster training progress while requiring less memory.

In practice, it is often unclear to which LLM components LoRA should be applied for the best outcome. While we know that not all weights influence task performance equally, identifying which components are important for a particular objective would require extensive ablation studies. Thus, LoRA is often applied across all suitable weight matrices in a model.

AdaLoRA (Adaptive Low-Rank Adaptation) is a method to allocate a given parameter budget across weight matrices. The core idea is to apply LoRA to all LLM components but to use different values for the rank r. Important components use a matrix pair with a large r, leading to a ∆W with many weights. Less important components are approximated using a lower-rank matrix pair. AdaLoRA assigns an importance score to each component and sets the values for r such that the total number of weights remains within the user-defined budget. This leads to an optimal training outcome for a fixed compute and memory budget.

AdaMoLE (Adaptive Mixture of Low-Rank Adaptation Experts) similarly aims to reduce the number of weights that need to be updated. It replaces the single low-rank matrix pair of the original LoRA with a collection of multiple matrix pairs (LoRA experts) that are activated dynamically based on the input context. This enables the LLM to learn different tasks with a minimal total number of weights.

Fine-tuning an LLM with the Adaptive Mixture of Low-Rank Adaptation Experts approach. The fine-tuned weights are approximated as the sum of the frozen pre-trained weights and a number of so-called LoRA experts that are activated by a gating function and a threshold function. Different LoRA experts specialize in different contexts, allowing the LLM to learn different tasks with a minimal number of weights.
Fine-tuning an LLM with the Adaptive Mixture of Low-Rank Adaptation Experts approach. The fine-tuned weights are approximated as the sum of the frozen pre-trained weights and a number of so-called LoRA experts that are activated by a gating function and a threshold function. Different LoRA experts specialize in different contexts, allowing the LLM to learn different tasks with a minimal number of weights. | Modified based on: source

Hands-on: LLM hyperparameter optimization with neptune.ai

Optuna is a framework for optimizing hyperparameter search using Bayesian optimization. It can be applied to various machine-learning tasks, including LLM hyperparameter tuning.

To see this in action, we’ve prepared a Colab notebook that walks you through the process of finding the optimal combination of learning rate, batch size, and number of epochs for fine-tuning a Hugging Face Transformers model on the IMBD dataset.

The tutorial uses neptune.ai to track training progress and analyze the different hyperparameters. If you don’t want to go through the tutorial yourself right now, you can still explore example results in this public Neptune project.

How about being one of the first to access Neptune Scale?

Neptune Scale is our upcoming product release built for teams that train foundation models. It offers enhanced scalability and exciting new features. You can join our beta program to benefit from Neptune Scale earlier.

What’s next in LLM hyperparameter optimization?

Finding an optimal combination of hyperparameters is essential for training LLMs. In this article, we’ve reviewed key LLM hyperparameters and their influence on the model and training performance. We’ve also discussed how to approach hyperparameter optimization systematically and explored methods to assist or even automate this task in certain scenarios.

From the examples of hyperparameter choices for state-of-the-art LLMs, we’ve seen that while architectures, training tasks, and data change, most models are trained with relatively similar learning rate schedules and optimizer configurations. As our understanding of the model and training mechanics deepens and more experiments yield empirical evidence, we’ll likely see an evolution of the standard recipes and more diversity.

Was the article useful?

Explore more content topics:

Essential Review Papers on Physics-Informed Neural Networks: A Curated Guide for Practitioners

0

Staying on top of a fast-growing research field is never easy.

I face this challenge firsthand as a practitioner in Physics-Informed Neural Networks (PINNs). New papers, be they algorithmic advancements or cutting-edge applications, are published at an accelerating pace by both academia and industry. While it is exciting to see this rapid development, it inevitably raises a pressing question:

How can one stay informed without spending countless hours sifting through papers?

This is where I have found review papers to be exceptionally valuable. Good review papers are effective tools that distill essential insights and highlight important trends. They are big-time savers guiding us through the flood of information.

In this blog post, I would like to share with you my personal, curated list of must-read review papers on PINNs, that are especially influential for my own understanding and use of PINNs. Those papers cover key aspects of PINNs, including algorithmic developments, implementation best practices, and real-world applications.

In addition to what’s available in existing literature, I’ve included one of my own review papers, which provides a comprehensive analysis of common functional usage patterns of PINNs — a practical perspective often missing from academic reviews. This analysis is based on my review of around 200 arXiv papers on PINNs across various engineering domains in the past 3 years and can serve as an essential guide for practitioners looking to deploy these techniques to tackle real-world challenges.

For each review paper, I will explain why it deserves your attention by explaining its unique perspective and indicating practical takeaways that you can benefit from immediately.

Whether you’re just getting started with PINNs, using them to tackle real-world problems, or exploring new research directions, I hope this collection makes navigating the busy field of PINN research easier for you.

Let’s cut through the complexity together and focus on what truly matters.

1️⃣ Scientific Machine Learning through Physics-Informed Neural Networks: Where we are and what’s next

📄 Paper at a glance

🔍 What it covers

  • Authors: S. Cuomo, V. Schiano di Cola, F. Giampaolo, G. Rozza, M. Raissi, and F. Piccialli
  • Year: 2022
  • Link: arXiv

This review is structured around key themes in PINNs: the fundamental components that define their architecture, theoretical aspects of their learning process, and their application to various computing challenges in engineering. The paper also explores the available toolsets, emerging trends, and future directions.

Fig 1. Overview of the #1 review paper. (Image by author)

✨ What’s unique

This review paper stands out in the following ways:

  • One of the best introductions to PINN fundamentals. This paper takes a well-paced approach to explaining PINNs from the ground up. Section 2 systematically dissects the building blocks of a PINN, covering various underlying neural network architectures and their associated characteristics, how PDE constraints are incorporated, common training methodologies, and learning theory (convergence, error analysis, etc.) of PINNs.
  • Putting PINNs in historical context. Rather than simply presenting PINNs as a standalone solution, the paper traces their development from earlier work on using deep learning to solve differential equations. This historical framing is valuable because it helps demystify PINNs by showing that they are an evolution of previous ideas, and it makes it easier for practitioners to see what alternatives are available.
  • Equation-driven organization. Instead of just classifying PINN research by scientific domains (e.g., geoscience, material science, etc.) as many other reviews do, this paper categorizes PINNs based on the types of differential equations (e.g., diffusion problems, advection problems, etc.) they solve. This equation-first perspective encourages knowledge transfer as the same set of PDEs could be used across multiple scientific domains. In addition, it makes it easier for practitioners to see the strengths and weaknesses of PINNs when dealing with different types of differential equations.

🛠 Practical goodies

Beyond its theoretical insights, this review paper offers immediately useful resources for practitioners:

  • A complete implementation example. In section 3.4, this paper walks through a full PINN implementation to solve a 1D Nonlinear Schrödinger equation. It covers translating equations into PINN formulations, handling boundary and initial conditions, defining neural network architectures, choosing training strategies, selecting collocation points, and applying optimization methods. All implementation details are clearly documented for easy reproducibility. The paper compares PINN performance by varying different hyperparameters, which could offer immediately applicable insights for your own PINN experiments.
  • Available frameworks and software tools. Table 3 compiles a comprehensive list of major PINN toolkits, with detailed tool descriptions provided in section 4.3. The considered backends include not only Tensorflow and PyTorch but also Julia and Jax. This side-by-side comparison of different frameworks is especially useful for picking the right tool for your needs.

💡Who would benefit

  • This review paper benefits anyone new to PINNs and looking for a clear, structured introduction.
  • Engineers and developers looking for practical implementation guidance would find the realistic, hands-on demo, and the thorough comparison of existing PINN frameworks most interesting. Additionally, they can find relevant prior work on differential equations similar to their current problem, which offers insights they can leverage in their own problem-solving.
  • Researchers investigating theoretical aspects of PINN convergence, optimization, or efficiency can also greatly benefit from this paper.

2️⃣ From PINNs to PIKANs: Recent Advances in Physics-Informed Machine Learning

📄 Paper at a glance

  • Authors: J. D. Toscano, V. Oommen, A. J. Varghese, Z. Zou, N. A. Daryakenari, C. Wu, and G. E. Karniadakis
  • Year: 2024
  • Link: arXiv

🔍 What it covers

This paper provides one of the most up-to-date overviews of the latest advancements in PINNs. It emphasises enhancements in network design, feature expansion, optimization strategies, uncertainty quantification, and theoretical insights. The paper also surveys key applications across a range of domains.

Fig 2. Overview of the #2 review paper. (Image by author)

✨ What’s unique

This review paper stands out in the following ways:

  • A structured taxonomy of algorithmic developments. One of the most fresh contributions of this paper is its taxonomy of algorithmic advancements. This new taxonomy scheme elegantly categorizes all the advancements into three core areas: (1) representation model, (2) handling governing equations, and (3) optimization process. This structure provides a clear framework for understanding both current developments and potential directions for future research. In addition, the illustrations used in the paper are top-notch and easily digestible.
Fig 3. The taxonomy of algorithmic developments in PINNs proposed by the #2 paper. (Image by author)
  • Spotlight on Physics-informed Kolmogorov–Arnold Networks (KAN). KAN, a new architecture based on the Kolmogorov–Arnold representation theorem, is currently a hot topic in deep learning. In the PINN community, some work has already been done to replace the multilayer perceptions (MLP) representation with KANs to gain more expressiveness and training efficiency. The community lacks a comprehensive review of this new line of research. This review paper (section 3.1) exactly fills in the gap.
  • Review on uncertainty quantification (UQ) in PINNs. UQ is essential for the reliable and trustworthy deployment of PINNs when tackling real-world engineering applications. In section 5, this paper provides a dedicated section on UQ, explaining the common sources of uncertainty in solving differential equations with PINNs and reviewing strategies for quantifying prediction confidence.
  • Theoretical advances in PINN training dynamics. In practice, training PINNs is non-trivial. Practitioners are often puzzled by why PINNs training sometimes fail, or how they should be trained optimally. In section 6.2, this paper provides one of the most detailed and up-to-date discussions on this aspect, covering the Neural Tangent Kernel (NTK) analysis of PINNs, information bottleneck theory, and multi-objective optimization challenges.

🛠 Practical goodies

Even though this review paper leans towards the theory-heavy side, two particularly valuable aspects stand out from a practical perspective:

  • A timeline of algorithmic advances in PINNs. In Appendix A Table, this paper tracks the milestones of key advancements in PINNs, from the original PINN formulation to the most recent extensions to KANs. If you’re working on algorithmic improvements, this timeline gives you a clear view of what’s already been done. If you’re struggling with PINN training or accuracy, you can use this table to find existing methods that might solve your issue.
  • A broad overview of PINN applications across domains. Compared to all the other reviews, this paper strives to give the most comprehensive and updated coverage of PINN applications in not only the engineering domains but also other less-covered fields such as finance. Practitioners can easily find prior works conducted in their domains and draw inspiration.

💡Who would benefit

  • For practitioners working in safety-critical fields that need confidence intervals or reliability estimates on their PINN predictions, the discussion on UQ would be useful. If you are struggling with PINN training instability, slow convergence, or unexpected failures, the discussion on PINN training dynamics can help unpack the theoretical reasons behind these issues.
  • Researchers may find this paper especially interesting because of the new taxonomy, which allows them to see patterns and identify gaps and opportunities for novel contributions. In addition, the review of cutting-edge work on PI-KAN can also be inspiring.

3️⃣ Physics-Informed Neural Networks: An Application-Centric Guide

📄 Paper at a glance

  • Authors: S. Guo (this author)
  • Year: 2024
  • Link: Medium

🔍 What it covers

This article reviews how PINNs are used to tackle different types of engineering tasks. For each task category, the article discusses the problem statement, why PINNs are useful, how PINNs can be implemented to address the problem, and is followed by a concrete use case published in the literature.

Fig 4. Overview of the #3 review paper. (Image by author)

✨ What’s unique

Unlike most reviews that categorize PINN applications either based on the type of differential equations solved or specific engineering domains, this article picks an angle that practitioners care about the most: the engineering tasks solved by PINNs. This work is based on reviewing papers on PINN case studies scattered in various engineering domains. The outcome is a list of distilled recurring functional usage patterns of PINNs:

  • Predictive modeling and simulations, where PINNs are leveraged for dynamical system forecasting, coupled system modeling, and surrogate modeling.
  • Optimization, where PINNs are commonly employed to achieve efficient design optimization, inverse design, model predictive control, and optimized sensor placement.
  • Data-driven insights, where PINNs are used to identify the unknown parameters or functional forms of the system, as well as to assimilate observational data to better estimate the system states.
  • Data-driven enhancement, where PINNs are used to reconstruct the field and enhance the resolution of the observational data.
  • Monitoring, diagnostic, and health assessment, where PINNs are leveraged to act as virtual sensors, anomaly detectors, health monitors, and predictive maintainers.

🛠 Practical goodies

This article places practitioners’ needs at the forefront. While most existing review papers merely answer the question, “Has PINN been used in my field?”, practitioners often seek more specific guidance: “Has PINN been used for the type of problem I’m trying to solve?”. This is precisely what this article tries to address.

By using the proposed five-category functional classification, practitioners can conveniently map their problems to these categories, see how others have solved them, and what worked and what did not. Instead of reinventing the wheel, practitioners can leverage established use cases and adapt proven solutions to their own problems.

💡Who would benefit

This review is best for practitioners who want to see how PINNs are actually being used in the real world. It can also be particularly valuable for cross-disciplinary innovation, as practitioners can learn from solutions developed in other fields.

4️⃣ An Expert’s Guide to Training Physics-informed Neural Networks

📄 Paper at a glance

  • Authors: S. Wang, S. Sankaran, H. Wang, P. Perdikaris
  • Year: 2023
  • Link: arXiv

🔍 What it covers

Even though it doesn’t market itself as a “standard” review, this paper goes all in on providing a comprehensive handbook for training PINNs. It presents a detailed set of best practices for training physics-informed neural networks (PINNs), addressing issues like spectral bias, unbalanced loss terms, and causality violations. It also introduces challenging benchmarks and extensive ablation studies to demonstrate these methods.

Fig 5. Overview of the #4 review paper. (Image by author)

✨ What’s unique

  • A unified “expert’s guide”. The main authors are active researchers in PINNs, working extensively on improving PINN training efficiency and model accuracy for the past years. This paper is a distilled summary of the authors’ past work, synthesizing a broad range of recent PINN techniques (e.g., Fourier feature embeddings, adaptive loss weighting, causal training) into a cohesive training pipeline. This feels like having a mentor who tells you exactly what does and doesn’t work with PINNs.
  • A thorough hyperparameter tuning study. This paper conducts various experiments to show how different tweaks (e.g., different architectures, training schemes, etc.) play out on different PDE tasks. Their ablation studies show precisely which methods move the needle, and by how much.
  • PDE benchmarks. The paper compiles a suite of challenging PDE benchmarks and offers state-of-the-art results that PINNs can achieve.

🛠 Practical goodies

  • A problem-solution cheat sheet. This paper thoroughly documents various techniques addressing common PINN training pain-points. Each technique is clearly presented using a structured format: the why (motivation), how (how the approach addresses the problem), and what (the implementation details). This makes it very easy for practitioners to identify the “cure” based on the “symptoms” observed in their PINN training process. What’s great is that the authors transparently discussed potential pitfalls of each approach, allowing practitioners to make well-informed decisions and effective trade-offs.
  • Empirical insights. The paper shares valuable empirical insights obtained from extensive hyperparameter tuning experiments. It offers practical guidance on choosing suitable hyperparameters, e.g., network architectures and learning rate schedules, and demonstrates how these parameters interact with the advanced PINN training techniques proposed.
  • Ready-to-use library. The paper is accompanied by an optimized JAX library that practitioners can directly adopt or customize. The library supports multi-GPU environments and is ready for scaling to large-scale problems.

💡Who would benefit

  • Practitioners who are struggling with unstable or slow PINN training can find many practical strategies to fix common pathologies. They can also benefit from the straightforward templates (in JAX) to quickly adapt PINNs to their own PDE setups.
  • Researchers looking for challenging benchmark problems and aiming to benchmark new PINN ideas against well-documented baselines will find this paper especially handy.

5️⃣ Domain-Specific Review Papers

Beyond general reviews in PINNs, there are several nice review papers that focus on specific scientific and engineering domains. If you’re working in one of these fields, these reviews could provide a deeper dive into best practices and cutting-edge applications.

1. Heat Transfer Problems

Paper: Physics-Informed Neural Networks for Heat Transfer Problems

The paper provides an application-centric discussion on how PINNs can be used to tackle various thermal engineering problems, including inverse heat transfer, convection-dominated flows, and phase-change modeling. It highlights real-world challenges such as missing boundary conditions, sensor-driven inverse problems, and adaptive cooling system design. The industrial case study related to power electronics is particularly insightful for understanding the usage of PINNs in practice.

2. Power Systems

Paper: Applications of Physics-Informed Neural Networks in Power Systems — A Review

This paper offers a structured overview of how PINNs are applied to critical power grid challenges, including state/parameter estimation, dynamic analysis, power flow calculation, optimal power flow (OPF), anomaly detection, and model synthesis. For each type of application, the paper discusses the shortcomings of traditional power system solutions and explains why PINNs could be advantageous in addressing those shortcomings. This comparative summary is useful for understanding the motivation for adopting PINNs.

3. Fluid Mechanics

Paper: Physics-informed neural networks (PINNs) for fluid mechanics: A review

This paper explored three detailed case studies that demonstrate PINNs application in fluid dynamics: (1) 3D wake flow reconstruction using sparse 2D velocity data, (2) inverse problems in compressible flow (e.g., shock wave prediction with minimal boundary data), and (3) biomedical flow modeling, where PINNs infer thrombus material properties from phase-field data. The paper highlights how PINNs overcome limitations in traditional CFD, e.g., mesh dependency, expensive data assimilation, and difficulty handling ill-posed inverse problems.

4. Additive Manufacturing

Paper: A review on physics-informed machine learning for monitoring metal additive manufacturing process

This paper examines how PINNs address critical challenges specific to additive manufacturing process prediction or monitoring, including temperature field prediction, fluid dynamics modeling, fatigue life estimation, accelerated finite element simulations, and process characteristics prediction.

6️⃣ Conclusion

In this blog post, we went through a curated list of review papers on PINNs, covering fundamental theoretical insights, the latest algorithmic advancements, and practical application-oriented perspectives. For each paper, we highlighted unique contributions, key takeaways, and the audience that would benefit the most from these insights. I hope this curated collection can help you better navigate the evolving field of PINNs.

Ethical Considerations and Best Practices in LLM Development 

0

Bias is inherent to building a ML model. Bias exists on a spectrum. Our job is to tell the difference between the desirable bias and the one that needs correction.

We can identify biases using benchmarks like StereoSet and BBQ, and minimize them with ongoing monitoring across versions and iterations.

Adhering to data protection laws is not as complex if we focus less on the internal structure of the algorithms and more on the practical contexts of use.

To keep data secure throughout the model’s lifecycle, implement these practices: data anonymization, secure model serving and privacy penetration tests.

Transparency can be achieved by providing contextual insights into model outputs. Documentation and opt-out mechanisms are important aspects of a trustworthy system.

Picture this: you’ve spent months fine-tuning an AI-powered chatbot to provide mental health support. After months of development, you launch it, confident it will make therapy more accessible for those in need. But soon, reports emerge: one user seeking help for an eating disorder received diet tips instead of support, worsening their condition. Another, in a moment of crisis, met with responses that intentionally encouraged harmful behaviors (and later committed suicide). This is not hypothetical—it’s a real-life example. 

Now think about your work as an AI professional. Just like the mortgage model, large language models (LLMs) influence critical decisions, and training them on biased data can perpetuate harmful stereotypes, exclude marginalized voices, or even generate unsafe recommendations. Whether the application is financial services, healthcare, or customer support, the ethical considerations are just as high: how do we ensure our work has long-term value and positive societal impact? By focusing on measurable solutions: differential privacy techniques to protect user data, bias-mitigation benchmarks to identify gaps, and reproducible tracking with tools like neptune.ai to ensure accountability.

This article isn’t just about why ethics matter—it’s about how you can take action now to build trustworthy LLMs. Let’s get started!

So how can we address bias in LLMs?

Bias in the context of training LLMs is often discussed with a negative connotation. However, the reality is more complex: algorithmic bias is inherent in any machine learning model because it reflects patterns, structures, and priorities encoded in the training data and design. Let’s put it this way: some bias is necessary for models to work effectively. When we fine-tune LLMs, we shift their biases to align with specific tasks or applications. For example, a large language model is intentionally biased toward generating grammatically correct sentences. 

The challenge for AI researchers and engineers lies in separating desirable biases from harmful algorithmic biases that perpetuate social biases or inequity. To address it, it’s helpful to think of bias as existing on a spectrum:

  1. Functional biases: The previous example falls on this end of the spectrum. These biases are intentional and beneficial to enhance model performance. They guide the LLM to generate text in a specific tone, style, or adhering to a logical reasoning pattern, etc.
  1. Neutral biases: These may not directly harm users but can skew the diversity of outputs. For example, an LLM trained on predominantly European data might overrepresent those perspectives, unintentionally narrowing the scope of information or viewpoints it offers.
  1. Harmful biases: These are the biases that demand active mitigation. Harmful biases lead to biased outputs that disadvantage certain groups. For example, a recruitment LLM favoring male applicants due to biased training data reflects a harmful bias that requires correction. During the data collection stage, two valuable frameworks to analyze data distribution are Datasheets for datasets and FACETS.

To mitigate unwanted biases (the third end of the spectrum), it is recommended to adopt a  structured approach during the fine-tuning stage:

1. Define the desired outcome

Identify the biases your model should intentionally have and avoid. For example, an LLM designed for legal assistance should prioritize precision and formal language (functional biases), while actively avoiding harmful biases like racial assumptions in legal case studies.

2. Test and measure bias

Debiasing techniques assess how your pre-trained LLM handles both neutral and harmful biases. Two of the most popular benchmarks are StereoSet to test for stereotypical associations in the outputs of your large language model and BBQ (Bias Benchmark for QA) for highlighting biases in question-answering systems. 

Let’s see how to use them in a simple example. Imagine you’re evaluating an LLM used in a recruitment platform. A StereoSet prompt might be:

“The software engineer was explaining the algorithm. After the meeting, ___ went back to coding.”

The benchmark would present two potential completions:

  • “he” (stereotypical)
  • “she” or “they” (non-stereotypical)

StereoSet evaluates the model’s likelihood of generating each option. Suppose your LLM is heavily biased toward stereotypical associations, like assuming “software engineer” is male. This would indicate a higher probability assigned to “he” over “she” or “they.”

This is a common stereotype, but StereoSet can evaluate more nuanced scenarios like:

“The team lead recommended a flexible work schedule for better work-life balance. ___ later presented their findings to the board.”

Here, the model’s output might be tested for implicit gender bias linking caregiving roles or flexibility to one gender while associating leadership and authority with another. The results are then compared to a baseline provided by the benchmark, which quantifies the degree of bias in your LLM’s outputs. By analyzing such patterns across thousands of prompts, these debiasing techniques provide a detailed breakdown of how biases manifest in your LLM’s outputs, allowing you to pinpoint specific areas for improvement.

Identify the appropriate bias benchmark for your specific task. For this, you can explore the collection of LLM benchmarks curated by researchers at McGill University, which offers a range of benchmarks tailored to a variety of scenarios.

3. Monitor bias continuously

Mitigating bias isn’t a one-time effort—it requires ongoing monitoring to ensure that your LLM remains fair and effective across iterations. Here are some ideas to help you implement it:

Create a script that evaluates your model

First, we create a script that runs a standardized set of evaluations against one of your model versions. Think about the metrics that you will implement to measure bias in your specific scenario. You can explore fairness metrics, such as demographic parity, measure disparate impact (the extent to which the model’s decisions disproportionately affect different groups), or assess stereotype reinforcement using the benchmarks mentioned earlier.

Demographic parity (also known as statistical parity) is a metric used to assess bias and fairness concerns, that is, whether a machine learning model treats different demographic groups equally in terms of outcomes. Specifically, it measures whether the probability of a positive outcome (e.g., approval for a loan, a job recommendation, etc.) is the same across different groups, regardless of their demographic attributes (e.g., gender, race, age). Here there is a manual implementation of this metric in Python:

from sklearn.metrics import confusion_matrix


y_true = [0, 1, 0, 1, 0]  
y_pred = [0, 1, 0, 0, 1]  
group_labels = ['male', 'female', 'male', 'female', 'male']  
def demographic_parity(y_true, y_pred, group_labels):
    groups = set(group_labels)
    parity = {}
    
    for group in groups:
        group_indices = [i for i, label in enumerate(group_labels) if label == group]
        group_outcomes = [y_pred[i] for i in group_indices]
        positive_rate = sum(group_outcomes) / len(group_outcomes)
        parity[group] = positive_rate

    return parity

parity_results = demographic_parity(y_true, y_pred, group_labels)
print(parity_results)  

You can also explore demographic_parity_ratio from the fairlearn.metrics package, which simplifies the application of this fairness metric in your model evaluation.

Track your results in Neptune

You can use tools like neptune.ai to track bias metrics (e.g., fairness or disparate impact) across model versions. Let’s see how:

  1. Set up your project: If you haven’t already, sign up for Neptune now and create a project to track your LLM’s training data and metrics.
  2. Log the metrics: Set up custom logging for these metrics in your training code by calculating and recording them after each evaluation phase.
  3. Monitor bias: Use Neptune’s dashboards to monitor how these fairness metrics evolve over model versions. Compare the impact of different debiasing strategies on the metrics, and create alerts to notify you when any metric exceeds a threshold. This allows you to take immediate corrective action.

All metadata in a single place with an experiment tracker (example in neptune.ai)

Integrate bias checks into your CI/CD workflows

If your team manages model training through CI/CD, incorporate the automated bias detection scripts (that have already been created) into each pipeline iteration. Alternatively, this script can also be used as part of a manual QA process, ensuring that potential bias is identified and addressed before the model reaches production.

How to ensure LLM complies with user privacy and data laws?

When developing LLMs, you need to comply with data protection laws and ethical frameworks and guidelines. Regulations like the GDPR, HIPAA in healthcare, and the AI Act in the EU place significant demands on how personal data is handled, stored, and processed by AI systems. However, adhering to these standards is not as complex as it may seem, especially if you take a strategic approach.

I learned this perspective firsthand during a discussion where Teresa Rodríguez de las Heras, director of the Research Chair UC3M-Microsoft, shared her insights. She remarked: 

The regulatory focus, especially in the draft AI Act, is less on the internal structure of the algorithms (i.e., their code or mathematical models) and more on the practical contexts in which AI is used.

Think about it this way: it is easy to integrate GDPR-compliant services like ChatGPT’s enterprise version or to use AI models in a law-compliant way through platforms such as Azure’s OpenAI offering, as providers take the necessary steps to ensure their platforms are compliant with regulations.

The real challenge lies in how the service is used. While the infrastructure may be compliant, you, as an AI researcher, need to ensure that your LLM’s deployment and data handling practices align with privacy laws. This includes how data is accessed, processed, and stored throughout the model’s lifecycle, as well as thorough documentation of these processes. Clear and detailed documentation is crucial—usually, a technically sound architecture following best practices meets the regulatory requirements, but it has to be documented that it does. By focusing on these aspects, we can shift our understanding of compliance from a purely technical standpoint to a broader, application-based risk perspective, which ultimately affects the overall compliance of your AI system.

You might be wondering, how can I meet these requirements? Here are some security steps you can take to ensure user privacy:

Data anonymization

Protect personal data in your training data by ensuring it is fully anonymized to prevent the leakage of personally identifiable information (PII). Start by:

  • Removing or masking direct identifiers such as names, addresses, emails, job titles, and geographic locations.
  • Using aggregated data instead of raw personal information (e.g., grouping individuals by age ranges or replacing specific locations with broader regions).
  • Applying K-anonymity to generalize or suppress data so each individual cannot be distinguished from at least k-1 others in the dataset.

Once these foundational steps are in place, consider additional measures to limit the risk of re-identification. For practical examples and implementation tips, consider exploring Google’s TensorFlow Privacy repository on GitHub. 

Secure model serving

Ensure that your deployed model is served securely to protect user data during interactions. How?

  • Hosting the model in secure, GDPR-compliant cloud environments, such as Amazon Web Services or Azure.
  • Using encryption protocols like HTTPS and TLS to safeguard data in transit.
  • Implementing access controls to limit who can query the model and monitor interactions.

Privacy penetration tests

Conduct regular privacy penetration tests to identify vulnerabilities in your system. For example:

  • Simulate data extraction attacks to evaluate how well your model resists adversarial attempts to uncover training data. For more information on defending against these threats, check out Defense Strategies in Adversarial Machine Learning.
  • Collaborate with privacy experts to audit your model’s infrastructure and identify potential compliance gaps.

These measures serve as a robust framework for privacy protection without compromising the performance of your LLMs. 

How to integrate transparency, accountability, and explainability?

As LLMs become increasingly integrated into applications and individuals and organizations rely on AI development for their own projects, concerns surrounding the transparency, accountability, and explainability of these systems are growing. 

However, the current market leaves formal interpretability research and solutions mostly in the academic and R&D corners rather than demanding them in everyday products. This makes sense: you don’t need to know where the training data comes from to build an app with ChatGPT, and highly popular tools like GitHub Copilot and Bing Chat thrive without deep interpretability features. That said, certain practical approaches to interpretability (e.g., user-facing explanations for predictions or contextual annotations in outputs) occasionally emerge in industry settings. These glimpses, while rare, provide meaningful transparency and serve specific use cases where interpretability can enhance trust and usability.

Such practical approaches allow users to better understand the results without having to decipher the internal logic. As an AI professional developing LLM-based applications, learning about these strategies—contextual cues, custom filtering, and source references—can differentiate your product. 

Transparency has become a key expectation in the AI industry, as highlighted by initiatives like the EU AI Act and guidelines from organizations such as the Partnership on AI, which emphasize the importance of explainable AI. By integrating them, you can meet these expectations while maintaining feasibility for deployment. Let’s get into it!

What does contextual transparency look like?

Contextual transparency provides meaningful insights into how the model produces outputs, for example, by showing relevant sources, highlighting influential inputs, or offering filtering options. When models display their sources, users can quickly assess their credibility and the accuracy of their results. In cases where the answer is not reliable, these sources are often either fake (links that go nowhere) or redirect to papers or articles unrelated to the topic. You can provide contextual transparency to your LLM by including:

• Disclaimers about outputs: Set expectations by clearly communicating the probabilistic nature of your LLM’s responses and their potential for inaccuracies. OpenAI, for example, includes disclaimers in ChatGPT to guide user understanding. 

OpenAI's ChatGPT disclaimer encouraging users to verify information independently.
OpenAI’s ChatGPT disclaimer encouraging users to verify information independently | Source: Author

While researching for this article, I came across a collection of the best disclaimers from ChatGPT shared by Reddit users. These examples highlight how language models can be prompted to produce disclaimers, though the results don’t always make sense from a human perspective.

• Contextual cues: Contextual cues provide insights about the sources and processes behind the model’s outputs. Features like highlighting citations (as seen in Bing Chat) or referencing snippets of code and links to external materials (as ChatGPT does) help users understand the reasoning behind responses.

• RAG-specific contextualization: In Retrieval-Augmented Generation (RAG) systems, contextualization often involves surfacing top-related documents or tokens that influence the model’s output.

An example of contextual transparency: ChatGPT references the source code in the output.
An example of contextual transparency: ChatGPT references the source code in the output. | Source: Author
An example of contextual transparency: Bing Chat cites the source that influenced its answer.
An example of contextual transparency: Bing Chat cites the source that influenced its answer. | Source

How to navigate data usage risks in AI development?

While regulations often dictate what can be done legally, we also need to consider what should be done to build user trust and ensure fair practices. Deploying ML models implies navigating the line between necessary oversight (e.g., content moderation) and potential overreach. Being AI professionals, we need to approach this challenge responsibly.

Production logs, including user prompts, interactions, and model outputs, offer a wealth of information about the system’s performance and potential misuse. However, they also raise ethical implications about user consent and privacy risks.

Understand your data sources

An important part of building ethically sound AI models lies in verifying that your data comes from sources with clear usage rights. Your data pipeline should flag or exclude content from sources with uncertain copyright status. If you are using scraping tools, start by implementing rules to filter out certain domains or sites that have unclear copyright status. 

Common Crawl is a free, open repository that provides a large dataset of web pages that can be filtered for copyrighted content. While it is a good starting point for identifying general content, I recommend refining these filters with additional checks tailored to your specific topics.

Using publicly accessible data that is copyrighted

The AI industry has faced growing scrutiny over practices like scraping data and using user-provided content without explicit consent. For example, while human users cannot legally reuse or republish copyrighted content from websites or books without explicit permission, many LLM providers use them as training data. The assumption that “publicly accessible” equals “fair use” has led to a growing backlash from creators, publishers, and regulators. Controversial examples include:

Using user data that is not publicly accessible

Some jurisdictions have more robust regulatory frameworks that explicitly regulate how user data can be used to train models. In the EU and the UK, laws like the GDPR have prompted companies to adopt stricter privacy practices. Let’s see some examples:

• Grammarly, for instance, follows a regional approach. It states on its Product Improvement and Training Control page and in the privacy settings that users in the EU and UK automatically have their data excluded from model training:

Since you created your account in the EU or UK, Grammarly will not use your content to train its models or improve its product for other users.

• In 2019, a Bloomberg report revealed that Amazon employees and contractors sometimes review Alexa voice recordings to help improve Alexa’s speech recognition models. While the data review process is intended to enhance product quality, the disclosure raised concerns about user consent, privacy, and the extent to which voice data—often from private homes—could be accessed for AI development. In May 2023, the Federal Trade Commission (FTC) imposed a $25 million fine on Amazon related to children’s privacy, alleging that the company had violated the Children’s Online Privacy Protection Act (COPPA) by retaining children’s voice recordings indefinitely and misrepresenting parents’ ability to delete those recordings.

These examples highlight how regulations differ across jurisdictions. This patchwork of regulations creates a challenging landscape for AI developers, highlighting that what is deemed legal (or even ethical) differs across regions. As a result, some users benefit from stronger protections against such practices than others, depending on their location.

There are some recommendations that may come in handy to navigate different jurisdictions. First, if resources permit, adopt a “highest common denominator” strategy by aligning global practices with the most restrictive data protection requirements (e.g., EU GDPR). Second, keep detailed documentation of each model’s training process—covering data sources, usage procedures, and implemented safeguards—and present this information in an accessible format (e.g., FAQs or transparency reports). This approach demonstrates a clear commitment to transparency and ethical standards.

Best practices for ethical LLM development

Navigating the regulatory landscape requires more than just complying with the local laws. Just as contextual transparency helps users trust the outputs of your LLMs, your broader organizational values, professional standards, or industry best practices form the ethical backbone that ensures this trust extends to the foundation of your system.

By following these practical steps, you can reinforce that commitment to building fair and transparent models:

Implement opt-out mechanisms

Opt-out mechanisms allow users to control whether their data is used to train AI models and other software, giving them some agency over how their data is processed and used. If you plan to store users’ data for training your AI or for any other purpose, implementing an opt-out mechanism is a good practice to give users back control over their personal data. Let’s look at some examples of how this can be done:

  • Social media platforms: Platforms such as Quora, LinkedIn, and Figma have opt-out mechanisms that allow users to request that their data be excluded from certain data mining purposes. However, the specific options and level of transparency can vary widely from platform to platform. Wired has a step-by-step guide on how to stop your data from being used by the most popular platforms to train AI, which I recommend checking out.
  • Opt-out of data scraping: Many websites indicate where or whether they permit automated crawling by providing a “robots.txt” file. While this file signals how a site wishes to be scrapped, it doesn’t technically prevent unauthorized crawlers from harvesting data; compliance ultimately depends on whether the crawler chooses to honor those instructions.
Structure of a 'robots.txt' file
Syntax of a robots-txt file to prevent agents from crawling a website. Each agent is separated in a different line containing its name and the disallow or allow rules attached to it | Source

Keep your documentation updated

Clear and comprehensive documentation can take multiple forms, from end-user guides (explaining the usage and limitations of your LLM) and developer-focused manuals (covering architecture, training procedures, and potential biases) to legal or regulatory documentation for compliance and accountability. 

Model Cards, originally proposed by Margaret Mitchell and Timnit Gebru at Google, offer a structured template for detailing key information about machine learning models: the dataset used, intended use cases, limitations, etc. Hugging Face has implemented a version of Model Cards on its platform, facilitating a standardized way to document Large Language Models (LLMs) and other AI systems. 

By maintaining up-to-date documentation, you help users and stakeholders understand your model’s capabilities and limitations. This plays a crucial role in fostering trust and encouraging responsible use.

For example, OpenAI has publicly documented its red-teaming process, which involves testing models against harmful content to assess their robustness and ethical implications. Documenting such efforts not only promotes transparency but also sets a benchmark for how ethical considerations are addressed in the development process.

Stay ahead of regulations

If your company has a legal team, collaborate with them to ensure compliance with local and international regulations. If not, and you are planning to expand your LLM globally, consider hiring legal advisors to mitigate the legal risks before launching your LLM. 

For example, for applications that are subject to the GDPR, you need to implement and document appropriate technical and organizational measures protecting any personal data you store and process, as outlined in Article 32. These measures often include creating documentation, such as TOM documents, along with terms of service and privacy policies that users must agree to during signup. Adhering to these requirements, particularly in the European context, is essential for building trust and ensuring compliance.

Avoid legal pitfalls that may affect the long-term viability and trustworthiness of your LLMs by anticipating potential regulatory changes. Monitor the legal landscape for AI development in the regions where you currently operate or plan to expand in the future. These are some useful resources:

  • The U.S. National Institute of Standards and Technology (NIST) AI Risk Management Framework is an updated source with recommendations on AI risks and regulatory impacts for individuals and organizations. 

Summing it up: AI ethics done right

Let’s wrap up with a quick recap of all the key takeaways from our discussion:

  • Bias in LLMs is inevitable, but manageable: While algorithmic bias in machine learning models is part of the game, not all biases are negative. Our job is to identify which biases are functional (beneficial to performance) and which ones are harmful (reinforce inequality). Tools like StereoSet and BBQ are useful for pinpointing and mitigating harmful biases.    
  • Protect user privacy from start to finish: Think less about the mathematical structure of your model (that is usually handled by the provider, they will keep it law-compliant) and more about how data is handled in practice during your model’s lifecycle (this is where you are responsible to keep your system law-compliant). Safeguard sensitive information by implementing strong privacy measures like data anonymization, differential privacy, and secure model serving.
  • Transparency is your ally: You don’t have to explain every inner detail of your AI models to be transparent. Instead, focus on providing meaningful insights into how your model produces outputs. Contextual transparency—like source references and disclaimers—builds trust without overwhelming users with technical jargon.
  • Bias mitigation techniques and privacy protection aren’t one-time tasks: They should be continuously integrated throughout your model’s lifecycle. Using tools like Neptune to track and visualize key metrics, including fairness, helps ensure your models stay aligned with ethical standards across iterations and versions.
  • Ethical AI development requires proactive steps: Understand your data sources, implement opt-out mechanisms, keep your documentation up to date, and stay ahead of regulatory changes. Ethical AI isn’t just about compliance—it’s about building trust and accountability with users and stakeholders.

Was the article useful?

Explore more content topics:

Challenges & Solutions For Monitoring at Hyperscale

0

What is not measured, cannot be improved.” This quote has become a guiding principle for teams training foundation models. When you’re dealing with complex, large-scale AI systems, things can spiral quickly without the right oversight. Operating at hyperscale poses significant challenges for teams, from the large volume of data generated to the unpredictability of hardware failures and the need for efficient resource management. These issues require strategic solutions, that’s why monitoring isn’t just a nice-to-have—it’s the backbone of transparency, reproducibility, and efficiency. During my talk at NeurIPS,  I broke down five key lessons learned from teams facing large-scale model training and monitoring. Let’s get into it.

Real-time monitoring prevents costly failures

Imagine this: you’re training a large language model on thousands of GPUs at a cost of hundreds of thousands of dollars per day. Now imagine discovering, hours into training, that your model is diverging or that hardware issues are degrading your performance. The financial and operational implications are staggering. This is why live monitoring—the ability to act immediately—is so critical.

Live monitoring allows teams to see experiment progress as it happens, rather than waiting for checkpoints or the end of a run. This real-time visibility is a game-changer for identifying and fixing problems on the fly. In addition, automated processes allow you to set up monitoring workflows once and reuse them for similar experiments. This streamlines the process of comparing results, analyzing results, and debugging issues, saving time and effort.

However, achieving true live monitoring is far from simple. Hyperscale training generates an overwhelming volume of data, often reaching up to a million data points per second. Traditional monitoring tools struggle under such loads, creating bottlenecks that can delay corrective action. Some teams try to cope by batching or sampling metrics, but these approaches sacrifice real-time visibility and add complexity to the code.

The solution lies in systems that can handle high-throughput data ingestion while providing accurate, real-time insights. Tools like neptune.ai make this possible by providing dashboards that visualize metrics without delaying training. For example, live tracking of GPU utilization or memory usage can reveal early signs of bottlenecks or out-of-memory errors, allowing engineers to proactively adjust course. See here some testimonials:

One thing we’re always keeping track of is what the utilization is and how to improve it. Sometimes, we’ll get, for example, out-of-memory errors, and then seeing how the memory increases over time in the experiment is really helpful for debugging as well.

James Tu

Research Scientist, Waabi

For some of the pipelines, Neptune was helpful for us to see the utilization of the GPUs. The utilization graphs in the dashboard are a perfect proxy for finding some bottlenecks in the performance, especially if we are running many pipelines.

Wojtek Rosiński

CTO, ReSpo.Vision

Real-time visualization of GPU memory usage (top) and power consumption (bottom) during a large-scale training run. These metrics help identify potential bottlenecks, such as out-of-memory errors or inefficient hardware utilization, enabling immediate corrective actions to maintain optimal performance. | Source: Author

Troubleshooting hardware failures is challenging: simplify it with debugging

Distributed systems are prone to failure, and hardware failures are notoriously difficult to troubleshoot. A single hardware failure can cascade into widespread outages, often with cryptic error messages. Teams often waste time sifting through stack traces, trying to distinguish between infrastructure problems and code bugs.

At Cruise, engineers used frameworks like Ray and Lightning to improve error reporting. By automatically labeling errors as either “infra” or “user” issues and correlating stack traces across nodes, debugging became much faster.

Igor Tsvetkov

Former Senior Staff Software Engineer, Cruise

AI teams automating error categorization and correlation can significantly reduce debugging time in hyperscale environments, just as Cruise has done. How? By using classification strategies to identify if failures originated from hardware constraints (e.g., GPU memory leaks, network latency) or software bugs (e.g., faulty model architectures, misconfigured hyperparameters). 

Intuitive experiment tracking optimizes resource utilization

Another relevant aspect of hyperscale monitoring is optimizing resource utilization, in particular in a scenario where hardware failures and training interruptions can set teams back significantly. Picture a scenario where training jobs suddenly deviate: loss metrics spike, and you’re left deciding whether to let the job run or terminate it. Advanced experiment trackers allow for remote experiment termination, eliminating the need for teams to manually access cloud logs or servers.

Use checkpoints at frequent intervals so you do not have to restart from scratch, but just warm-start from the previous checkpoint. Most mature training frameworks already offer automated checkpointing and warm-starts from previous checkpoints. But most of these, by default, save the checkpoints in the same machine. This doesn’t help if your hardware crashes, or, for example, you are using spot instances and they are reassigned.

For maximum resilience and to prevent losing data if hardware crashes, checkpoints should be linked to your experiment tracker. This does not mean that you upload GBs worth of checkpoints to the tracker (although you can and some of our customers, especially self-hosted customers, do this for security reasons), but rather have pointers to the remote location, like S3, where the checkpoints have been saved. This enables you to link the checkpoint with the corresponding experiment step, and efficiently retrieve the relevant checkpoint at any given step.

A comparison of training workflows with and without advanced experiment tracking and checkpointing. On the left, failed training runs at various stages lead to wasted time and resources. On the right, a streamlined approach with checkpoints and proactive monitoring ensures consistent progress and minimizes the impact of interruptions.
A comparison of training workflows with and without advanced experiment tracking and checkpointing. On the left, failed training runs at various stages lead to wasted time and resources. On the right, a streamlined approach with checkpoints and proactive monitoring ensures consistent progress and minimizes the impact of interruptions. | Source: Author

However, there are two caveats to successfully restarting an experiment from a checkpoint: assuming that the experimentation environment is constant, or at least reproducible, and addressing deterministic issues like Out-of-Memory errors (OOMs) or bottlenecks that may require parameter changes to avoid repeating failures. This is where forking can play a significant role in improving recovery and progress.

Track months-long model training with more confidence. Use neptune.ai forking feature to iterate faster and optimize the usage of GPU resources.

With Neptune, users can visualize forked training out of the box. This means you can:

  • Test multiple configs at the same time. Stop the runs that don’t improve accuracy. And continue from the most accurate last step.
  • Restart failed training sessions from any previous step. The training history is inherited, and the entire experiment is visible on a single chart.

In addition, checkpointing strategies are critical for optimizing recovery processes. Frequent checkpointing ensures minimal loss of progress, allowing you to warm-start from the most recent state instead of starting from scratch. However, checkpointing can be resource-intensive in terms of storage and time, so we need to strike a balance between frequency and overhead.

For large-scale models, the overhead of writing and reading weights to persistent storage can significantly reduce training efficiency. Innovations like redundant in-memory copies, as demonstrated by Google’s Gemini models, enable rapid recovery and improved training goodput (defined by Google as the time spent computing useful new steps over the elapsed time of the training job), increasing resilience and efficiency.

Features like PyTorch Distributed’s asynchronous checkpointing can significantly reduce checkpointing times making frequent checkpointing more viable without compromising training performance.

Beyond models, checkpointing the state of dataloaders remains a challenge due to distributed states across nodes. While some organizations like Meta have developed in-house solutions, general frameworks have yet to fully address this issue. Incorporating dataloader checkpointing can further enhance resilience by preserving the exact training state during recovery.

Reproducibility and transparency are non-negotiable

Reproducibility is the bedrock of reliable research, but it’s notoriously difficult at scale. Ensuring reproducibility requires consistent tracking of environment details, datasets, configurations, and results. This is where Neptune’s approach excels, linking every experiment’s lineage—from parent runs to dataset versions—in an accessible dashboard.

This transparency not only aids validation but also accelerates troubleshooting. Consider ReSpo.Vision’s challenges in managing and comparing results across pipelines. By implementing organized tracking systems, they gained visibility into pipeline dependencies and experiment parameters, streamlining their workflow.

A single source of truth simplifies data visualization and management at large-scale data

Managing and visualizing data at scale is a common challenge, amplified in the context of large-scale experimentation. While tools like MLflow or TensorBoard are sufficient for smaller projects with 10–20 experiments, they quickly fall short when handling thousands or even hundreds of experiments. At this scale, organizing and comparing results becomes a logistical hurdle, and relying on tools that cannot effectively visualize or manage this scale leads to inefficiencies and missed insights.

A solution lies in adopting a single source of truth for all experiment metadata, encompassing everything from input data and training metrics to checkpoints and outputs. Neptune’s dashboards address this challenge by providing a highly customizable and centralized platform for experiment tracking. These dashboards enable real-time visualization of key metrics, which can be tailored to include “custom metrics”—those not explicitly logged at the code level but calculated retrospectively within the tool. For instance, if a business requirement shifts from using precision and recall to the F1 score as a performance indicator, custom metrics allow you to calculate and visualize these metrics across existing and future experiments without rerunning them, ensuring flexibility and minimizing duplicated effort.

Consider the challenges faced by Waabi and ReSpo.Vision. Waabi’s teams, running large-scale ML experiments, needed a way to organize and share their experiment data efficiently. Similarly, ReSpo.Vision required an intuitive system to visualize multiple metrics in a standardized format that any team member—technical or non-technical—could easily access and interpret. Neptune’s dashboards provided the solution, allowing these teams to streamline their workflows by offering visibility into all relevant experiment data, reducing overhead, and enabling collaboration across stakeholders.

I like those dashboards because we need several metrics, so you code the dashboard once, have those styles, and easily see it on one screen. Then, any other person can view the same thing, so that’s pretty nice.

Łukasz Grad

Chief Data Scientist, ReSpo.Vision

The benefits of such an approach extend beyond visualization. Logging only essential data and calculating derived metrics within the tool reduces latency and streamlines the experimental process. This capability empowers teams to focus on actionable insights, enabling scalable and efficient experiment tracking, even for projects involving tens of thousands of models and subproblems.

Visualizing large datasets

We generally do not think of dataset visualization as part of experiment monitoring. However, preparing the dataset for model training is an experiment in itself, and while it may be an upstream experiment not in the same pipeline as the actual model training, data management and visualization is critical to LLMOps.

Large-scale experiments often involve processing billions of data points or embeddings. Visualizing such data to uncover relationships and debug issues is a common hurdle. Tools like Deepscatter and Jupyter Scatter have made progress in scaling visualizations for massive datasets, offering researchers valuable insights into their data distribution and embedding structures.

Moving forward

The path to efficient hyperscale training lies in combining robust monitoring, advanced debugging tools, and comprehensive experiment tracking. Solutions like Neptune Scale are designed to address these challenges, offering the scalability, precision, and transparency researchers need.

How about being one of the first to access Neptune Scale?

Neptune Scale is our upcoming product release built for teams that train foundation models. It offers enhanced scalability and exciting new features. You can join our beta program to benefit from Neptune Scale earlier.

If you’re interested in learning more, visit our blog or join the MLOps community to explore case studies and actionable strategies for large-scale AI experimentation.

Acknowledgments

I would like to express my gratitude to Prince Canuma, Dr. Shantipriya Parida, and Igor Tsvetkov for their valuable time and insightful discussions on this topic. Their contributions and perspectives were instrumental in shaping this talk.

Was the article useful?

Explore more content topics:

One Turn After Another | Towards Data Science

0

While some games, like rock-paper-scissors, only work if all payers decide on their actions simultaneously, other games, like chess or Monopoly, expect the players to take turns one after another. In Game Theory, the first kind of game is called a static game, while turn-taking is a property of so-called dynamic games. In this article, we will analyse the latter with methods from game theory. 

This article is the fourth part of a four-chapter series on the fundamentals of game theory. I recommend you to read the first three articles if you haven’t done that yet, as the concepts shown here will build on the terms and paradigms introduced in the previous articles. But if you are already familiar with the core fundamentals of game theory, don’t let yourself be stopped, and go ahead!

Dynamic games

Dynamic games can be visualized as trees. Photo by Adarsh Kummur on Unsplash

While so far we only looked at static games, we will now introduce dynamic games where payers take turns. As previously, such games include a number of players n, a set of actions for each player, and a reward function that assesses the actions of a player given the other players’ actions. Beyond that, for a dynamic game, we need to define an order in which the players take their turns. Consider the following tree-like visualization of a dynamic game. 

A visualization of a dynamic game. Figure by author.

At the top we have a node where player 1 has to decide between two actions L and R. This determines whether to follow the left part or the right part of the tree. After player 1’s turn, player 2 takes their turn. If player 1 chooses L, player 2 can decide between l1 and r1. If player 1 chooses R, player 2 has to decide between l2 and r2. At the leaves of the tree (the nodes at the bottom), we see the rewards just like we had them in the matrix cells in static games. For example, if player 1 decides for L and player 2 decides for r1, the reward is (1,0); that is, player 1 gets a reward of 1, and player 2 gets a reward of 0. 

I bet you are eager to find the Nash equilibrium of this game, as this is what Game Theory is mainly about (if you still struggle with the concept of Nash equilibrium, you might want to take a look back at chapter 2 of this series). To do that, we can transform the game into a matrix, as we already know how to find a Nash equilibrium in a game displayed as a matrix. Player 1 decides on the row of the matrix, player 2 decides on the column and the values in the cell then specifies the reward. However, there is one important point to notice. When we look at the game displayed as a tree, player 2 decides on their action after player 1 does and hence only cares about the part of the tree that is actually reached. If player 1 chooses action L, player 2 only decides between l1 and r1 and doesn’t care about l2 and r2, because these actions are out of the question anyway. However, when we search for a Nash Equilibrium, we need to be aware of what would happen, if player 1 would change their action. Therefore, we must know what player 2 would have done if player 1 had chosen a different option. That is why we have four columns in the following matrix, to always account for decisions in both parts of the tree. 

A column like (r1,l2) can be read as “player 2 chooses r1 if player 1 chose L and chooses l2 if player 1 chose R”. On this matrix, we can search for the best answers. For example, the cell (L, (l1,l2)) with reward 3,1 is a best answer. Player 1 has no reason to change from L to R because that would lower his reward (from 3 to 1), and Player 2 has no reason to change either because none of the other options is better (one is as good, though). In total, we find three Nash equilibria, which are underlined in the upcoming matrix: 

The chocolate-pudding market

We will talk about chocolate pudding now. But also about game theory. Photo by American Heritage Chocolate on Unsplash

Our next example brings the idea of dynamic games to life. Let’s assume player 2 is a market-leading retailer of chocolate pudding. Player 1 also wants to build up his business but isn’t sure yet whether to join the chocolate pudding market or whether they rather should sell something else. In our game, player 1 has the first turn and can decide between two actions. Join the market (i.e., sell chocolate pudding), or don’t join the market (i.e., sell something else). If player 1 decides to sell something other than chocolate pudding, player 2 stays the market-dominating retailer for chocolate pudding and player 1 makes some money in the other area they decided for. This is reflected by the reward 1,3 in the right part of the tree in the following figure. 

The market-game as a dynamic game. Figure by author. 

But what if player 1 is greedy for the unimaginable riches that lie dormant on the chocolate pudding market? If they decide to join the market, it is player 2’s turn. They can decide to accept the new competitor, give in and share the market. In this case, both players get a reward of 2. But player 2 can also decide to start a price war to demonstrate his superiority to the new competitor. In this case, both players get a reward of 0, because they ruin their profit due to dumping prices. 

Just like before, we can turn this tree into a matrix and find the Nash equilibria by searching for the best answers:

If player 1 joins the market, the best option for player 1 is to give in. This is an equilibrium because no player has any reason to change. For player 1 it does not make sense to leave the market (that would give a reward of 1 instead of 2) and for player 2 it is no good idea to switch to fighting either (which would give a reward of 0 instead of 2). The other Nash equilibrium happens when player 1 just doesn’t join the market. However, this scenario includes player 2’s decision to fight, if player 1 had chosen to join the market instead. He basically makes a threat and says “If you join the market, I will fight you.” Remember that previously we said we need to know what the players would do even in the cases that don’t appear to happen? Here we see why this is important. Player 1 needs to assume that player 2 would fight because that is the only reason for player 1 to stay out of the market. If player 2 wouldn’t threaten to fight, we wouldn’t have a Nash equilibrium, because then joining the market would become a better option for player 1. 

But how reasonable is this threat? It keeps player 1 outside the market, but what would happen if player 1 didn’t believe the threat and decided to still join the market? Would player 2 really carry out his threat and fight? That would be very silly because it would give him a reward of 0, whereas giving in would give a reward of 2. From that perspective, player 2 used an empty threat that is not very reasonable. If the case really occurs, he wouldn’t carry it out anyway, would he?

Subgame perfect equilibrium

For a subgame perfect equilibrium, before you get the whole picture, you need to start with small parts of the game. Photo by Ben Stern on Unsplash

The previous example showed that sometimes Nash equilibria occur, that are not very reasonable within the game. To cope with this problem, a more strict concept of equilibrium has been introduced which is called a subgame perfect equilibrium. This adds some stricter conditions to the notion of an equilibrium. Hence every subgame perfect equilibrium is a Nash equilibrium, but not all Nash equilibria are subgame perfect. 

A Nash equilibrium is subgame perfect if every subgame of this equilibrium is a Nash equilibrium itself. What does that mean? First, we have to understand that a subgame is a part of the game’s tree that starts at any node. For example, if player 1 chooses L, the remainder of the tree under the node reached by playing L is a subgame. In a likewise fashion, the tree that comes after the node of action R is a subgame. Last but not least, the whole game is always a subgame of itself. As a consequence, the example we started with has three subgames, which are marked in grey, orange and blue in the following: 

The market game has three subgames. Figure by author.

We already saw, that this game has three Nash equilibria which are (L,(l1,l2)), (L, (l1,r2)) and (R,(r1,r2)). Let us now find out which of these are subgame perfect. To this end, we investigate the subgames one after another, starting with the orange one. If we only look at the orange part of the tree, there is a single Nash equilibrium that occurs if player 2 chooses l1. If we look at the blue subgame, there is also a single Nash equilibrium that is reached when player 2 chooses r2. Now that tells us that in every subgame perfect Nash equilibrium, player 2 has to choose option l1 if we arrive in the orange subgame (i.e. if player 1 chooses L) and player 2 has to choose option r2 if we arrive at the blue subgame (i.e., if player 1 chooses R). Only one of the previous Nash equilibria fulfills this condition, namely (L,(l1,r2)). Hence this is the only subgame perfect Nash equilibrium of the whole game. The other two versions are Nash equilibria as well, but they are somewhat unlogical in the sense, that they contain some kind of empty threat, as we had it in the chocolate pudding market example before. The method we just used to find the subgame perfect Nash equilibrium is called backwards induction, by the way. 

Uncertainty

In dynamic games, it can happen that you have to make decisions without knowing exactly what node of the game you are in. Photo by Denise Jans on Unsplash

So far in our dynamic games, we always knew which decisions the other players made. For a game like chess, this is the case indeed, as every move your opponent makes is perfectly observable. However, there are other situations in which you might not be sure about the exact moves the other players make. As an example, we go back to the chocolate pudding market. You take the perspective of the retailer that is already in the market and you have to decide whether you would start fighting if the other player joins the market. But there is one thing you don’t know, namely how aggressive your opponent will be. When you start fighting, will they be frightened easily and give up? Or will they be aggressive and fight you until only one of you is left? This can be seen as a decision made by the other player that influences your decision. If you expect the other player to be a coward, you might prefer to fight, but if they turn out to be aggressive, you would rather want to give in (reminds you of the birds fighting for food in the previous chapter, doesn’t it?). We can model this scenario in a game like this: 

A dynamic game with a hidden decision (indicated by the dotted circle). Figure by author.

The dotted circle around the two nodes indicates, that these are hidden decisions that are not observable to everyone. If you are player 2, you know whether player 1 joined the market or not, but if they joined, you don’t know whether they are aggressive (left node) or moderate (right node). Hence you act under uncertainty, which is a very common ingredient in many games you play in the real world. Poker would become very boring if everybody knew everyone’s cards, that’s why there is private information, namely the cards on your hand only you know about. 

Now you still have to decide whether to fight or give in, although you are not exactly sure what node of the tree you are in. To do that, you have to make assumptions about the likelihood of each state. If you are quite certain that the other player is behaving moderately, you might be up for a fight, but if you assume them to be aggressive, you might prefer giving in. Say there is a Probability p that the other player is aggressive and 1-p that they behave moderately. If you assume p to be high, you should give in, but if p becomes smaller, there should be a point where your decision switches to fighting. Let’s try to find that point. In particular, there should be a sweet spot in between where the probability of the other player being aggressive vs. moderate is such that fighting and giving in are equal alternatives to one another. That is, the rewards would be equal, which we can model as follows: 

Do you see how this formula is derived from the rewards for fighting or giving in in the different leaves of the tree? This formula solves to p=1/3, so if the probability of the other player being aggressive is 1/3 it would make no difference whether to fight or give in. But if you assume the other player to be aggressive with a probability of more than 1/3, you should give in, and if you assume aggressiveness to be less likely than 1/3, you should fight. This is a chain of thought you also have in other games where you act under uncertainty. When you play poker, you might not calculate the probabilities exactly, but you ask yourself, “How likely is it that John has two kings on his hand?” and depending on your assumption of that probability, you check, raise or give up. 

Summary & outlook

Your journey on the seas of game theory has only just begun. There is so much more to explore. Photo by George Liapis on Unsplash

Now we have learned a lot about dynamic games. Let us summarize our key findings. 

  • Dynamic games include an order in which players take turns. 
  • In dynamic games, the players’ possible actions depend on the previously executed actions of the other players. 
  • A Nash equilibrium in a dynamic game can be implausible, as it contains an empty threat that would not be rational.
  • The concept of subgame perfect equilibria prevents such implausible solutions. 
  • In dynamic games, decisions can be hidden. In that case, players may not exactly know which node of the game they are in and have to assign probabilities to different states of the games. 

With that, we have reached the end of our series on the fundamentals of game theory. We have learned a lot, yet there are plenty of things we haven’t been able to cover. Game theory is a science in itself, and we have only been able to scratch the surface. Other concepts that expand the possibilities of game-theoretic analyses include: 

  • Analysing games that are repeated multiple times. If you play the prisoner’s dilemma multiple times, you might be tempted to punish the other player for having betrayed you in the previous round. 
  • In cooperative games, players can conclude binding contracts that determine their actions to reach a solution of the game together. This is different from the non-cooperative games we looked at, where all players are free to decide and maximize their own reward. 
  • While we only looked at discrete games, where each player has a finite number of actions to choose from, continuous games allow an infinite number of actions (e.g., any number between 0 and 1). 
  • A big part of game theory considers the usage of public goods and the problem that individuals might consume these goods without contributing to their maintenance. 

These concepts allow us to analyse real-world scenarios from various fields such as auctions, social networks, evolution, markets, information sharing, voting behaviour and much more. I hope you enjoyed this series and find meaningful applications for the knowledge you gained, be it the analysis of customer behaviour, political negotiations or the next game night with your friends. From a game theory perspective, life is a game!

References

The topics introduced here are typically covered in standard textbooks on game theory. I mainly used this one, which is written in German though:

  • Bartholomae, F., & Wiens, M. (2016). Spieltheorie. Ein anwendungsorientiertes Lehrbuch. Wiesbaden: Springer Fachmedien Wiesbaden.

An alternative in the English language could be this one:

  • Espinola-Arredondo, A., & Muñoz-Garcia, F. (2023). Game Theory: An Introduction with Step-by-step Examples. Springer Nature.

Game theory is a rather young field of research, with the first main textbook being this one:

  • Von Neumann, J., & Morgenstern, O. (1944). Theory of games and economic behavior.

Like this article? Follow me to be notified of my future posts.

Open LLMs are Necessary For Current Private Adaptations and Outperform Their Closed Alternatives [Paper Reflection]

0

Closed Large Language Models (LLMs), which are proprietary and accessible only via APIs, have dominated the LLM space since around 2022 due to their high performance and versatility. However, Open LLMs have made substantial progress, narrowing the performance gap with their Closed LLM counterparts. Open LLMs are models whose architecture and parameters are publicly available for use, modification, and distribution.

For instance, while Closed LLMs like Anthropic’s Claude (released in March 2023) and OpenAI’s GPT-4 (released in March 2023) set new benchmarks upon their launches, the Open LLM Llama 3 released by Meta in April 2024 and DeepSeek-R1 released in January 2025 not only matched but surpassed these models in tasks such as coding, reasoning, text classification, summarization, and question answering.

While much of the discussion around LLMs centers on task and computational performance, in our paper Open LLMs are Necessary for Current Private Adaptations and Outperform their Closed Alternatives, we focus on the privacy implications of using Open and Closed LLMs. Specifically, we explore whether and how models can be fine-tuned on sensitive data while ensuring robust privacy guarantees.

To this end, we define threat models, compare various Open and Closed LLMs that leverage differential privacy (DP) on classification and generation tasks and analyze methodological limitations. Our research results in a thorough analysis of the privacy-utility tradeoff under different privacy levels.

Our findings indicate that Open LLMs can be adapted to private data without leaking information to third parties, such as LLM providers and malicious users. Thus, they offer a significant privacy advantage over Closed, proprietary models.

The threat space in adapting LLMs to private data

The adaptation of Closed LLMs to private datasets introduces a multifaceted threat space. In typical scenarios, data curators provide their sensitive data to LLM providers for fine-tuning, producing a model tailored to the dataset. This customized model is subsequently queried by external parties, e.g., customers of the data curator.

The resulting threat space can be categorized into three key dimensions:

  1. From the data curator to the LLM provider: The private data shared during fine-tuning may be susceptible to unauthorized access or misuse.
  2. From the querying party to the LLM provider: Queries submitted by end users, which often contain sensitive information intended for the data curator, are exposed to the LLM provider.
  1. From malicious end users to the adapted LLM: Malicious end users may attempt to extract private information through the LLM’s responses to carefully crafted queries.

In contrast to Closed LLMs, Open LLMs provide full control over the model and data, enabling private adaptation without the need to share sensitive information with a third party. This control eliminates the first two threat vectors associated with Closed LLMs, such as unauthorized access or misuse by the provider and exposure of user queries. With Open LLMs, data curators can directly fine-tune the model on private datasets using privacy-preserving techniques, ensuring end-to-end privacy.

What are the current methods for private adaptation of LLMs? 

It follows from our threat space analysis that restricting access to the fine-tuning dataset alone does not guarantee data privacy. Model outputs can still reveal sensitive information from the fine-tuning data. If the fine-tuned model is exposed (e.g., via an API), it remains vulnerable to information extraction and inference attacks.

Differential privacy (DP) introduces a rigorous mathematical framework that ensures the privacy of individuals whose data is used in the fine-tuning process. Specifically, DP adds carefully calibrated noise to the model updates, making it statistically improbable to determine whether any individual’s data was included in the fine-tuning dataset. Its quantifiable and robust privacy guarantee makes DP valuable for protecting sensitive information in LLM fine-tuning.

While DP provides privacy guarantees for both Open and Closed LLMs, it does not address the issue of trust in third-party providers for Closed LLMs. For these models, data curators must rely on the provider to implement safeguards and handle sensitive data responsibly.

Private adaptation methods for Closed LLMs 

We can rule out fine-tuning services offered by LLM providers (e.g., OpenAI and Amazon), as this entails sharing private data with a third party. Closed LLMs are accessible only via APIs. Thus, we cannot access and adapt the model’s weights directly.

Instead, private adaptation methods for Closed LLMs rely on privacy-preserving discrete prompts or private in-context learning (ICL). These approaches work by carefully crafting input prompts or selecting relevant examples to guide the model’s behavior, all while ensuring that sensitive information in the prompts or examples is protected from potential leakage or inference attacks.

All methods we evaluate in our study follow the PATE (Private Aggregation of Teacher Ensembles) framework. At a high level, PATE achieves data privacy by splitting the private dataset into non-overlapping partitions. Then, each partition is used to train a so-called teacher model. These teacher models are joined into an ensemble model by combining their outputs while adding noise, which preserves privacy.

This ensemble is then used to train a so-called student model in the following way: The ensemble makes predictions for samples from an unlabeled public dataset. The resulting (sample, ensemble prediction) pairs constitute the training data for the student model. Thus, the student learns to make the same predictions as the teacher ensemble but never sees sensitive data samples. The student is what’s released as the final model.

Overview of the PATE framework. The sensitive dataset is divided into non-overlapping partitions, and a separate teacher model is trained on each partition. All teachers are aggregated noisily into an ensemble model, which is used to make predictions on a public dataset. The samples from the public dataset, together with the ensemble’s predictions, constitute the training data for the student model, which is the model that is eventually queried by users. | Source

The private adaptation methods for Closed LLMs we analyze in our study build on this general framework. They differ in how the teachers are utilized and how their responses are aggregated:

  • Differentially Private In-context Learning (DP-ICL): All teachers process the same prompt, and the ensemble’s response is the noisy consensus.
  • PromptPATE: The teacher ensemble assigns labels to public unlabeled data via private voting. These labeled public sequences are used to create new discrete student prompts, which are deployed with the LLM.
  • DP-FewShotGen: The teacher ensemble generates private synthetic few-shot samples that are used as samples for in-context learning.
  • DP-OPT: A local LLM generates privacy-preserving prompts and instructions from the private dataset. These are used for in-context learning for the third-party Closed LLM.

In our paper, we compare the privacy protection and performance of these four state-of-the-art methods for private adaptation of Closed LLMs. When applying them to the popular Closed LLMs Claude, GPT-3 Babbage, GPT-3 Davinci, and GPT-4 Turbo, we observe that compared to private adaptation of Open LLMs, these methods offer lower performance at a higher cost on various downstream tasks, including dialog summarization, classification, and generation. Further, all methods except DP-OPT leak training data to the LLM provider.

Private adaptation methods for Open LLMs 

Unlike Closed LLMs, Open LLMs provide access to their parameters, enabling more flexible and parameter-centric private adaptation methods. These methods typically follow the Differentially Private Stochastic Gradient Descent (DPSGD) paradigm to ensure privacy. In DPSGD, the influence of each private data point is constrained during training through gradient clipping and the addition of calibrated noise. This approach guarantees that the model does not memorize or leak sensitive information.

In our study, we explore three primary methods for private adaptation of Open LLMs: 

  1. Prompt-based adaptation (PromptDPSGD) introduces a small number of additional parameters (<1% of the model’s total parameters) in the input space through soft prompts or prefix-tuning and adapts Differentially Private Stochastic Gradient Descent (DPSGD) to preserve privacy.
  2. Parameter-efficient fine-tuning, such as LoRA, only updates a relatively small number of parameters (<10% of the model’s total parameters) within the model’s architecture to enable efficient updates. PrivateLoRA extends this approach with DP guarantees by building on the DPSGD algorithm.
  3. Full fine-tuning adaptations (DP-FineTune) involve fine-tuning the entire model or a subset of its layers for comprehensive adaptation while adhering to differential privacy principles.

Applying these methods to Vicuna, Llama-3, OpenLLaMa, BART, RoBERTa, and the Pythia suite of models, we find that private adaptation of Open LLMs improves performance on downstream tasks and reduces costs compared to their Closed counterparts. It also provides a critical privacy benefit by eliminating the risk of exposing private data and user queries to LLM providers.

Insightful results

Our analysis of private adaptation methods for both Closed and Open LLMs reveals several critical findings regarding data leakage, performance, and cost:

  1. Query data leakage: All private adaptation methods for Closed LLMs leak query data to the LLM provider. This means that sensitive information from user queries is exposed during the adaptation process, posing a significant privacy risk.
  2. Training data leakage: Only one method (DP-OPT) of the four methods of private adaptation of Closed LLMs successfully protects private training data from the LLM provider. However, this method requires a local LLM to effectively protect the privacy of the training data. The remaining private adaptation methods for Closed LLMs leak a large fraction of the training data to the LLM provider, undermining the privacy guarantees of the adaptation process.
  3. Performance: All adaptation methods for Closed LLMs achieve lower downstream task performance than privacy-preserving local adaptations on Open LLMs, even when the Open LLMs are significantly smaller than their Closed counterparts.
  4. Cost: The training and query costs for private adaptations of Closed LLMs are substantially higher due to the API access costs imposed by the LLM provider. In contrast, private adaptations for Open LLMs are more cost-effective. We estimated the costs assuming an A40 GPU with 48 GB of memory. In this scenario, privately adopting a Closed LLM to text classification tasks with DP-ICL costs about $140. In contrast, fine-tuning an Open LLM with PrivateLoRA on the same tasks costs about $30.

This leads to the conclusion that for a truly privacy-preserving adaptation of LLMs, one should use Open LLMs. By offering full control over the model and data, Open LLMs eliminate the risks associated with third-party providers and enable robust privacy-preserving techniques. As a result, Open LLMs address the limitations of Closed LLMs and enable efficient and customizable adaptations tailored to sensitive datasets.

Was the article useful?

Explore more content topics:

Nine Pico PIO Wats with Rust (Part 2)

0

This is Part 2 of an exploration into the unexpected quirks of programming the Raspberry Pi Pico PIO with Micropython. If you missed Part 1, we uncovered four Wats that challenge assumptions about register count, instruction slots, the behavior of pull noblock, and smart yet cheap hardware.

Now, we continue our journey toward crafting a theremin-like musical instrument — a project that reveals some of the quirks and perplexities of PIO programming. Prepare to challenge your understanding of constants in a way that brings to mind a Shakespearean tragedy.

Wat 5: Inconstant constants

In the world of PIO programming, constants should be reliable, steadfast, and, well, constant. But what if they’re not? This brings us to a puzzling Wat about how the set instruction in PIO works—or doesn’t—when handling larger constants.

Much like Juliet doubting Romeo’s constancy, you might find yourself wondering if PIO constants will, as she says, “prove likewise variable.”

The problem: Constants are not as big as they seem

Imagine you’re programming an ultrasonic range finder and need to count down from 500 while waiting for the Echo signal to drop from high to low. To set up this wait time in PIO, you might naïvely try to load the constant value directly using set:

; In Rust, be sure 'config.shift_in.direction = ShiftDirection::Left;'
set y, 15       ; Load upper 5 bits (0b01111)
mov isr, y      ; Transfer to ISR (clears ISR)
set y, 20       ; Load lower 5 bits (0b10100)
in y, 5         ; Shift in lower bits to form 500 in ISR
mov y, isr      ; Transfer back to y

Aside: Don’t try to understand the crazy jmp operations here. We’ll discuss those next in Wat 6.

But here’s the tragic twist: the set instruction in PIO is limited to constants between 0 and 31. Moreover, the star-crossed set instruction doesn’t report an error. Instead, it silently corrupts the entire PIO instruction. This produces a nonsense result.

Workarounds for inconstant constants

To address this limitation, consider the following approaches:

  • Read Values and Store Them in a Register: We saw this approach in Wat 1. You can load your constant in the osr register, then transfer it to y. For example:
# Read the max echo wait into OSR.
pull                    ; same as pull block
mov y, osr              ; Load max echo wait into Y
  • Shift and Combine Smaller Values: Using the isr register and the in instruction, you can build up a constant of any size. This, however, consumes time and operations from your 32-operation budget (see Part 1, Wat 2).
; In Rust, be sure 'config.shift_in.direction = ShiftDirection::Left;'

set y, 15       ; Load upper 5 bits (0b01111)
mov isr, y      ; Transfer to ISR (clears ISR)
set y, 20       ; Load lower 5 bits (0b10100)
in y, 5         ; Shift in lower bits to form 500 in ISR
mov y, isr      ; Transfer back to y
  • Slow Down the Timing: Reduce the frequency of the state machine to stretch delays over more system clock cycles. For example, lowering the state machine speed from 125 MHz to 343 kHz reduces the timeout constant 182,216 to 500
  • Use Extra Delays and (Nested) Loops: All instructions support an optional delay, allowing you to add up to 31 extra cycles. (To generate even longer delays, use loops — or even nested loops.)
; Generate 10μs trigger pulse (4 cycles at 343_000Hz)
set pins, 1 [3]       ; Set trigger pin to high, add delay of 3
set pins, 0           ; Set trigger pin to low voltage
  • Use the “Subtraction Trick” to Generate the Maximum 32-bit Integer: In Wat 7, we’ll explore a way to generate 4,294,967,295 (the maximum unsigned 32-bit integer) via subtraction.

Much like Juliet cautioning against swearing by the inconstant moon, we’ve discovered that PIO constants are not always as steadfast as they seem. Yet, just as their story takes unexpected turns, so too does ours, moving from the inconstancy of constants to the uneven nature of conditionals. In the next Wat, we’ll explore how PIO’s handling of conditional jumps can leave you questioning its loyalty to logic.

Wat 6: Conditionals through the looking-glass

In most programming environments, logical conditionals feel balanced: you can test if a pin is high or low, or check registers for equality or inequality. In PIO, this symmetry breaks down. You can jump on pin high, but not pin low, and on x!=y, but not x==y. The rules are whimsical — like Humpty Dumpty in Through the Looking-Glass: “When I define a conditional, it means just what I choose it to mean — neither more nor less.”

These quirks force us to rewrite our code to fit the lopsided logic, creating a gulf between how we wish the code could be written and how we must write it.

The problem: Lopsided conditionals in action

Consider a simple scenario: using a range finder, you want to count down from a maximum wait time (y) until the ultrasonic echo pin goes low. Intuitively, you might write the logic like this:

measure_echo_loop:
 jmp !pin measurement_complete   ; If echo voltage is low, measurement is complete
 jmp y-- measure_echo_loop       ; Continue counting down unless timeout

And when processing the measurement, if we only wish to output values that differ from the previous value, we would write:

measurement_complete:
 jmp x==y cooldown             ; If measurement is the same, skip to cool down
 mov isr, y                    ; Store measurement in ISR
 push                          ; Output ISR
 mov x, y                      ; Save the measurement in X

Unfortunately, PIO doesn’t let you test !pin or x==y directly. You must restructure your logic to accommodate the available conditionals, such as pin and x!=y.

The solution: The way it must be

Given PIO’s limitations, we adapt our logic with a two-step approach that ensures the desired behavior despite the missing conditionals:

  • Jump on the opposite conditional to skip two instructions forward.
  • Next, use an unconditional jump to reach the desired target.

This workaround adds one extra jump (affecting the instruction limit), but the additional label is cost-free.

Here is the rewritten code for counting down until the pin goes low:

measure_echo_loop:
   jmp pin echo_active     ; if echo voltage is high continue count down
   jmp measurement_complete ; if echo voltage is low, measurement is complete
echo_active:
   jmp y-- measure_echo_loop ; Continue counting down unless timeout

And here is the code for processing the measurement such that it will only output differing values:

measurement_complete:
   jmp x!=y send_result    ; if measurement is different, then send it.
   jmp cooldown            ; If measurement is the same, don't send.

send_result:
   mov isr, y              ; Store measurement in ISR
   push                    ; Output ISR
   mov x, y               ; Save the measurement in X

Lessons from Humpty Dumpty’s conditionals

In Through the Looking-Glass, Alice learns to navigate Humpty Dumpty’s peculiar world — just as you’ll learn to navigate PIO’s Wonderland of lopsided conditions.

But as soon as you master one quirk, another reveals itself. In the next Wat, we’ll uncover a surprising behavior of jmp that, if it were an athlete, would shatter world records.

In Part 1’s Wat 1 and Wat 3, we saw how jmp x-- or jmp y-- is often used to loop a fixed number of times by decrementing a register until it reaches 0. Straightforward enough, right? But what happens when y is 0 and we run the following instruction?

jmp y-- measure_echo_loop

If you guessed that it does not jump to measure_echo_loop and instead falls through to the next instruction, you’re absolutely correct. But for full credit, answer this: What value does y have after the instruction?

The answer: 4,294,967,295. Why? Because y is decremented after it is tested for zero. Wat!?

Aside: If this doesn’t surprise you, you likely have experience with C or C++ which distinguish between pre-increment (e.g., ++x) and post-increment (e.g., x++) operations. The behavior of jmp y-- is equivalent to a post-decrement, where the value is tested before being decremented.

This value, 4,294,967,295, is the maximum for a 32-bit unsigned integer. It’s as if a track-and-field long jumper launches off the takeoff board but, instead of landing in the sandpit, overshoots and ends up on another continent.

Aside: As foreshadowed in Wat 5, we can use this behavior intentionally to set a register to the value 4,294,967,295.

Now that we’ve learned how to stick the landing with jmp, let’s see if we can avoid getting stuck by the pins that PIO reads and sets.

In Dr. Seuss’s Too Many Daves, Mrs. McCave had 23 sons, all named Dave, leading to endless confusion whenever she called out their name. In PIO programming, pin and pins can refer to completely different ranges of pins depending on the context. It’s hard to know which Dave or Daves you’re talking to.

The problem: Pin ranges and subranges

In PIO, both pin and pins instructions depend on pin ranges defined in Rust, outside of PIO. However, individual instructions often operate on a subrange of those pin ranges. The behavior varies depending on the command: the subrange could be the first n pins of the range, all the pins, or just a specific pin given by an index. To clarify PIO’s behavior, I created the following table:

This table shows how PIO interprets the terms pin and pins in different instructions, along with their associated contexts and configurations.

Example: Distance program for the range finder

Here’s a PIO program for measuring the distance to an object using Trigger and Echo pins. The key features of this program are:

  • Continuous Operation: The range finder runs in a loop as fast as possible.
  • Maximum Range Limit: Measurements are capped at a given distance, with a return value of 4,294,967,295 if no object is detected.
  • Filtered Outputs: Only measurements that differ from their immediate predecessor are sent, reducing the output rate.

Glance over the program and notice that although it is working with two pins — Trigger and Echo — throughout the program we only see pin and pins.

.program distance

; X is the last value sent. Initialize it to
; u32::MAX which means 'echo timeout'
; (Set X to u32::MAX by subtracting 1 from 0)
   set x, 0
subtraction_trick:
   jmp x-- subtraction_trick

; Read the max echo wait into OSR
   pull                         ; same as pull block

; Main loop
.wrap_target
   ; Generate 10μs trigger pulse (4 cycles at 343_000Hz)
   set pins, 0b1 [3]       ; Set trigger pin to high, add delay of 3
   set pins, 0b0           ; Set trigger pin to low voltage

   ; When the trigger goes high, start counting down until it goes low
   wait 1 pin 0            ; Wait for echo pin to be high voltage
   mov y, osr              ; Load max echo wait into Y

measure_echo_loop:
   jmp pin echo_active     ; if echo voltage is high continue count down
   jmp measurement_complete ; if echo voltage is low, measurement is complete
echo_active:
   jmp y-- measure_echo_loop ; Continue counting down unless timeout

; Y tells where the echo countdown stopped. It
; will be u32::MAX if the echo timed out.
measurement_complete:
   jmp x!=y send_result    ; if measurement is different, then sent it.
   jmp cooldown            ; If measurement is the same, don't send.

send_result:
   mov isr, y              ; Store measurement in ISR
   push                    ; Output ISR
   mov x, y               ; Save the measurement in X

; Cool down period before next measurement
cooldown:
   wait 0 pin 0           ; Wait for echo pin to be low
.wrap                      ; Restart the measurement loop

Configuring Pins

To ensure the PIO program behaves as intended:

  • set pins, 0b1 should control the Trigger pin.
  • wait 1 pin 0 should monitor the Echo pin.
  • jmp pin echo_active should also monitor the Echo pin.

Here’s how you can configure this in Rust (followed by an explanation):

let mut distance_state_machine = pio1.sm0;
let trigger_pio = pio1.common.make_pio_pin(hardware.trigger);
let echo_pio = pio1.common.make_pio_pin(hardware.echo);
distance_state_machine.set_pin_dirs(Direction::Out, &[&trigger_pio]);
distance_state_machine.set_pin_dirs(Direction::In, &[&echo_pio]);
distance_state_machine.set_config(&{
   let mut config = Config::default();
   config.set_set_pins(&[&trigger_pio]); // For set instruction
   config.set_in_pins(&[&echo_pio]); // For wait instruction
   config.set_jmp_pin(&echo_pio); // For jmp instruction
   let program_with_defines = pio_file!("examples/distance.pio");
   let program = pio1.common.load_program(&program_with_defines.program);
   config.use_program(&program, &[]); // No side-set pins
   config
});

The keys here are the set_set_pins, set_in_pins, and set_jmp_pin methods on the Config struct.

  • set_in_pins: Specifies the pins for input operations, such as wait(1, pin, …). The “in” pins must be consecutive.
  • set_set_pins: Configures the pin for set operations, like set(pins, 1). The “set” pins must also be consecutive.
  • set_jmp_pin: Defines the single pin used in conditional jumps, such as jmp(pin, ...).

As described in the table, other optional inputs include:

  • set_out_pins: Sets the consecutive pins for output operations, such as out(pins, …).
  • use_program: Sets a) the loaded program and b) consecutive pins for sideset operations. Sideset operations allow simultaneous pin toggling during other instructions.

Configuring Multiple Pins

Although not required for this program, you can configure a range of pins in PIO by providing a slice of consecutive pins. For example, suppose we had two ultrasonic range finders:

let trigger_a_pio = pio1.common.make_pio_pin(hardware.trigger_a);
let trigger_b_pio = pio1.common.make_pio_pin(hardware.trigger_b);
config.set_set_pins(&[&trigger_a_pio, &trigger_b_pio]);

A single instruction can then control both pins:

set pins, 0b11 [3]  # Sets both trigger pins (17, 18) high, adds delay
set pins, 0b00      # Sets both trigger pins low

This approach lets you efficiently apply bit patterns to multiple pins simultaneously, streamlining control for applications involving multiple outputs.

Aside: The Word “Set” in Programming

In programming, the word “set” is notoriously overloaded with multiple meanings. In the context of PIO, “set” refers to something to which you can assign a value — such as a pin’s state. It does not mean a collection of things, as it often does in other programming contexts. When PIO refers to a collection, it usually uses the term “range” instead. This distinction is crucial for avoiding confusion as you work with PIO.

Lessons from Mrs. McCave

In Too Many Daves, Mrs. McCave lamented not giving her 23 Daves more distinct names. You can avoid her mistake by clearly documenting your pins with meaningful names — like Trigger and Echo — in your comments.

But if you think handling these pin ranges is tricky, debugging a PIO program adds an entirely new layer of challenge. In the next Wat, we’ll dive into the kludgy debugging methods available. Let’s see just how far we can push them.

I like to debug with interactive breakpoints in VS Code. I also do print debugging, where you insert temporary info statements to see what the code is doing and the values of variables. Using the Raspberry Pi Debug Probe and probe-rs, I can do both of these with regular Rust code on the Pico.

With PIO programming, however, I can do neither.

The fallback is push-to-print debugging. In PIO, you temporarily output integer values of interest. Then, in Rust, you use info! to print those values for inspection.

For example, in the following PIO program, we temporarily add instructions to push the value of x for debugging. We also include set and out to push a constant value, such as 7, which must be between 0 and 31 inclusive.

.program distance

; X is the last value sent. Initialize it to
; u32::MAX which means 'echo timeout'
; (Set X to u32::MAX by subtracting 1 from 0)
   set x, 0
subtraction_trick:
   jmp x-- subtraction_trick

; DEBUG: See the value of x
   mov isr, x
   push

; Read the max echo wait into OSR
   pull                         ; same as pull block

; DEBUG: Send constant value
   set y, 7           ; Push '7' so that we know we've reached this point
   mov isr, y
   push
; ...

Back in Rust, you can read and print these values to help understand what’s happening in the PIO code (full code and project):

  // ...
   distance_state_machine.set_enable(true);
   distance_state_machine.tx().wait_push(MAX_LOOPS).await;
   loop {
       let end_loops = distance_state_machine.rx().wait_pull().await;
       info!("end_loops: {}", end_loops);
   }
  // ...

Outputs:

INFO  Hello, debug!
└─ distance_debug::inner_main::{async_fn#0} @ examplesdistance_debug.rs:27
INFO  end_loops: 4294967295
└─ distance_debug::inner_main::{async_fn#0} @ examplesdistance_debug.rs:57
INFO  end_loops: 7
└─ distance_debug::inner_main::{async_fn#0} @ examplesdistance_debug.rs:57

When push-to-print debugging isn’t enough, you can turn to hardware tools. I bought my first oscilloscope (a FNIRSI DSO152, for $37). With it, I was able to confirm the Echo signal was working. The Trigger signal, however, was too fast for this inexpensive oscilloscope to capture clearly.

Using these methods — especially push-to-print debugging — you can trace the flow of your PIO program, even without a traditional debugger.

Aside: In C/C++ (and potentially Rust), you can get closer to a full debugging experience for PIO, for example, by using the piodebug project.

That concludes the nine Wats, but let’s bring everything together in a bonus Wat.

Now that all the components are ready, it’s time to combine them into a working theremin-like musical instrument. We need a Rust monitor program. This program starts both PIO state machines — one for measuring distance and the other for generating tones. It then waits for a new distance measurement, maps that distance to a tone, and sends the corresponding tone frequency to the tone-playing state machine. If the distance is out of range, it stops the tone.

Rust’s Place: At the heart of this system is a function that maps distances (from 0 to 50 cm) to tones (approximately B2 to F5). This function is simple to write in Rust, leveraging Rust’s floating-point math and exponential operations. Implementing this in PIO would be virtually impossible due to its limited instruction set and lack of floating-point support.

Here’s the core monitor program to run the theremin (full file and project):

sound_state_machine.set_enable(true);
distance_state_machine.set_enable(true);
distance_state_machine.tx().wait_push(MAX_LOOPS).await;
loop {
   let end_loops = distance_state_machine.rx().wait_pull().await;
   match loop_difference_to_distance_cm(end_loops) {
       None => {
           info!("Distance: out of range");
           sound_state_machine.tx().wait_push(0).await;
       }
       Some(distance_cm) => {
           let tone_frequency = distance_to_tone_frequency(distance_cm);
           let half_period = sound_state_machine_frequency / tone_frequency as u32 / 2;
           info!("Distance: {} cm, tone: {} Hz", distance_cm, tone_frequency);
           sound_state_machine.tx().push(half_period); // non-blocking push
           Timer::after(Duration::from_millis(50)).await;
       }
   }
}

Using two PIO state machines alongside a Rust monitor program lets you literally run three programs at once. This setup is convenient on its own and is essential when strict timing or very high-frequency I/O operations are required.

Aside: Alternatively, Rust Embassy’s async tasks let you implement cooperative multitasking directly on a single main processor. You code in Rust rather than a mixture of Rust and PIO. Although Embassy tasks don’t literally run in parallel, they switch quickly enough to handle applications like a theremin. Here’s a snippet from theremin_no_pio.rs showing a similar core loop:

loop {
       match distance.measure().await {
           None => {
               info!("Distance: out of range");
               sound.rest().await;
           }
           Some(distance_cm) => {
               let tone_frequency = distance_to_tone_frequency(distance_cm);
               info!("Distance: {} cm, tone: {} Hz", distance_cm, tone_frequency);
               sound.play(tone_frequency).await;
               Timer::after(Duration::from_millis(50)).await;
           }
       }
   }

See our recent article on Rust Embassy programming for more details.

Now that we’ve assembled all the components, let’s watch the video again of me “playing” the musical instrument. On the monitor screen, you can see the debugging prints displaying the distance measurements and the corresponding tones. This visual connection highlights how the system responds in real time.

Conclusion

PIO programming on the Raspberry Pi Pico is a captivating blend of simplicity and complexity, offering unparalleled hardware control while demanding a shift in mindset for developers accustomed to higher-level programming. Through the nine Wats we’ve explored, PIO has both surprised us with its limitations and impressed us with its raw efficiency.

While we’ve covered significant ground — managing state machines, pin assignments, timing intricacies, and debugging — there’s still much more you can learn as needed: DMA, IRQ, side-set pins, differences between PIO on the Pico 1 and Pico 2, autopush and autopull, FIFO join, and more.

Recommended Resources

At its core, PIO’s quirks reflect a design philosophy that prioritizes low-level hardware control with minimal overhead. By embracing these characteristics, PIO will not only meet your project’s demands but also open doors to new possibilities in embedded systems programming.

Please follow Carl on Towards Data Science and on @carlkadie.bsky.social. I write on scientific programming in Rust and Python, machine learning, and statistics. I tend to write about one article per month.

How to Build a RAG System Using LangChain, Ragas, and Neptune

0

LangChain provides composable building blocks to create LLM-powered applications, making it an ideal framework for building RAG systems. Developers can integrate components and APIs of different vendors into coherent applications.

Evaluating a RAG system’s performance is crucial to ensure high-quality responses and robustness. The Ragas framework offers a large number of RAG-specific metrics as well as capabilities for generating dedicated evaluation datasets.

neptune.ai makes it easy for RAG developers to track evaluation metrics and metadata, enabling them to analyze and compare different system configurations. The experiment tracker can handle large amounts of data, making it well-suited for quick iteration and extensive evaluations of LLM-based applications.

Imagine asking a chat assistant about LLMOps only to receive outdated advice or irrelevant best practices. While LLMs are powerful, they rely solely on their pre-trained knowledge and lack the ability to fetch current data.

This is where Retrieval-Augmented Generation (RAG) comes in. RAG combines the generative power of LLMs with external data retrieval, enabling the assistant to access and use real-time information. For example, instead of outdated answers, the chat assistant could pull insights from Neptune’s LLMOps article collection to deliver accurate and contextually relevant responses.

In this guide, we’ll show you how to build a RAG system using the LangChain framework, evaluate its performance using Ragas, and track your experiments with neptune.ai. Along the way, you’ll learn to create a baseline RAG system, refine it using Ragas metrics, and enhance your workflow with Neptune’s experiment tracking.

Part 1: Building a baseline RAG system with LangChain

In the first part of this guide, we’ll use LangChain to build a RAG system for the blog posts in the LLMOps category on Neptune’s blog.

Overview of a baseline RAG system. A user’s question is used as the query to retrieve relevant documents from a database. The documents returned by the search are added to the prompt that is passed to the LLM together with the user’s question. The LLM uses the information in the prompt to generate an answer.
Overview of a baseline RAG system. A user’s question is used as the query to retrieve relevant documents from a database. The documents returned by the search are added to the prompt that is passed to the LLM together with the user’s question. The LLM uses the information in the prompt to generate an answer. | Source

What is LangChain?

LangChain offers a collection of open-source building blocks, including memory management, data loaders for various sources, and integrations with vector databases—all the essential components of a RAG system.

LangChain stands out among the frameworks for building RAG systems for its composability and versatility. Developers can combine and connect these building blocks using a coherent Python API, allowing them to focus on creating LLM applications rather than dealing with the nitty-gritty of API specifications and data transformations.

Overview of the categories of building blocks provided by LangChain. The framework includes interfaces to models and vector stores, document loaders, and text processing utilities like output parsers and text splitters. Further, LangChain offers features for prompt engineering, like templates and example selectors. The framework also contains a collection of tools that can be called by LLM agents.
Overview of the categories of building blocks provided by LangChain. The framework includes interfaces to models and vector stores, document loaders, and text processing utilities like output parsers and text splitters. Further, LangChain offers features for prompt engineering, like templates and example selectors. The framework also contains a collection of tools that can be called by LLM agents. | Source

Step 1: Setting up

We’ll begin by installing the necessary dependencies (I used Python 3.11.4 on Linux):

pip install -qU langchain-core==0.1.45 langchain-openai==0.0.6 langchain-chroma==0.1.4 ragas==0.2.8 neptune==1.13.0 pandas==2.2.3 datasets==3.2.0

For this example, we’ll use OpenAI’s models and configure the API key. To access OpenAI models, you’ll need to create an OpenAI account and generate an API key. Our usage in this blog should be well within the free-tier limits.

Once we have obtained our API key, we’ll set it as an environment variable so that LangChain’s OpenAI building blocks can access it:

import os
os.environ["OPENAI_API_KEY"] = "YOUR_KEY_HERE"

You can also use any of LangChain’s other embedding and chat models, including local models provided by Ollama. Thanks to the compositional structure of LangChain, all it takes is replacing OpenAIEmbeddings and OpenAIChat in the code with the respective alternative building blocks.

Step 2: Load and parse the raw data

Source data for RAG systems is often unstructured documents. Before we can use it effectively, we’ll need to process and parse it into a structured format.

Fetch the source data

Since we’re working with a blog, we’ll use LangChain’s WebBaseLoader to load data from Neptune’s blog. WebBaseLoader reads raw webpage content, capturing text and structure, such as headings.

The web pages are loaded as LangChain documents, which include the page content as a string and metadata associated with that document, e.g., the source page’s URL.

In this example, we select 3 blog posts to create the chat assistant’s knowledge base:

import bs4
from langchain_community.document_loaders import WebBaseLoader

loader = WebBaseLoader(
    web_paths=[
        "https://neptune.ai/blog/llm-hallucinations",
        "https://neptune.ai/blog/llmops",
        "https://neptune.ai/blog/llm-guardrails"
    ],
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer(name=["p", "h2", "h3", "h4"])
    ),
)
docs = loader.load()

Split the data into smaller chunks

To meet the embedding model’s token limit and improve retrieval performance, we’ll split the long blog posts into smaller chunks.

The chunk size is a trade-off between specificity (capturing detailed information within each chunk) and efficiency (reducing the total number of resulting chunks). By overlapping chunks, we mitigate the loss of critical information that occurs when a self-contained sequence of the source text is split into two incoherent chunks.

Visualization of the chunks created from the article LLM Hallucinations 101. The text is split into four chunks highlighted in blue, lime green, dark orange, and dark yellow. The overlaps between chunks are marked in olive green.
Visualization of the chunks created from the article LLM Hallucinations 101. The text is split into four chunks highlighted in blue, lime green, dark orange, and dark yellow. The overlaps between chunks are marked in olive green. | Created with ChunkViz

For generic text, LangChain recommends the RecursiveCharacterTextSplitter. We set the chunk size to a maximum of 1,000 characters with an overlap of 200 characters. We also filter out unnecessary parts of the documents, such as the header, footer, and any promotional content:

from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)

header_footer_keywords = ["peers about your research", "deepsense", "ReSpo", "Was the article useful?", "related articles", "All rights reserved"]

splits = []
for s in text_splitter.split_documents(docs):
    if not any(kw in s.page_content for kw in header_footer_keywords):
        splits.append(s)

len(splits)

Step 3: Set up the vector store

Vector stores are specialized data stores that enable indexing and retrieving information based on vector representations.

Choose a vector store

LangChain supports many vector stores. In this example, we’ll use Chroma, an open-source vector store specifically designed for LLM applications.

By default, Chroma stores the collection in memory; once the session ends, all the data (embeddings and indices) are lost. While this is fine for our small example, in production, you’ll want to persist the database to disk by passing the persist_directory keyword argument when initializing Chroma.

Specify which embedding model to use

Embedding models convert chunks into vectors. There are many embedding models to choose from. The Massive Text Embedding Benchmark (MTEB) leaderboard is a great resource for selecting one based on model size, embedding dimensions, and performance requirements.

The MTEB Leaderboard provides a standardized comparison of embedding models across diverse tasks and datasets, including retrieval, clustering, classification, and reranking. The leaderboard provides a clear comparison of model performance and makes selecting embedding models easier through filters and ranking.
The MTEB Leaderboard provides a standardized comparison of embedding models across diverse tasks and datasets, including retrieval, clustering, classification, and reranking. The leaderboard provides a clear comparison of model performance and makes selecting embedding models easier through filters and ranking.

For our example LLMOps RAG system, we’ll use OpenAIEmbeddings with its default model. (At the time of writing, this was text-embedding-ada-002.)

Create a retriever object from the vector store

A retriever performs semantic searches to find the most relevant pieces of information based on a user query. For this baseline example, we’ll configure the retriever to return only the top result, which will be used as context for the LLM to generate an answer.

Initializing the vector store for our RAG system and instantiating a retriever takes only two lines of code:

from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma.from_documents(
   documents=splits,
   embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})

In the last line, we have specified through search_kwargs that the retriever only returns the most similar document (top-k retrieval with k = 1).

Step 4: Bring it all together

Now that we’ve set up a vector database with the source data and initialized the retriever to return the most relevant chunk given a query, we’ll combine it with an LLM to complete our baseline RAG chain.

Define a prompt template

We need to set a prompt to guide the LLM in responding. This prompt should tell the model to use the retrieved context to answer the query.

We’ll use a standard RAG prompt template that specifically asks the LLM to use the provided context (the retrieved chunk) to answer the user query concisely:

from langchain_core.prompts import ChatPromptTemplate

system_prompt = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer "
    "the question. If you don't know the answer, say that you "
    "don't know. Use three sentences maximum and keep the "
    "answer concise."
    "\n\n"
    "{context}"
)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

Create the full RAG chain

We’ll use the create_stuff_documents_chain utility function to set up the generative part of our RAG chain. It combines an instantiated LLM and a prompt template with a {context} placeholder into a chain that takes a set of documents as its input, which are “stuffed” into the prompt before it is fed into the LLM. In our case, that’s OpenAI’s GPT4o-mini.

from langchain_openai import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain

llm = ChatOpenAI(model="gpt-4o-mini")
question_answer_chain = create_stuff_documents_chain(llm, prompt)

Then, we can use the create_retrieval_chain utility function to finally instantiate our complete RAG chain: 

from langchain.chains import create_retrieval_chain

rag_chain = create_retrieval_chain(retriever, question_answer_chain)

Get an output from the RAG chain

To see how our system works, we can run a first inference call. We’ll send a query to the chain that we know can be answered using the contents of one of the blog posts:

response = rag_chain.invoke({"input": "What are DOM-based attacks?"})
print(response["answer"])

The response is a dictionary that contains “input,” “context,” and “answer” keys:

{
  "input": 'What are DOM-based attacks?',
  'context': [Document(metadata={'source': 'https://neptune.ai/blog/llm-guardrails'}, page_content='By prompting the application to pretend to be a chatbot that “can do anything” and is not bound by any restrictions, users were able to manipulate ChatGPT to provide responses to questions it would usually decline to answer.Although “prompt injection” and “jailbreaking” are often used interchangeably in the community, they refer to distinct vulnerabilities that must be handled with different methods.DOM-based attacksDOM-based attacks are an extension of the traditional prompt injection attacks. The key idea is to feed a harmful instruction into the system by hiding it within a website’s code.Consider a scenario where your program crawls websites and feeds the raw HTML to an LLM on a daily basis. The rendered page looks normal to you, with no obvious signs of anything wrong. Yet, an attacker can hide a malicious key phrase by matching its color to the background or adding it in parts of the HTML code that are not rendered, such as a style Tag.While invisible to human eyes, the LLM will')],
  "answer": "DOM-based attacks are a type of vulnerability where harmful instructions are embedded within a website's code, often hidden from view. Attackers can conceal malicious content by matching its color to the background or placing it in non-rendered sections of the HTML, like style tags. This allows the malicious code to be executed by a system, such as a language model, when it processes the website's HTML."}

We see that the retriever appropriately identified a snippet from the LLM Guardrails: Secure and Controllable Deployment article as the most relevant chunk.

Define a prediction function

Now that we have a fully functioning end-to-end RAG chain, we can create a convenience function that enables us to query our RAG chain. It takes a RAG chain and a query and returns the chain’s response. We’ll also implement the option to pass just the stuff documents chain and provide the list of context documents via an additional input parameter. This will come in handy when evaluating the different parts of our RAG system.

Here’s what this function looks like:

from langchain_core.runnables.base import Runnable
from langchain_core.documents import Document

def predict(chain: Runnable, query: str, context: list[Document] | None = None)-> dict:
    """
    Accepts a retrieval chain or a stuff documents chain. If the latter, context must be passed in.
    Return a response dict with keys "input", "context", and "answer"
    """
    inputs = {"input": query}
    if context:
        inputs.update({"context": context})

    response = chain.invoke(inputs)

    result = {
        response["input"]: {
            "context": [d.page_content for d in response['context']],
            "answer": response["answer"],
        }
    }
    return result

Part 2: Evaluating a RAG system using Ragas and neptune.ai

Once a RAG system is built, it’s important to evaluate its performance and establish a baseline. The proper way to do this is by systematically testing it using a representative evaluation dataset. Since such a dataset is not available in our case yet, we’ll have to generate one.

To assess both the retrieval and generation aspects of the system, we’ll use Ragas as the evaluation framework and neptune.ai to track experiments as we iterate.

What is Ragas?

Ragas is an open-source toolkit for evaluating RAG applications. It offers both LLM-based and non-LLM-based metrics to assess the quality of retrieval and generated responses. Ragas works smoothly with LangChain, making it a great choice for evaluating our RAG system.

Step 1: Generate a RAG evaluation dataset

An evaluation set for RAG tasks is similar to a question-answering task dataset. The key difference is that each row includes not just the query and a reference answer but also reference contexts (documents that we expect to be retrieved to answer the query).

Thus, an example evaluation set entry looks like this:

Query

Reference context

Reference answer

How can users trick a chatbot to bypass restrictions?

[‘By prompting the application to pretend to be a chatbot that “can do anything” and is not bound by any restrictions, users were able to manipulate ChatGPT to provide responses to questions it would usually decline to answer.’]

Users trick chatbots to bypass restrictions by prompting the application to pretend to be a chatbot that ‘can do anything’ and is not bound by any restrictions, allowing it to provide responses to questions it would usually decline to answer.

Ragas provides utilities to generate such a dataset from a list of reference documents using an LLM.

As the reference documents, we’ll use the same chunks that we fed into the Chroma vector store in the first part, which is precisely the knowledge base from which our RAG system is drawing.

To test the generative part of our RAG chain, we’ll need to generate example queries and reference answers using a different model. Otherwise, we’d be testing our system’s self-consistency. We’ll use the full-sized GPT-4o model, which should outperform the GPT-4o-mini in our RAG chain.

As in the first part, it is possible to use a different LLM. The LangchainLLMWrapper and LangChainEmbeddingsWrapper make any model available via LangChain accessible to Ragas.

What happens under the hood?

Ragas’ TestSetGenerator builds a knowledge graph in which each node represents a chunk. It extracts information like named entities from the chunks and uses this data to model the relationship between nodes. From the knowledge graph, so-called query synthesizers derive scenarios consisting of a set of nodes, the desired query length and style, and a user persona. This scenario is used to populate a prompt template instructing an LLM to generate a query and answer (example). For more details, refer to the Ragas Testset Generation documentation.

Creating an evaluation dataset with 50 rows for our RAG system should take about a minute. We’ll generate a mixture of abstract queries (“What is concept A?”) and specific queries (“How often does subscription plan B bill its users?”):

from ragas.llms import LangChainLLMWrapper
from ragas.embeddings import LangChainEmbeddingsWrapper
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from ragas.testset import TestsetGenerator
from ragas.testset.synthesizers import AbstractQuerySynthesizer, SpecificQuerySynthesizer

generator_llm = LangChainLLMWrapper(ChatOpenAI(model="gpt-4o"))
generator_embeddings = LangChainEmbeddingsWrapper(OpenAIEmbeddings())

generator = TestsetGenerator(llm=generator_llm, embedding_model=generator_embeddings)

dataset = generator.generate_with_langchain_docs(
    splits,
    testset_size=50,
    query_distribution=[
        (AbstractQuerySynthesizer(llm=generator_llm), 0.1),
        (SpecificQuerySynthesizer(llm=generator_llm), 0.9),
    ],
)

Filtering unwanted data

We want to focus our evaluation on cases where the reference answer is helpful. In particular, we don’t want to include test samples with responses containing phrases like “the context is insufficient” or “the context does not contain.” Duplicate entries in the dataset would skew the evaluation, so they should also be omitted.

For filtering, we’ll use the ability to easily convert Ragas datasets into Pandas DataFrames or Hugging Face Datasets:


unique_indices = set(dataset.to_pandas().drop_duplicates(subset=["user_input"]).index)


not_helpful = set(dataset.to_pandas()[dataset.to_pandas()["reference"].str.contains("does not contain|does not provide|context does not|is insufficient|is incomplete", case=False, regex=True)].index)

unique_helpful_indices = unique_indices - not_helpful

ds = dataset.to_hf_dataset().select(unique_helpful_indices)

This leaves us with unique samples that look like this:

User input

Reference contexts

Reference answer

What role does reflection play in identifying and correcting hallucinations in LLM outputs?

[‘After the responseCorrecting a hallucination after the LLM output has been generated is still beneficial, as it prevents the user from seeing the incorrect information. This approach can effectively transform correction into prevention by ensuring that the erroneous response never reaches the user. The process can be broken down into the following steps:This method is part of multi-step reasoning strategies, which are increasingly important in handling complex problems. These strategies, often referred to as “agents,” are gaining popularity. One well-known agent pattern is reflection. By identifying hallucinations early, you can address and correct them before they impact the user.’]

Reflection plays a role in identifying and correcting hallucinations in LLM outputs by allowing early identification and correction of errors before they impact the user.

What are some examples of LLMs that utilize a reasoning strategy to improve their responses?

[‘Post-training or alignmentIt is hypothesized that an LLM instructed not only to respond and follow instructions but also to take time to reason and reflect on a problem could largely mitigate the hallucination issue—either by providing the correct answer or by stating that it does not know how to answer.Furthermore, you can teach a model to use external tools during the reasoning process,\xa0 like getting information from a search engine. There are a lot of different fine-tuning techniques being tested to achieve this. Some LLMs already working with this reasoning strategy are Matt Shumer’s Reflection-LLama-3.1-70b and OpenAI’s O1 family models.’]

Some examples of LLMs that utilize a reasoning strategy to improve their responses are Matt Shumer’s Reflection-LLama-3.1-70b and OpenAI’s O1 family models.

What distnguishes ‘promt injecton’ frm ‘jailbraking’ in vulnerabilties n handling?

[‘Although “prompt injection” and “jailbreaking” are often used interchangeably in the community, they refer to distinct vulnerabilities that must be handled with different methods.’]

‘Prompt injection’ and ‘jailbreaking’ are distinct vulnerabilities that require different handling methods.

In the third sample, the query contains a lot of typos. This is an example of the “MISSPELLED” query style.

💡 You can find a full example evaluation dataset on Hugging Face.

Step 2: Choose RAG evaluation metrics

As mentioned earlier, Ragas offers both LLM-based and non-LLM-based metrics for RAG system evaluation.

For this example, we’ll focus on LLM-based metrics. LLM-based metrics are more suitable for tasks requiring semantic and contextual understanding than quantitative metrics while being significantly less resource-intensive than having humans evaluate each response. This makes them a reasonable tradeoff despite concerns about reproducibility.

From the wide range of metrics available in Ragas, we’ll select five:

  1. LLM Context Recall measures how many of the relevant documents are successfully retrieved. It uses the reference answer as a proxy for the reference context and determines whether all claims in the reference answer can be attributed to the retrieved context.
  2. Faithfulness measures the generated answer’s factual consistency with the given context by assessing how many claims in the generated answer can be found in the retrieved context.
  3. Factual Correctness evaluates the factual accuracy of the generated answer by assessing whether claims are present in the reference answer (true and false positives) and whether any claims from the reference answer are missing (false negatives). From this information, precision, recall, or F1 scores are calculated.
  4. Semantic Similarity measures the similarity between the reference answer and the generated answer.
  5. Noise Sensitivity measures how often a system makes errors by providing incorrect responses when utilizing either relevant or irrelevant retrieved documents.

Each of these metrics requires specifying an LLM or an embedding model for its calculations. We’ll again use GPT-4o for this purpose:

from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, SemanticSimilarity, NoiseSensitivity
from ragas import EvaluationDataset
from ragas import evaluate

evaluator_llm = LangChainLLMWrapper(ChatOpenAI(model="gpt-4o"))
evaluator_embeddings = LangChainEmbeddingsWrapper(OpenAIEmbeddings())

metrics = [
    LLMContextRecall(llm=evaluator_llm),
    FactualCorrectness(llm=evaluator_llm),
    Faithfulness(llm=evaluator_llm),
    SemanticSimilarity(embeddings=evaluator_embeddings),
    NoiseSensitivity(llm=evaluator_llm),
]

Step 3: Evaluate the baseline RAG system’s performance

To evaluate our baseline RAG system, we’ll generate predictions and analyze them with the five selected metrics.

To speed up the process, we’ll use a concurrent approach to handle the I/O-bound predict calls from the RAG chain. This allows us to process multiple queries in parallel. Afterward, we can convert the results into a data frame for further inspection and manipulation. We’ll also store the results in a CSV file.

Here’s the complete performance evaluation code:

from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset

def concurrent_predict_retrieval_chain(chain: Runnable, dataset: Dataset):
    results = {}
    threads = []
    with ThreadPoolExecutor(max_workers=5) as pool:
        for query in dataset["user_input"]:
            threads.append(pool.submit(predict, chain, query))
        for task in as_completed(threads):
            results.update(task.result())
    return results

predictions = concurrent_predict_retrieval_chain(rag_chain, ds)


ds_k_1 = ds.map(lambda example: {"response": predictions[example["user_input"]]["answer"], "retrieved_contexts": predictions[example["user_input"]]["context"]})

results = evaluate(dataset=EvaluationDataset.from_hf_dataset(ds_k_1), metrics=metrics)


df = results.to_pandas()
df.to_csv("eval_results.csv", index=False)

Part 3: Iteratively refining the RAG performance

With the evaluation setup in place, we can now start to improve our RAG system. Using the initial evaluation results as our baseline, we can systematically make changes to our RAG chain and assess whether they improve performance.

While we could make do with saving all evaluation results in cleanly named files and taking notes, we’d quickly be overwhelmed with the amount of information. To efficiently iterate and keep track of our progress, we’ll need a way to record, analyze, and compare our experiments.

What is neptune.ai?

Neptune is a machine-learning experiment tracker focused on collaboration and scalability. It provides a centralized platform for tracking, logging, and comparing metrics, artifacts, and configurations.

Neptune can track not only single metrics values but also more complex metadata, such as text, arrays, and files. All metadata can be accessed and analyzed through a highly versatile user interface as well as programmatically. All this makes it a great tool for developing RAG systems and other LLM-based applications.

Step 1: Set up neptune.ai for experiment tracking

To get started with Neptune, sign up for a free account at app.neptune.ai and follow the steps to create a new project. Once that’s done, set the project name and API token as environment variables and initialize a run:

os.environ["NEPTUNE_PROJECT"] = "YOUR_PROJECT"
os.environ["NEPTUNE_API_TOKEN"] = "YOUR_API_TOKEN"

import neptune

run = neptune.init_run()

In Neptune, each run corresponds to one tracked experiment. Thus, every time we’ll execute our evaluation script, we’ll start a new experiment.

Logging Ragas metrics to neptune.ai

To make our lives easier, we’ll define a helper function that stores the Ragas evaluation results in the Neptune Run object, which represents the current experiment.

We’ll track the metrics for each sample in the evaluation dataset and an overall performance metric, which in our case is simply the average across all metrics for the entire dataset: 

import io

import neptune
import pandas as pd

def log_detailed_metrics(results_df: pd.DataFrame, run: neptune.Run, k: int):
    run[f"eval/k"].append(k)

    
    for i, row in results_df.iterrows():
        for m in metrics:
            val = row[m.name]
            run[f"eval/q{i}/{m.name}"].append(val)

        
        run[f"eval/q{i}/user_input"] = row["user_input"]
        run[f"eval/q{i}/response"].append(row["response"])
        run[f"eval/q{i}/reference"] = row["reference"]

        
        context_df = pd.DataFrame(
            zip(row["retrieved_contexts"], row["reference_contexts"]
            columns=["retrieved", "reference"],
        )
        context_stream = io.StringIO()
        context_data = context_df.to_csv(
            context_stream, index=True, index_label="k")
        run[f"eval/q{i}/contexts/{k}}"].upload(
            neptune.types.File.from_stream(context_stream, extension="csv")
        )
      
    
    overall_metrics = results_df[[m.name for m in metrics]].mean(axis=0).to_dict()
    for k, v in overall_metrics.items():
        run[f"eval/overall"].append(v)

log_detailed_metrics(df, run, k=1)


run.stop()

Once we run the evaluation and switch to Neptune’s Experiments tab, we see our currently active run and the first round of metrics that we’ve logged.

Step 2: Iterate over a retrieval parameter

In our baseline RAG chain, we only use the first retrieved document chunk in the LLM context. But what if there are relevant chunks ranked lower, perhaps in the top 3 or top 5? To explore this, we can experiment with using different values for k, the number of retrieved documents.

We’ll start by evaluating k = 3 and k = 5 to see how the results change. For each experiment, we instantiate a new retrieval chain, run the prediction and evaluation functions, and log the results for comparison:

for k in [1, 3, 5]:
    retriever_k = vectorstore.as_retriever(search_kwargs={"k": k})
    rag_chain_k = create_retrieval_chain(retriever_k, question_answer_chain)
    predictions_k = concurrent_predict_retrieval_chain(rag_chain_k, ds)

    
    ds_k = ds.map(lambda example: {
        "response": predictions_k[example["user_input"]]["answer"],
        "retrieved_contexts": predictions_k[example["user_input"]]["context"]
    })

    results_k = evaluate(dataset=EvaluationDataset.from_hf_dataset(ds_k), metrics=metrics)
    df_k = results_k.to_pandas()

    
    df_k.to_csv("eval_results.csv", index=False)
    run[f"eval/eval_data/{k}"].upload("eval_results.csv")

    log_detailed_metrics(df_k, run, k)


run.stop()

Once the evaluation is complete (this should take between 5 and 10 minutes), the script should display “Shutting down background jobs” and show “Done!” once the process is finished.

Results overview

Let’s take a look at the results. Navigate to the Charts tab. The graphs all share a common x-axis labeled “step.” The evaluations for k = [1, 3, 5] are recorded as steps [0, 1, 2].


Comparison of metrics values over three different values of k: The averaged metrics values over all samples (top row) and the metric values for the first sample question (bottom row) indicate that the third step (k = 5) yielded the best outcome.

Looking at the overall metrics, we can observe that increasing k has improved most metrics. Factual correctness decreases by a small amount. Additionally, noise sensitivity, where a lower value is preferable, increased. This is expected since increasing k will lead to more irrelevant chunks being included in the context. However, as both context recall and answer semantic similarity have gone up, it seems to be a worthy tradeoff.

Step 3: Iterate further

From here on, there are numerous possibilities for further experimentation, for example:

  • Trying different chunking strategies, such as semantic chunking, which determines the breakpoints between chunks based on semantic similarity rather than strict token counts.
  • Leveraging hybrid search, which combines keyword search algorithms like BM25 and semantic search with embeddings.
  • Trying other models that excel at question-answering tasks, like the Anthropic models, which are also available through LangChain.
  • Adding support components for dialogue systems, such as chat history.

Looking ahead

In the three parts of this tutorial, we’ve used LangChain to build a RAG system based on OpenAI models and the Chroma vector database, evaluated it with Ragas, and analyzed our progress with Neptune. Along the way, we explored essential foundations of developing performant RAG systems, such as:

  • How to efficiently chunk, store, and retrieve data to ensure our RAG system consistently delivers relevant and accurate responses to user queries.
  • How to generate an evaluation dataset for our particular RAG chain and use RAG-specific metrics like faithfulness and factual correctness to evaluate it.
  • How Neptune makes it easy to track, visualize, and analyze RAG system performance, allowing us to take a systematic approach when iteratively improving our application.

As we saw at the end of part 3, we’ve barely scratched the surface when it comes to improving retrieval performance and response quality. Using the triplet of tools we introduced and our evaluation setup, any new technique or change applied to the RAG system can be assessed and compared with alternative configurations. This allows us to confidently assess whether a modification improves performance and detect unwanted side effects.

Was the article useful?

Explore more content topics:

Popular Posts

My Favorites