• About Us
  • Privacy Policy
  • Disclaimer
  • Contact Us
AimactGrow
  • Home
  • Technology
  • AI
  • SEO
  • Coding
  • Gaming
  • Cybersecurity
  • Digital marketing
No Result
View All Result
  • Home
  • Technology
  • AI
  • SEO
  • Coding
  • Gaming
  • Cybersecurity
  • Digital marketing
No Result
View All Result
AimactGrow
No Result
View All Result

Prepare Your Massive Mannequin on A number of GPUs with Tensor Parallelism

Admin by Admin
January 18, 2026
Home AI
Share on FacebookShare on Twitter


import dataclasses

import datetime

import os

Ā 

import datasets

import tokenizers

import torch

import torch.distributed as dist

import torch.nn as nn

import torch.nn.practical as F

import torch.optim.lr_scheduler as lr_scheduler

import tqdm

from torch import Tensor

from torch.distributed.checkpoint import load, save

from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner

from torch.distributed.fsdp import FSDPModule, fully_shard

from torch.distributed.tensor import Replicate, Shard

from torch.distributed.tensor.parallel import (

Ā Ā Ā Ā ColwiseParallel,

Ā Ā Ā Ā PrepareModuleInput,

Ā Ā Ā Ā RowwiseParallel,

Ā Ā Ā Ā SequenceParallel,

Ā Ā Ā Ā loss_parallel,

Ā Ā Ā Ā parallelize_module,

)

from torch.utils.information.distributed import DistributedSampler

Ā 

# Set default to bfloat16

torch.set_default_dtype(torch.bfloat16)

print(“NCCL model:”, torch.cuda.nccl.model())

Ā 

# Construct the mannequin

@dataclasses.dataclass

class LlamaConfig:

Ā Ā Ā Ā “”“Outline Llama mannequin hyperparameters.”“”

Ā Ā Ā Ā vocab_size: int = 50000Ā Ā # Dimension of the tokenizer vocabulary

Ā Ā Ā Ā max_position_embeddings: int = 2048Ā Ā # Most sequence size

Ā Ā Ā Ā hidden_size: int = 768Ā Ā # Dimension of hidden layers

Ā Ā Ā Ā intermediate_size: int = 4*768Ā Ā # Dimension of MLP’s hidden layer

Ā Ā Ā Ā num_hidden_layers: int = 12Ā Ā # Variety of transformer layers

Ā Ā Ā Ā num_attention_heads: int = 12Ā Ā # Variety of consideration heads

Ā Ā Ā Ā num_key_value_heads: int = 3Ā Ā # Variety of key-value heads for GQA

Ā 

Ā 

class RotaryPositionEncoding(nn.Module):

Ā Ā Ā Ā “”“Rotary place encoding.”“”

Ā 

Ā Ā Ā Ā def __init__(self, dim: int, max_position_embeddings: int) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā “”“Initialize the RotaryPositionEncoding module.

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā Args:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā dim: The hidden dimension of the enter tensor to which RoPE is utilized

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā max_position_embeddings: The utmost sequence size of the enter tensor

Ā Ā Ā Ā Ā Ā Ā Ā ““”

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā self.dim = dim

Ā Ā Ā Ā Ā Ā Ā Ā self.max_position_embeddings = max_position_embeddings

Ā Ā Ā Ā Ā Ā Ā Ā # compute a matrix of ntheta_i

Ā Ā Ā Ā Ā Ā Ā Ā N = 10_000.0

Ā Ā Ā Ā Ā Ā Ā Ā inv_freq = 1.0 / (N ** (torch.arange(0, dim, 2) / dim))

Ā Ā Ā Ā Ā Ā Ā Ā inv_freq = torch.cat((inv_freq, inv_freq), dim=–1)

Ā Ā Ā Ā Ā Ā Ā Ā place = torch.arange(max_position_embeddings)

Ā Ā Ā Ā Ā Ā Ā Ā sinusoid_inp = torch.outer(place, inv_freq)

Ā Ā Ā Ā Ā Ā Ā Ā # save cosine and sine matrices as buffers, not parameters

Ā Ā Ā Ā Ā Ā Ā Ā self.register_buffer(“cos”, sinusoid_inp.cos())

Ā Ā Ā Ā Ā Ā Ā Ā self.register_buffer(“sin”, sinusoid_inp.sin())

