I not too long ago got here throughout a put up by Sebastian that caught my consideration, and I needed to dive deeper into its content material. As fashions develop bigger and extra advanced, effectively managing reminiscence throughout mannequin loading turns into more and more vital, particularly when working with restricted GPU or CPU assets. In his put up, Sebastian covers sensible ideas for loading bigger pre-trained or fine-tuned fashions in constrained reminiscence environments, which is especially related when working with PyTorch.
This information emphasizes deal with conditions the place fashions are saved utilizing torch.save(mannequin.state_dict(), "mannequin.pth")
and later should be loaded for continued pre-training or additional fine-tuning. Whereas the examples give attention to a big language mannequin (LLM), Sebastian’s strategies are broadly relevant to any PyTorch mannequin. Moreover, they supply worthwhile insights into memory-efficient mannequin weight loadingy in PyTorch, serving to optimize reminiscence utilization in the course of the loading course of.
Overview
- Environment friendly reminiscence administration is essential for loading giant neural networks in PyTorch, particularly on methods with restricted GPU or CPU assets.
- As a substitute of loading your complete mannequin directly, you possibly can load weights incrementally.Usually, calling
mannequin.to(gadget)
strikes all of the mannequin’s parameters to the gadget (like a GPU), which might devour important reminiscence. - PyTorch launched the “meta” gadget, which permits for the creation of tensors with out utilizing reminiscence.
- By using the meta gadget, you possibly can load weights immediately into GPU reminiscence, bypassing the CPU and optimizing reminiscence utilization.
Preliminary Setup: Surroundings Test
Earlier than diving into the specifics, let’s be certain that the required packages and variations can be found. Right here’s a snippet that checks for the model of PyTorch and different helpful instruments.
from importlib.metadata import model
pkgs = [
"torch",
]
for p in pkgs:
print(f"{p} model: {model(p)}")
Benchmark Utilities for Reminiscence Monitoring
Step one is to arrange a utility to trace GPU reminiscence (VRAM). Monitoring reminiscence utilization helps in understanding how totally different strategies influence reminiscence load throughout mannequin loading and inference. Later, we may even monitor the system’s RAM (CPU reminiscence).
Right here’s the utility code for GPU reminiscence monitoring:
import gc
import time
import torch
def start_memory_tracking():
"""Initialize GPU reminiscence monitoring."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
else:
print("This pocket book is meant for CUDA GPUs however CUDA shouldn't be out there.")
def print_memory_usage():
max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB
print(f"Most GPU reminiscence allotted: {max_gpu_memory:.1f} GB")
def cleanup():
gc.gather()
torch.cuda.empty_cache()
time.sleep(3) # Permit time for reminiscence to clear
torch.cuda.reset_peak_memory_stats()
max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Most GPU reminiscence allotted: {max_memory_allocated:.1f} GB")
These capabilities assist monitor GPU reminiscence utilization earlier than, throughout, and after mannequin operations. The cleanup() operate is particularly helpful for clearing unused reminiscence to keep away from working out of VRAM.
Mannequin Setup
Subsequent, we arrange the mannequin. For demonstration, we are going to use the “GPT-2 giant” mannequin (although you possibly can regulate the mannequin dimension to fit your reminiscence constraints). By altering the configuration, the mannequin dimension can vary from “gpt2-small” (124M parameters) to “gpt2-xl” (1558M parameters).
Right here’s the configuration:
from previous_chapters import GPTModel
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary dimension
"context_length": 1024, # Context size
"drop_rate": 0.0, # Dropout charge
"qkv_bias": True # Question-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
CHOOSE_MODEL = "gpt2-xl (1558M)"
BASE_CONFIG.replace(model_configs[CHOOSE_MODEL])
This configuration permits flexibility in selecting fashions based mostly on out there reminiscence assets. For decrease reminiscence consumption, deciding on a smaller variant (like gpt2-small) is advisable.
As soon as the mannequin configuration is about up, the subsequent steps will dive into loading, managing, and optimizing the mannequin weights for environment friendly reminiscence utilization.
Monitoring GPU Reminiscence Throughout Mannequin Loading
Let’s now put the GPU reminiscence monitoring utilities into motion. First, we initialize reminiscence monitoring and cargo the mannequin to look at reminiscence consumption. The code under tracks GPU reminiscence utilization as we load and run a GPT mannequin.
start_memory_tracking()
mannequin = GPTModel(BASE_CONFIG)
gadget = torch.gadget("cuda")
mannequin.to(gadget)
print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.4 GB
This reveals that loading and putting the mannequin onto the GPU consumes round 6.4 GB of VRAM, which is typical for bigger fashions like GPT-2. Nonetheless, that is simply the preliminary setup.
Working the Mannequin
To confirm that every part works appropriately, we move a easy enter tensor to the mannequin. Though we aren’t monitoring reminiscence throughout this step, it’s important to test that the mannequin operates as anticipated.
# Take a look at if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(gadget)
mannequin.eval()
with torch.no_grad():
mannequin(test_input)
Saving the Mannequin
Now, think about we’re pretraining the mannequin (or finetuning it). For this instance, we skip the precise pretraining course of and immediately save the initialized mannequin. The next code saves the mannequin’s weights utilizing torch.save().
# Coaching code would go right here...
mannequin.prepare()
torch.save(mannequin.state_dict(), "mannequin.pth")
Reminiscence Cleanup
After saving the mannequin, it’s vital to unlock GPU reminiscence to make sure environment friendly useful resource administration in subsequent operations. By deleting the mannequin and the check enter tensor, after which working our cleanup() operate, we clear up VRAM.
del mannequin, test_input
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB
At this level, the GPU reminiscence utilization is reset to zero, as anticipated.
Loading Pretrained Mannequin Weights
The following step includes reloading the saved mannequin weights to proceed coaching or finetuning. Nonetheless, loading pretrained weights requires extra GPU reminiscence than initializing a contemporary mannequin as a result of the mannequin’s weights are loaded twice: as soon as when loading the mannequin itself, and once more when loading the weights into reminiscence.
# Begin monitoring reminiscence
start_memory_tracking()
# Recreate the mannequin structure
mannequin = GPTModel(BASE_CONFIG)
mannequin.to(gadget)
# Load the saved state_dict
mannequin.load_state_dict(
torch.load("mannequin.pth", map_location=gadget, weights_only=True)
)
mannequin.to(gadget)
mannequin.eval()
print_memory_usage()
# Output: Most GPU reminiscence allotted: 12.8 GB
The GPU reminiscence utilization has now doubled in comparison with the preliminary load, peaking at 12.8 GB. This occurs as a result of, for a brief interval, each the unique mannequin and the newly loaded weights are held in reminiscence. Ultimately, the loaded weights are copied into the mannequin, and the momentary state_dict is discarded. Nonetheless, this reminiscence spike may cause points when working with restricted assets.
Resetting GPU Reminiscence
After loading the mannequin weights and testing it, it’s important to reset the GPU reminiscence as soon as once more. Testing the mannequin ensures it really works as anticipated, and clearing reminiscence is essential for environment friendly useful resource utilization.
# Take a look at if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(gadget)
mannequin.eval()
with torch.no_grad():
mannequin(test_input)
del mannequin, test_input
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB
This reset brings GPU reminiscence utilization again to zero, making certain a clear state for future operations.
Loading Weights Sequentially
One efficient workaround for the issue of double reminiscence utilization when loading mannequin weights is sequential loading. As a substitute of loading each the mannequin and weights concurrently into GPU reminiscence, we will load the mannequin first, hold the weights in CPU reminiscence, after which copy every parameter one after the other to the GPU. This methodology considerably reduces the height reminiscence utilization.
Right here’s implement sequential weight loading:
Step-by-Step Breakdown:
- Load the Mannequin onto the GPU: First, we load the mannequin structure into GPU reminiscence, as typical.
- Load the Weights onto the CPU: The mannequin weights are loaded onto CPU reminiscence, avoiding the preliminary reminiscence spike brought on by shifting each the mannequin and the weights to the GPU.
- Copy Weights Parameter by Parameter: Every weight is then copied sequentially from the CPU to GPU, which means that at no level do now we have each the mannequin and the complete state_dict in GPU reminiscence.
The code under demonstrates this strategy:
start_memory_tracking()
# Load the mannequin into GPU reminiscence
mannequin = GPTModel(BASE_CONFIG).to(gadget)
# Load the mannequin's saved state_dict onto the CPU
state_dict = torch.load("mannequin.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.4 GB
# Copy every parameter to GPU reminiscence one after the other
with torch.no_grad():
for title, param in mannequin.named_parameters():
if title in state_dict:
param.copy_(state_dict[name].to(gadget))
else:
print(f"Warning: {title} not present in state_dict.")
print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.7 GB
Reminiscence Comparability:
- Initially, the mannequin alone occupies
- As we copy every parameter sequentially, the reminiscence will increase barely to
Nonetheless, it is a a lot smaller peak in comparison with the 12.8 GB required when loading every part directly. By sequentially loading the weights, we keep away from having each the complete mannequin and the complete set of weights in GPU reminiscence concurrently.
Mannequin Testing & Reminiscence Reset:
After copying the weights, we check the mannequin to make sure every part works as anticipated. Lastly, we reset the GPU reminiscence to clear any lingering objects, simply as we did in earlier steps.
# Take a look at if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(gadget)
mannequin.eval()
with torch.no_grad():
mannequin(test_input)
# Clear up GPU reminiscence
del mannequin, test_input, state_dict, param
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB
Loading the Mannequin with Low CPU Reminiscence
Within the earlier part, we lowered GPU reminiscence utilization by loading mannequin weights into CPU reminiscence first after which sequentially copying them into the GPU. However what if the machine has restricted CPU reminiscence and bigger GPU reminiscence? To deal with this, we will use PyTorch’s “meta” gadget strategy, which is right for machines with constrained CPU assets.
Meta Machine: A Good Tradeoff
The “meta” gadget is a particular gadget sort in PyTorch that creates “meta” tensors. These tensors characterize the form and sort of the info with out allocating reminiscence for the info itself. This permits us to outline fashions with out consuming CPU or GPU reminiscence till vital.
Utilizing the meta gadget, we will first initialize the mannequin with none reminiscence allocation, after which load the mannequin weights immediately into GPU reminiscence, bypassing the CPU.
Monitoring CPU Reminiscence Utilization
Earlier than we dive into the meta gadget strategy, we are going to outline a utility operate to trace CPU reminiscence utilization:
import os
import psutil
from threading import Thread
def memory_usage_in_gb(func, *args, **kwargs):
course of = psutil.Course of(os.getpid())
baseline_mem = course of.memory_info().rss / 1024 ** 3 # in GB
mem_usage = []
carried out = False
def monitor_memory():
whereas not carried out:
mem_usage.append(course of.memory_info().rss / 1024 ** 3) # Convert to GB
time.sleep(0.1)
t = Thread(goal=monitor_memory)
t.begin()
func(*args, **kwargs)
carried out = True
t.be a part of()
peak_mem_usage_gb = max(mem_usage) - baseline_mem
return peak_mem_usage_gb
Now that we will measure CPU reminiscence utilization, let’s monitor the reminiscence used in the course of the sequential weight loading strategy from the earlier part:
def load_sequentially():
start_memory_tracking()
mannequin = GPTModel(BASE_CONFIG).to(gadget)
state_dict = torch.load("mannequin.pth", map_location="cpu", weights_only=True)
print_memory_usage()
# Sequentially copy weights to the mannequin's parameters
with torch.no_grad():
for title, param in mannequin.named_parameters():
if title in state_dict:
param.copy_(state_dict[name].to(gadget))
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")
This strategy outputs:
- Most GPU reminiscence allotted: 6.7 GB
- Most CPU reminiscence allotted: 6.3 GB
Meta Machine Strategy
To additional cut back CPU reminiscence utilization, we will use the meta gadget to load the mannequin with out allocating reminiscence till we really need it. Right here’s the implementation:
def load_sequentially_with_meta():
start_memory_tracking()
with torch.gadget("meta"):
mannequin = GPTModel(BASE_CONFIG)
mannequin = mannequin.to_empty(gadget=gadget)
state_dict = torch.load("mannequin.pth", map_location=gadget, weights_only=True)
print_memory_usage()
# Sequentially copy weights to the mannequin's parameters
with torch.no_grad():
for title, param in mannequin.named_parameters():
if title in state_dict:
param.copy_(state_dict[name])
print_memory_usage()
peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")
Reminiscence Utilization with Meta Machine:
- Most GPU reminiscence allotted: 12.8 GB
- Most CPU reminiscence allotted: 1.3 GB
By utilizing the meta gadget and immediately loading the mannequin weights into GPU reminiscence, we drastically cut back CPU reminiscence consumption from 6.3 GB to only 1.3 GB.
Comparability with Baseline
Lastly, let’s examine this methodology with the easy PyTorch weight loading methodology, the place no meta gadget or sequential loading is used:
def baseline():
start_memory_tracking()
mannequin = GPTModel(BASE_CONFIG)
mannequin.to(gadget)
mannequin.load_state_dict(torch.load("mannequin.pth", map_location=gadget, weights_only=True))
mannequin.to(gadget)
mannequin.eval()
print_memory_usage()
peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")
For this strategy:
- Most GPU reminiscence allotted: 12.8 GB
- Most CPU reminiscence allotted: 4.4 GB
Utilizing mmap=True for Environment friendly Mannequin Loading
For extra superior customers of PyTorch, there’s another strategy to dealing with reminiscence constraints when loading giant fashions—utilizing the mmap=True setting in torch.load(). This setting leverages memory-mapped file I/O, which permits the mannequin to learn knowledge immediately from disk with out totally loading it into RAM. That is notably helpful on methods with restricted CPU reminiscence, because it minimizes the reminiscence footprint throughout mannequin loading.
What’s mmap=True?
Reminiscence-mapped I/O (mmap) is a mechanism that allows a file to be learn immediately from disk by mapping it into the digital tackle area. As a substitute of loading your complete mannequin into RAM, PyTorch can load elements of the mannequin on demand, successfully decreasing reminiscence utilization. This may be notably advantageous when coping with giant pretrained or finetuned fashions, equivalent to GPT-2 or GPT-3, on machines with restricted assets.
The mmap=True possibility might be added when calling torch.load() to attain this conduct.
Instance Implementation of mmap=True
Let’s see how the mmap=True possibility works in follow. Under is a pattern implementation the place we load a mannequin utilizing this setting:
def best_practices():
with torch.gadget("meta"):
mannequin = GPTModel(BASE_CONFIG)
mannequin.load_state_dict(
torch.load("mannequin.pth", map_location=gadget, weights_only=True, mmap=True),
assign=True
)
print_memory_usage()
peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")
Outcomes with mmap=True
- Most GPU reminiscence allotted: 6.4 GB
- Most CPU reminiscence allotted: 5.9 GB
Right here, we see that the GPU reminiscence utilization stays environment friendly (6.4 GB), and CPU reminiscence utilization is pretty excessive as a result of the machine has sufficient CPU RAM to help it. Nonetheless, on a system with restricted CPU RAM, the mmap=True strategy would use much less reminiscence by avoiding loading the complete mannequin into RAM.
When to Use mmap=True
The mmap=True possibility is particularly useful within the following eventualities:
- Restricted CPU RAM
- Disk I/O Pace
Efficiency Concerns
At first look, the mmap=True strategy might sound much less environment friendly in comparison with the sequential weight loading strategy. Nonetheless, for machines with restricted CPU reminiscence, mmap=True could be a game-changer, offering an efficient strategy to load giant fashions with out overwhelming the CPU’s out there reminiscence.
By utilizing mmap=True, you’re balancing disk entry with reminiscence availability, which might help in environments the place reminiscence is scarce however disk I/O is quick.
Different Strategies for Mannequin Weight Loading
On this pocket book, we’ve targeted on easy, built-in strategies for effectively loading mannequin weights in PyTorch, notably when reminiscence (both GPU or CPU) is constrained. The really helpful methodology for managing restricted CPU reminiscence is the mmap=True strategy, as defined beforehand.
Nonetheless, in case you’re coping with excessive reminiscence limitations or want extra management over the method, there’s one other brute-force strategy: saving and loading every weight tensor individually.
Saving Mannequin Weights Individually
As a substitute of saving your complete state_dict as a single file, this methodology shops every mannequin parameter (tensor) individually. This lets you load every parameter one by one, stopping the necessity to maintain your complete mannequin in reminiscence concurrently.
Right here’s how one can save the mannequin weights individually:
mannequin = GPTModel(BASE_CONFIG)
# Assume `mannequin` is your educated mannequin
state_dict = mannequin.state_dict()
# Create a listing to retailer particular person parameter information
os.makedirs("model_parameters", exist_ok=True)
# Save every parameter tensor individually
for title, param in state_dict.gadgets():
torch.save(param.cpu(), f"model_parameters/{title}.pt")
del mannequin # Release GPU reminiscence
This breaks the mannequin into particular person elements, saving every tensor to its personal file within the “model_parameters” listing.
Loading Weights Individually
Now, let’s see how one can load these weights one-by-one to keep away from overwhelming reminiscence utilization.
def load_individual_weights():
start_memory_tracking()
with torch.gadget("meta"):
mannequin = GPTModel(BASE_CONFIG)
mannequin = mannequin.to_empty(gadget=gadget)
print_memory_usage()
param_dir = "model_parameters"
with torch.no_grad():
for title, param in mannequin.named_parameters():
weight_path = os.path.be a part of(param_dir, f"{title}.pt")
if os.path.exists(weight_path):
param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
param.copy_(param_data.to(gadget)) # Transfer tensor to GPU
del param_data # Free reminiscence after copying
else:
print(f"Warning: {title} not present in {param_dir}.")
print_memory_usage()
Outcomes from Particular person Weight Loading
- Most GPU reminiscence allotted: 6.4 GB
- Most CPU reminiscence allotted: 0.3 GB
The reminiscence footprint right here is considerably lowered—each on the GPU and CPU. By loading weights individually, you make sure that no pointless reminiscence is consumed at any stage, making this strategy ideally suited for terribly memory-limited environments.
When to Use This Methodology
- Excessive Reminiscence Limitations
When CPU and GPU reminiscence are each extremely constrained, this methodology affords exact management, making certain that just one parameter tensor is loaded into reminiscence at any given time.
On machines the place you can’t afford to make use of greater than minimal assets, this brute-force methodology gives an answer to make sure you can load even the most important fashions.
Efficiency Concerns
The trade-off right here is efficiency. Since every tensor is loaded individually, this methodology incurs further disk I/O, which can decelerate the loading course of in comparison with strategies that load your complete mannequin or bigger chunks of knowledge directly.
When working with giant fashions, equivalent to GPT variants or different deep studying fashions, reminiscence effectivity is essential. Strategies like sequential weight loading, utilizing the meta gadget, and enabling mmap=True
assist cut back reminiscence utilization on each CPU and GPU. These strategies, recognized for memory-efficient mannequin weight loading in PyTorch, are extremely versatile and might be tailored relying on the particular constraints of your {hardware} setting, whether or not you might have restricted CPU RAM, GPU VRAM, or each.
By using these methods, you possibly can work with giant fashions even on constrained {hardware}, making certain easy mannequin coaching and fine-tuning workflows.
Hope you just like the article! Reminiscence-efficient mannequin weight loading in PyTorch helps save assets. Utilizing reminiscence environment friendly mannequin weight loading in Python can cut back overhead. For a reminiscence environment friendly mannequin weight loading in PyTorch instance, attempt utilizing torch.load()
with reminiscence mapping to decrease RAM utilization.
Often Requested Questions
As deep studying fashions develop bigger (particularly fashions like GPT-2, GPT-3), effectively loading these fashions turns into important to stop working out of GPU or CPU reminiscence. Reminiscence-efficient loading lets you work with giant fashions even in constrained environments.
You should use the capabilities torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() to trace GPU reminiscence utilization earlier than, throughout, and after loading or coaching fashions. The offered utility capabilities assist monitor reminiscence utilization effectively.
Sequential weight loading includes loading the mannequin structure onto the GPU after which transferring weights one by one from CPU to GPU. This reduces the height reminiscence utilization in comparison with loading each the mannequin and its weights directly, serving to handle restricted GPU reminiscence.
Use decrease precision: float16, combined precision.
Optimize tensor operations: keep away from copies, environment friendly shapes, views.
Gradient accumulation: replace weights much less regularly.
Cut back mannequin dimension: prune connections, quantize weights, smaller fashions.
Optimize knowledge loading: knowledge loaders, prefetching, memory-mapped information.
GPU reminiscence effectivity: monitor utilization, free unused reminiscence, a number of GPUs.
Superior methods: data distillation, low-rank approximation.
The “meta” gadget lets you initialize fashions with out allocating reminiscence for his or her parameters. That is helpful when you might have restricted CPU reminiscence since you possibly can later load weights immediately into the GPU, bypassing the necessity for big reminiscence allocations on the CPU.