Skip to content

Commit b67cabb

Browse files
committed
Tensor parallel Llama3 tutorial illustrating use of torch.distributed and nccl ops
1 parent 543bc9b commit b67cabb

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Tutorials
6868
* :ref:`mutable_torchtrt_module_example`
6969
* :ref:`weight_streaming_example`
7070
* :ref:`pre_allocated_output_example`
71+
* :ref:`tensor_parallel_llama`
7172

7273
.. toctree::
7374
:caption: Tutorials
@@ -87,6 +88,7 @@ Tutorials
8788
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
8889
tutorials/_rendered_examples/dynamo/weight_streaming_example
8990
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
91+
tutorials/_rendered_examples/dynamo/tensor_parallel_llama
9092

9193
Dynamo Frontend
9294
----------------

examples/distributed_inference/tensor_parallel_llama3.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
11
# Taken and modified pytorch lightening
22
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
3+
"""
4+
.. _tensor_parallel_llama:
5+
6+
Torch distributed example for llama3-7B model
7+
======================================================
8+
9+
As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured"""
10+
11+
# %%
12+
# Imports and Model Definition
13+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14+
315
import logging
416
import os
517
import time
618

719
import torch
20+
21+
# %%
22+
# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
23+
# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model
824
from llama3_model import ModelArgs, ParallelTransformer
925
from tensor_parallel_initialize_dist import initialize_distributed_env
1026
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
@@ -14,11 +30,24 @@
1430
checkpoint_wrapper,
1531
)
1632

33+
# %%
34+
# Initialize the distributed environment
35+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
36+
37+
# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts
38+
# eg operations: allreduce, allgather, reduce_gather
39+
# NCCL operations enable these operations.
40+
# The below API does the following
41+
# Initialize the communicators and the distributed environment
42+
# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough
43+
# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log`
1744
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1845
"./tensor_parallel_llama3"
1946
)
20-
# Import should be after initialization of the TRT-LLM plugin .so path
21-
import tensorrt_llm
47+
48+
# %%
49+
# Model initialization with torch distributed parallel plan
50+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2251

2352
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2453
assert (
@@ -36,7 +65,39 @@
3665
)
3766

3867
with torch.no_grad():
68+
# The plan is
69+
# plan = {
70+
# "attention": PrepareModuleInput(
71+
# input_layouts=(Shard(1), None),
72+
# desired_input_layouts=(Replicate(), None),
73+
# ),
74+
# "attention.wq": ColwiseParallel(),
75+
# "attention.wk": ColwiseParallel(),
76+
# "attention.wv": ColwiseParallel(),
77+
# "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
78+
# "attention_norm": SequenceParallel(),
79+
# "feed_forward": PrepareModuleInput(
80+
# input_layouts=(Shard(1),),
81+
# desired_input_layouts=(Replicate(),),
82+
# ),
83+
# "feed_forward.w1": ColwiseParallel(),
84+
# "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
85+
# "feed_forward.w3": ColwiseParallel(),
86+
# "ffn_norm": SequenceParallel(),
87+
# }
88+
3989
model = ParallelTransformer(model_args, device_mesh)
90+
91+
# %%
92+
# Model inference with Torch-TensorRT backend
93+
# -------------------------------------------
94+
# When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model
95+
# on multiple GPUs and the communicator operations are used for proper communication. In the above,
96+
# `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion.
97+
# `SequenceParallel` performs sharded computations of the normalization layer
98+
# `PrepareModuleInput` configures the model input with proper communication operations
99+
# The NCCL operations used in the distributed backend is handled by the TensorRT-LLM NCCL plugins, which causes no graph breaks now
100+
40101
torch.manual_seed(0)
41102
inp = torch.randint(32000, (8, 256), device="cuda")
42103
python_result = model(inp)
@@ -62,9 +123,11 @@
62123
output = model(inp)
63124
end = time.time()
64125
if i == 0:
126+
# Logging the Compilation time
65127
logger.info(f"Compilation time is {end-start}")
66128
assert (
67129
python_result - output
68130
).std() < 0.01, "Compilation result is not correct."
69131
elif _rank == 0:
132+
# Logging the inference time
70133
logger.info(f"Inference time is {end-start}")

0 commit comments

Comments
 (0)