Ā 

Ā Ā Ā Ā def ahead(self, x: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā “”“Apply RoPE to tensor x.

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā Args:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā x: Enter tensor of form (batch_size, seq_length, num_heads, head_dim)

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā Returns:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Output tensor of form (batch_size, seq_length, num_heads, head_dim)

Ā Ā Ā Ā Ā Ā Ā Ā ““”

Ā Ā Ā Ā Ā Ā Ā Ā batch_size, seq_len, num_heads, head_dim = x.form

Ā Ā Ā Ā Ā Ā Ā Ā gadget = x.gadget

Ā Ā Ā Ā Ā Ā Ā Ā dtype = x.dtype

Ā Ā Ā Ā Ā Ā Ā Ā # remodel the cosine and sine matrices to 4D tensor and the identical dtype as x

Ā Ā Ā Ā Ā Ā Ā Ā cos = self.cos.to(gadget, dtype)[:seq_len].view(1, seq_len, 1, –1)

Ā Ā Ā Ā Ā Ā Ā Ā sin = self.sin.to(gadget, dtype)[:seq_len].view(1, seq_len, 1, –1)

Ā Ā Ā Ā Ā Ā Ā Ā # apply RoPE to x

Ā Ā Ā Ā Ā Ā Ā Ā x1, x2 = x.chunk(2, dim=–1)

Ā Ā Ā Ā Ā Ā Ā Ā rotated = torch.cat((–x2, x1), dim=–1)

Ā Ā Ā Ā Ā Ā Ā Ā output = (x * cos) + (rotated * sin)

Ā Ā Ā Ā Ā Ā Ā Ā return output

Ā 

Ā 

class LlamaAttention(nn.Module):

Ā Ā Ā Ā “”“Grouped-query consideration with rotary embeddings.”“”

Ā 

Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā self.hidden_size = config.hidden_size

Ā Ā Ā Ā Ā Ā Ā Ā self.num_heads = config.num_attention_heads

Ā Ā Ā Ā Ā Ā Ā Ā self.head_dim = self.hidden_size // self.num_heads

Ā Ā Ā Ā Ā Ā Ā Ā self.num_kv_heads = config.num_key_value_headsĀ Ā # GQA: H_kv < H_q

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # hidden_size have to be divisible by num_heads

Ā Ā Ā Ā Ā Ā Ā Ā assert (self.head_dim * self.num_heads) == self.hidden_measurement

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Linear layers for Q, Ok, V projections

Ā Ā Ā Ā Ā Ā Ā Ā self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)

Ā Ā Ā Ā Ā Ā Ā Ā self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

Ā Ā Ā Ā Ā Ā Ā Ā self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)

Ā Ā Ā Ā Ā Ā Ā Ā self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

Ā 

Ā Ā Ā Ā def ahead(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā bs, seq_len, dim = hidden_states.measurement()

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Challenge inputs to Q, Ok, V

Ā Ā Ā Ā Ā Ā Ā Ā query_states = self.q_proj(hidden_states).view(bs, seq_len, self.num_heads, self.head_dim)

Ā Ā Ā Ā Ā Ā Ā Ā key_states = self.k_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

Ā Ā Ā Ā Ā Ā Ā Ā value_states = self.v_proj(hidden_states).view(bs, seq_len, self.num_kv_heads, self.head_dim)

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Apply rotary place embeddings

Ā Ā Ā Ā Ā Ā Ā Ā query_states = rope(query_states)

Ā Ā Ā Ā Ā Ā Ā Ā key_states = rope(key_states)

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Transpose tensors from BSHD to BHSD dimension for scaled_dot_product_attention

Ā Ā Ā Ā Ā Ā Ā Ā query_states = query_states.transpose(1, 2)

Ā Ā Ā Ā Ā Ā Ā Ā key_states = key_states.transpose(1, 2)

Ā Ā Ā Ā Ā Ā Ā Ā value_states = value_states.transpose(1, 2)

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Use PyTorch’s optimized consideration implementation

Ā Ā Ā Ā Ā Ā Ā Ā # setting is_causal=True is incompatible with setting express consideration masks

Ā Ā Ā Ā Ā Ā Ā Ā attn_output = F.scaled_dot_product_attention(

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā query_states,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā key_states,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā value_states,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā attn_mask=attn_mask,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā dropout_p=0.0,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā enable_gqa=True,

Ā Ā Ā Ā Ā Ā Ā Ā )

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Transpose output tensor from BHSD to BSHD dimension, reshape to 3D, after which venture output

Ā Ā Ā Ā Ā Ā Ā Ā attn_output = attn_output.transpose(1, 2).reshape(bs, seq_len, self.hidden_size)

Ā Ā Ā Ā Ā Ā Ā Ā attn_output = self.o_proj(attn_output)

Ā Ā Ā Ā Ā Ā Ā Ā return attn_output

Ā 

Ā 

class LlamaMLP(nn.Module):

Ā Ā Ā Ā “”“Feed-forward community with SwiGLU activation.”“”

Ā 

Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā # Two parallel projections for SwiGLU

Ā Ā Ā Ā Ā Ā Ā Ā self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

Ā Ā Ā Ā Ā Ā Ā Ā self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)

Ā Ā Ā Ā Ā Ā Ā Ā self.act_fn = F.siluĀ Ā # SwiGLU activation perform

Ā Ā Ā Ā Ā Ā Ā Ā # Challenge again to hidden measurement

Ā Ā Ā Ā Ā Ā Ā Ā self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

Ā 

Ā Ā Ā Ā def ahead(self, x: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā # SwiGLU activation: multiply gate and up-projected inputs

Ā Ā Ā Ā Ā Ā Ā Ā gate = self.act_fn(self.gate_proj(x))

Ā Ā Ā Ā Ā Ā Ā Ā up = self.up_proj(x)

Ā Ā Ā Ā Ā Ā Ā Ā return self.down_proj(gate * up)

Ā 

Ā 

class LlamaDecoderLayer(nn.Module):

Ā Ā Ā Ā “”“Single transformer layer for a Llama mannequin.”“”

Ā 

Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

Ā Ā Ā Ā Ā Ā Ā Ā self.self_attn = LlamaAttention(config)

Ā Ā Ā Ā Ā Ā Ā Ā self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=1e–5)

Ā Ā Ā Ā Ā Ā Ā Ā self.mlp = LlamaMLP(config)

Ā 

Ā Ā Ā Ā def ahead(self, hidden_states: Tensor, rope: RotaryPositionEncoding, attn_mask: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā # First residual block: Self-attention

Ā Ā Ā Ā Ā Ā Ā Ā residual = hidden_states

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.input_layernorm(hidden_states)

Ā Ā Ā Ā Ā Ā Ā Ā attn_outputs = self.self_attn(hidden_states, rope=rope, attn_mask=attn_mask)

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = attn_outputs + residual

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā # Second residual block: MLP

Ā Ā Ā Ā Ā Ā Ā Ā residual = hidden_states

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.post_attention_layernorm(hidden_states)

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.mlp(hidden_states) + residual

Ā Ā Ā Ā Ā Ā Ā Ā return hidden_states

Ā 

Ā 

class LlamaModel(nn.Module):

Ā Ā Ā Ā “”“The complete Llama mannequin with none pretraining heads.”“”

Ā 

Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā self.rotary_emb = RotaryPositionEncoding(

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā config.hidden_size // config.num_attention_heads,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā config.max_position_embeddings,

Ā Ā Ā Ā Ā Ā Ā Ā )

Ā 

Ā Ā Ā Ā Ā Ā Ā Ā self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

Ā Ā Ā Ā Ā Ā Ā Ā self.layers = nn.ModuleList([

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)

Ā Ā Ā Ā Ā Ā Ā Ā ])

Ā Ā Ā Ā Ā Ā Ā Ā self.norm = nn.RMSNorm(config.hidden_size, eps=1e–5)

Ā 

Ā Ā Ā Ā def ahead(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā # Convert enter token IDs to embeddings

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.embed_tokens(input_ids)

Ā Ā Ā Ā Ā Ā Ā Ā # Course of by way of all transformer layers, then the ultimate norm layer

Ā Ā Ā Ā Ā Ā Ā Ā for layer in self.layers:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = layer(hidden_states, rope=self.rotary_emb, attn_mask=attn_mask)

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.norm(hidden_states)

Ā Ā Ā Ā Ā Ā Ā Ā # Return the ultimate hidden states

Ā Ā Ā Ā Ā Ā Ā Ā return hidden_states

Ā 

Ā 

class LlamaForPretraining(nn.Module):

Ā Ā Ā Ā def __init__(self, config: LlamaConfig) -> None:

Ā Ā Ā Ā Ā Ā Ā Ā tremendous().__init__()

Ā Ā Ā Ā Ā Ā Ā Ā self.base_model = LlamaModel(config)

Ā Ā Ā Ā Ā Ā Ā Ā self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

Ā 

Ā Ā Ā Ā def ahead(self, input_ids: Tensor, attn_mask: Tensor) -> Tensor:

Ā Ā Ā Ā Ā Ā Ā Ā hidden_states = self.base_model(input_ids, attn_mask)

Ā Ā Ā Ā Ā Ā Ā Ā return self.lm_head(hidden_states)

Ā 

Ā 

def create_causal_mask(batch: Tensor, dtype: torch.dtype = torch.float32) -> Tensor:

Ā Ā Ā Ā “”“Create a causal masks for self-attention.

Ā 

Ā Ā Ā Ā Args:

Ā Ā Ā Ā Ā Ā Ā Ā batch: Batch of sequences, form (batch_size, seq_len)

Ā Ā Ā Ā Ā Ā Ā Ā dtype: Information sort of the masks

Ā 

Ā Ā Ā Ā Returns:

Ā Ā Ā Ā Ā Ā Ā Ā Causal masks of form (seq_len, seq_len)

Ā Ā Ā Ā ““”

Ā Ā Ā Ā batch_size, seq_len = batch.form

Ā Ā Ā Ā masks = torch.full((seq_len, seq_len), float(“-inf”), gadget=batch.gadget, dtype=dtype)

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā .triu(diagonal=1)

Ā Ā Ā Ā return masks

Ā 

Ā 

def create_padding_mask(batch: Tensor, padding_token_id: int, dtype: torch.dtype = torch.float32) -> Tensor:

Ā Ā Ā Ā “”“Create a padding masks for a batch of sequences for self-attention.

Ā 

Ā Ā Ā Ā Args:

Ā Ā Ā Ā Ā Ā Ā Ā batch: Batch of sequences, form (batch_size, seq_len)

Ā Ā Ā Ā Ā Ā Ā Ā padding_token_id: ID of the padding token

Ā Ā Ā Ā Ā Ā Ā Ā dtype: Information sort of the masks

Ā 

Ā Ā Ā Ā Returns:

Ā Ā Ā Ā Ā Ā Ā Ā Padding masks of form (batch_size, 1, seq_len, seq_len)

Ā Ā Ā Ā ““”

Ā Ā Ā Ā padded = torch.zeros_like(batch, gadget=batch.gadget, dtype=dtype)

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā .masked_fill(batch == padding_token_id, float(“-inf”))

Ā Ā Ā Ā masks = padded[:,:,None] + padded[:,None,:]

Ā Ā Ā Ā return masks[:, None, :, :]

Ā 

Ā 

# Generator perform to create padded sequences of fastened size

class PretrainingDataset(torch.utils.information.Dataset):

Ā Ā Ā Ā def __init__(self, dataset: datasets.Dataset, tokenizer: tokenizers.Tokenizer,

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā  seq_length: int):

Ā Ā Ā Ā Ā Ā Ā Ā self.dataset = dataset

Ā Ā Ā Ā Ā Ā Ā Ā self.tokenizer = tokenizer

Ā Ā Ā Ā Ā Ā Ā Ā self.seq_length = seq_length

Ā Ā Ā Ā Ā Ā Ā Ā self.bot = tokenizer.token_to_id(“[BOT]”)

Ā Ā Ā Ā Ā Ā Ā Ā self.eot = tokenizer.token_to_id(“[EOT]”)

Ā Ā Ā Ā Ā Ā Ā Ā self.pad = tokenizer.token_to_id(“[PAD]”)

Ā 

Ā Ā Ā Ā def __len__(self):

Ā Ā Ā Ā Ā Ā Ā Ā return len(self.dataset)

Ā 

Ā Ā Ā Ā def __getitem__(self, index: int) -> tuple[Tensor, Tensor]:

Ā Ā Ā Ā Ā Ā Ā Ā “”“Get a sequence of token ids from the dataset. [BOT] and [EOT] tokens

Ā Ā Ā Ā Ā Ā Ā Ā are added. Clipped and padded to the sequence size.

Ā Ā Ā Ā Ā Ā Ā Ā ““”

Ā Ā Ā Ā Ā Ā Ā Ā seq = self.dataset[index][“text”]

Ā Ā Ā Ā Ā Ā Ā Ā tokens: checklist[int] = [self.bot] + self.tokenizer.encode(seq).ids + [self.eot]

Ā Ā Ā Ā Ā Ā Ā Ā # pad to focus on sequence size

Ā Ā Ā Ā Ā Ā Ā Ā toklen = len(tokens)

Ā Ā Ā Ā Ā Ā Ā Ā if toklen < self.seq_length+1:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā pad_length = self.seq_length+1 – toklen

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā tokens += [self.pad] * pad_size

Ā Ā Ā Ā Ā Ā Ā Ā # return the sequence

Ā Ā Ā Ā Ā Ā Ā Ā x = torch.tensor(tokens[:self.seq_length], dtype=torch.int64)

Ā Ā Ā Ā Ā Ā Ā Ā y = torch.tensor(tokens[1:self.seq_length+1], dtype=torch.int64)

Ā Ā Ā Ā Ā Ā Ā Ā return x, y

Ā 

Ā 

def load_checkpoint(mannequin: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

Ā Ā Ā Ā dist.barrier()

Ā Ā Ā Ā load(

Ā Ā Ā Ā Ā Ā Ā Ā {“mannequin”: mannequin, “optimizer”: optimizer},

Ā Ā Ā Ā Ā Ā Ā Ā checkpoint_id=“checkpoint-dist”,

Ā Ā Ā Ā Ā Ā Ā Ā planner=DefaultLoadPlanner(allow_partial_load=True),Ā Ā # ignore keys for RoPE buffer

Ā Ā Ā Ā )

Ā Ā Ā Ā scheduler.load_state_dict(

Ā Ā Ā Ā Ā Ā Ā Ā torch.load(“checkpoint-dist/lrscheduler.pt”, map_location=gadget),

Ā Ā Ā Ā )

Ā Ā Ā Ā dist.barrier()

Ā 

Ā 

def save_checkpoint(mannequin: nn.Module, optimizer: torch.optim.Optimizer, scheduler: lr_scheduler.SequentialLR) -> None:

Ā Ā Ā Ā dist.barrier()

Ā Ā Ā Ā save(

Ā Ā Ā Ā Ā Ā Ā Ā {“mannequin”: mannequin, “optimizer”: optimizer},

Ā Ā Ā Ā Ā Ā Ā Ā checkpoint_id=“checkpoint-dist”,

Ā Ā Ā Ā )

Ā Ā Ā Ā if dist.get_rank() == 0:

Ā Ā Ā Ā Ā Ā Ā Ā torch.save(scheduler.state_dict(), “checkpoint-dist/lrscheduler.pt”)

Ā Ā Ā Ā dist.barrier()

Ā 

Ā 

# Load the tokenizer and dataset

tokenizer = tokenizers.Tokenizer.from_file(“bpe_50K.json”)

dataset = datasets.load_dataset(“HuggingFaceFW/fineweb”, “sample-10BT”, cut up=“prepare”)

Ā 

# Initialize the distributed surroundings

dist.init_process_group(backend=“nccl”, timeout=datetime.timedelta(seconds=60))

local_rank = int(os.environ[“LOCAL_RANK”])

gadget = torch.gadget(f“cuda:{local_rank}”)

rank = dist.get_rank()

world_size = dist.get_world_size()

print(f“World measurement {world_size}, rank {rank}, native rank {local_rank}. Utilizing {gadget}”)

Ā 

# Initialize the mesh for tensor parallelism

n_tensor_parallel = 2

assert world_size % n_tensor_parallel == 0, “Anticipate world measurement to be divisible by variety of tensor parallel GPUs”

mesh = dist.device_mesh.init_device_mesh(

Ā Ā Ā Ā “cuda”,

Ā Ā Ā Ā (world_size // n_tensor_parallel, n_tensor_parallel),

Ā Ā Ā Ā mesh_dim_names=(“dp”, “tp”),

)

print(f“({rank}) Mesh: {mesh}, DP measurement: {mesh[‘dp’].measurement()}, TP measurement: {mesh[‘tp’].measurement()}, DP native rank: {mesh[‘dp’].get_local_rank()}, TP native rank: {mesh[‘tp’].get_local_rank()}”)

Ā 

# Create pretraining mannequin on meta gadget, on all ranks

with torch.gadget(“meta”):

Ā Ā Ā Ā model_config = LlamaConfig()

Ā Ā Ā Ā mannequin = LlamaForPretraining(model_config)

Ā 

# Arrange tensor parallelism on every transformer block within the base mannequin

tp_plan = {

Ā Ā Ā Ā “input_layernorm”: SequenceParallel(),

Ā Ā Ā Ā “self_attn”: PrepareModuleInput(

Ā Ā Ā Ā Ā Ā Ā Ā input_layouts=Shard(dim=1),Ā Ā # just one place arg will probably be used

Ā Ā Ā Ā Ā Ā Ā Ā desired_input_layouts=Replicate(),

Ā Ā Ā Ā ),

Ā Ā Ā Ā # Q/Ok projections output will probably be used with RoPE, have to be replicated

Ā Ā Ā Ā # Q/Ok/V output will probably be used with GQA, additionally have to be replicated

Ā Ā Ā Ā “self_attn.q_proj”: ColwiseParallel(output_layouts=Replicate()),

Ā Ā Ā Ā “self_attn.k_proj”: ColwiseParallel(output_layouts=Replicate()),

Ā Ā Ā Ā “self_attn.v_proj”: ColwiseParallel(output_layouts=Replicate()),

Ā Ā Ā Ā “self_attn.o_proj”: RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),

Ā Ā Ā Ā “post_attention_layernorm”: SequenceParallel(),

Ā Ā Ā Ā “mlp”: PrepareModuleInput(

Ā Ā Ā Ā Ā Ā Ā Ā input_layouts=Shard(dim=1),

Ā Ā Ā Ā Ā Ā Ā Ā desired_input_layouts=Replicate(),

Ā Ā Ā Ā ),

Ā Ā Ā Ā “mlp.gate_proj”: ColwiseParallel(),

Ā Ā Ā Ā “mlp.up_proj”: ColwiseParallel(),

Ā Ā Ā Ā “mlp.down_proj”: RowwiseParallel(output_layouts=Shard(1)),

}

for layer in mannequin.base_model.layers:

Ā Ā Ā Ā parallelize_module(layer, mesh[“tp”], tp_plan)

Ā 

# Arrange tensor parallelism on the embedding and output norm layers within the base mannequin

# and the prediction head within the top-level mannequin

tp_plan = {

Ā Ā Ā Ā “base_model.embed_tokens”: RowwiseParallel(

Ā Ā Ā Ā Ā Ā Ā Ā input_layouts=Replicate(),

Ā Ā Ā Ā Ā Ā Ā Ā output_layouts=Shard(1),

Ā Ā Ā Ā ),

Ā Ā Ā Ā “base_model.norm”: SequenceParallel(),

Ā Ā Ā Ā “lm_head”: ColwiseParallel(

Ā Ā Ā Ā Ā Ā Ā Ā input_layouts=Shard(1),

Ā Ā Ā Ā Ā Ā Ā Ā # output_layouts=Replicate(), # provided that not utilizing loss parallel

Ā Ā Ā Ā Ā Ā Ā Ā use_local_output=False,Ā Ā # Preserve DTensor output for loss parallel

Ā Ā Ā Ā ),

}

parallelize_module(mannequin, mesh[“tp”], tp_plan)

Ā 

# Convert tensor-parallelized mannequin to FSDP2, should shard each element

# shard throughout the “dp” dimension of the mesh

for layer in mannequin.base_model.layers:

Ā Ā Ā Ā fully_shard(layer, mesh=mesh[“dp”])

fully_shard(mannequin.base_model, mesh=mesh[“dp”])

fully_shard(mannequin, mesh=mesh[“dp”])

Ā 

def reset_all_weights(mannequin: nn.Module) -> None:

Ā Ā Ā Ā “”“Initialize all weights of the mannequin after transferring it away from meta gadget.”“”

Ā Ā Ā Ā @torch.no_grad()

Ā Ā Ā Ā def weight_reset(m: nn.Module):

Ā Ā Ā Ā Ā Ā Ā Ā reset_parameters = getattr(m, “reset_parameters”, None)

Ā Ā Ā Ā Ā Ā Ā Ā if callable(reset_parameters):

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā m.reset_parameters()

Ā 

Ā Ā Ā Ā # Applies fn recursively to mannequin itself and all of mannequin.kids()

Ā Ā Ā Ā mannequin.apply(fn=weight_reset)

Ā 

torch.manual_seed(42)

mannequin.to_empty(gadget=gadget)

reset_all_weights(mannequin)

assert isinstance(mannequin, FSDPModule), f“Anticipated FSDPModule, obtained {sort(mannequin)}”

Ā 

# Coaching parameters

epochs = 3

learning_rate = 1e–3

batch_size = 64 // mesh[“dp”].measurement()

seq_length = 512

num_warmup_steps = 1000

PAD_TOKEN_ID = tokenizer.token_to_id(“[PAD]”)

mannequin.prepare()

Ā 

# DataLoader, optimizer, scheduler, and loss perform

# Sampler is required to shard the dataset throughout world measurement

dataset = PretrainingDataset(dataset, tokenizer, seq_length)

sampler = DistributedSampler(

Ā Ā Ā Ā dataset, shuffle=False, drop_last=True,

Ā Ā Ā Ā num_replicas=mesh[“dp”].measurement(),

Ā Ā Ā Ā rank=mesh[“dp”].get_local_rank(),

)

dataloader = torch.utils.information.DataLoader(

Ā Ā Ā Ā dataset,

Ā Ā Ā Ā sampler=sampler,

Ā Ā Ā Ā batch_size=batch_size,

Ā Ā Ā Ā pin_memory=True,Ā Ā # optionally available

Ā Ā Ā Ā shuffle=False,

Ā Ā Ā Ā num_workers=2,

Ā Ā Ā Ā prefetch_factor=2,

)

num_training_steps = len(dataloader) * epochs

Ā 

optimizer = torch.optim.AdamW(

Ā Ā Ā Ā mannequin.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e–8, weight_decay=0.1,

)

warmup_scheduler = lr_scheduler.LinearLR(

Ā Ā Ā Ā optimizer,

Ā Ā Ā Ā start_factor=0.1, end_factor=1.0, total_iters=num_warmup_steps,

)

cosine_scheduler = lr_scheduler.CosineAnnealingLR(

Ā Ā Ā Ā optimizer,

Ā Ā Ā Ā T_max=num_training_steps – num_warmup_steps,

Ā Ā Ā Ā eta_min=0,

)

scheduler = lr_scheduler.SequentialLR(

Ā Ā Ā Ā optimizer,

Ā Ā Ā Ā schedulers=[warmup_scheduler, cosine_scheduler],

Ā Ā Ā Ā milestones=[num_warmup_steps],

)

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)

Ā 

# if checkpoint-dist dir exists, load the checkpoint to mannequin and optimizer

if os.path.exists(“checkpoint-dist”):

Ā Ā Ā Ā load_checkpoint(mannequin, optimizer, scheduler)

Ā 

# begin coaching

print(f“({rank}) Beginning coaching”)

for epoch in vary(epochs):

Ā Ā Ā Ā pbar = tqdm.tqdm(dataloader, desc=f“({rank}) Epoch {epoch+1}/{epochs}”)

Ā Ā Ā Ā for batch_id, batch in enumerate(pbar):

Ā Ā Ā Ā Ā Ā Ā Ā if batch_id % 1000 == 0:

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā save_checkpoint(mannequin, optimizer, scheduler)

Ā Ā Ā Ā Ā Ā Ā Ā # Specific prefetching earlier than sending any information to mannequin

Ā Ā Ā Ā Ā Ā Ā Ā mannequin.unshard()

Ā Ā Ā Ā Ā Ā Ā Ā # Get batched information, transfer from CPU to GPU

Ā Ā Ā Ā Ā Ā Ā Ā input_ids, target_ids = batch

Ā Ā Ā Ā Ā Ā Ā Ā input_ids = input_ids.to(gadget)

Ā Ā Ā Ā Ā Ā Ā Ā target_ids = target_ids.to(gadget)

Ā Ā Ā Ā Ā Ā Ā Ā # create consideration masks: causal masks + padding masks

Ā Ā Ā Ā Ā Ā Ā Ā attn_mask = create_causal_mask(input_ids) +

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā create_padding_mask(input_ids, PAD_TOKEN_ID)

Ā Ā Ā Ā Ā Ā Ā Ā # Extract output from mannequin

Ā Ā Ā Ā Ā Ā Ā Ā logits = mannequin(input_ids, attn_mask)

Ā Ā Ā Ā Ā Ā Ā Ā optimizer.zero_grad()

Ā Ā Ā Ā Ā Ā Ā Ā with loss_parallel():

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā # Compute loss: cross-entropy between logits and goal, ignoring padding tokens

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā loss = loss_fn(logits.view(–1, logits.measurement(–1)), target_ids.view(–1))

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā # Backward with loss on DTensor

Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā Ā loss.backward()

Ā Ā Ā Ā Ā Ā Ā Ā torch.nn.utils.clip_grad_norm_(mannequin.parameters(), 1.0)

Ā Ā Ā Ā Ā Ā Ā Ā optimizer.step()

Ā Ā Ā Ā Ā Ā Ā Ā scheduler.step()

Ā Ā Ā Ā Ā Ā Ā Ā pbar.set_postfix(loss=loss.merchandise())

Ā Ā Ā Ā Ā Ā Ā Ā pbar.replace(1)

Ā Ā Ā Ā pbar.shut()

Ā 

# Save the mannequin

save_checkpoint(mannequin, optimizer, scheduler)

Ā 

# Clear up the distributed surroundings

dist.destroy_process_group()

Tags: GPUsLargemodelmultipleParallelismTensorTrain
Admin

Admin

Next Post
9 Greatest Free Node.js Internet hosting 2026

9 Greatest Free Node.js Internet hosting 2026

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Recommended.

Witcher 4 Is So Far Away It Sounds Like It May Be A PS6 Recreation

Witcher 4 Is So Far Away It Sounds Like It May Be A PS6 Recreation

March 26, 2025
Cursor AI Rockets to $9.9 Billion Valuation with Huge $900 Million Elevate

Cursor AI Rockets to $9.9 Billion Valuation with Huge $900 Million Elevate

June 6, 2025

Trending.

10 tricks to begin getting ready! • Yoast

10 tricks to begin getting ready! • Yoast

July 21, 2025
AI-Assisted Menace Actor Compromises 600+ FortiGate Gadgets in 55 Nations

AI-Assisted Menace Actor Compromises 600+ FortiGate Gadgets in 55 Nations

February 23, 2026
Design Has By no means Been Extra Vital: Inside Shopify’s Acquisition of Molly

Design Has By no means Been Extra Vital: Inside Shopify’s Acquisition of Molly

September 8, 2025
Exporting a Material Simulation from Blender to an Interactive Three.js Scene

Exporting a Material Simulation from Blender to an Interactive Three.js Scene

August 20, 2025
Alibaba Workforce Open-Sources CoPaw: A Excessive-Efficiency Private Agent Workstation for Builders to Scale Multi-Channel AI Workflows and Reminiscence

Alibaba Workforce Open-Sources CoPaw: A Excessive-Efficiency Private Agent Workstation for Builders to Scale Multi-Channel AI Workflows and Reminiscence

March 1, 2026

AimactGrow

Welcome to AimactGrow, your ultimate source for all things technology! Our mission is to provide insightful, up-to-date content on the latest advancements in technology, coding, gaming, digital marketing, SEO, cybersecurity, and artificial intelligence (AI).

Categories

  • AI
  • Coding
  • Cybersecurity
  • Digital marketing
  • Gaming
  • SEO
  • Technology

Recent News

Slay the Spire 2 Assessment

Slay the Spire 2 Assessment

March 14, 2026
Key Features and Pricing Defined

Key Features and Pricing Defined

March 14, 2026
  • About Us
  • Privacy Policy
  • Disclaimer
  • Contact Us

Ā© 2025 https://blog.aimactgrow.com/ - All Rights Reserved

No Result
View All Result
  • Home
  • Technology
  • AI
  • SEO
  • Coding
  • Gaming
  • Cybersecurity
  • Digital marketing

Ā© 2025 https://blog.aimactgrow.com/ - All Rights Reserved