{"slug": "inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data", "title": "Inside FSDP with PyTorch and Ray: Scaling Model Training with Fully Sharded Data Parallel", "summary": "Alibaba's 1.7B parameter Qwen3-TTS voice cloning model was fine-tuned using Fully Sharded Data Parallel (FSDP) with PyTorch and Ray, demonstrating memory-efficient distributed training across 4 GPUs. The implementation achieved near-optimal memory usage by sharding all model states—parameters, gradients, and optimizer states—across workers, avoiding the massive GPU idle time seen in naive pipeline parallelism.", "body_md": "# Inside FSDP with PyTorch and Ray: Scaling Model Training with Fully Sharded Data Parallel\n\n[Suman Debnath](/blog?author=suman-debnath)| June 12, 2026\n\nA deep dive into FSDP internals with visual walkthroughs, hands-on implementation with Ray, PyTorch and DeepSpeed, and finally training a fine-tuned voice cloning model using 1.7B parameter Qwen3-TTS to clone your own voice.\n\n## Link**Introduction**\n\nA couple of months back, I tried to pen down my learnings on distributed training in [ this post](https://debnsuma.github.io/my-blog/posts/distributed-training-from-scratch/), where we discussed the fundamentals, from single-GPU bottlenecks to Data Parallelism and the ZeRO optimization stages. We explored how\n\n**memory constraints** limit model sizes on single GPUs and how sharding strategies help overcome these limitations.\n\nNow, in this blog, we’ll take a **deep dive into Fully Sharded Data Parallelism (FSDP)**. We’ll walk through a complete training iteration step-by-step using a concrete example with 4 GPUs, tracing exactly what happens to our model parameters, gradients, and optimizer states at each stage. By the end, you’ll have a crystal-clear mental model of how FSDP achieves its remarkable memory efficiency.\n\nOnce we have a solid understanding of FSDP internals, we’ll put this knowledge into practice using PyTorch’s FSDP and Ray Train. We’ll start by training a Vision Transformer on FashionMNIST, and then move on to fine-tuning the 1.7B parameter [ Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) model, recently released by Alibaba and available on Hugging Face, to clone our own voice.\n\nNote: PrerequisitesThis post builds on concepts from my previous blog on distributed training. Make sure you’re familiar with:\n\n- Static and dynamic memory constraints while training a model (parameters, gradients, optimizer states, activations, etc.)\n\n-`ZeRO-1`\n\n,`ZeRO-2`\n\n, and`ZeRO-3`\n\nsharding strategies\n\nCommunication primitives:\n\n- All-Reduce, All-Gather, Reduce-Scatter\n\nIf any of these are unfamiliar, I recommend reading the[first.]distributed training fundamentals post\n\n## Link**Why FSDP?**\n\nIn my previous post, we explored how **ZeRO** (Zero Redundancy Optimizer) progressively shards model state across GPUs:\n\nStrategy | What’s Sharded | Memory per GPU |\n| Nothing (full model copy on each GPU) | |\n| Optimizer states | |\n| Optimizer states + Gradients | |\n| Optimizer states + Gradients + Parameters |\n\n**FSDP** is PyTorch’s native implementation of fully sharded data parallel training, closely following the\n\nstage. FSDP shards **ZeRO-3*** all* model states, `parameters`\n\n, `gradients`\n\n, and `optimizer states`\n\n, across all data parallel workers, thereby achieving the theoretical minimum per-GPU memory usage for these tensors.\n\nBut how does it actually *work*? When `parameters`\n\nare scattered across different GPUs, how does each GPU run a forward pass that needs the *entire* model parameters? Before diving into FSDP, let’s understand why the “obvious” solution to training large models across multiple GPUs fails miserably.\n\nSuppose we have a model that doesn’t fit on a single GPU, but fits when split across 4 GPUs. The naive approach is **pipeline-style sequential execution**: place layers 1-3 on GPU0, layers 4-6 on GPU1, layers 7-9 on GPU2, and layers 10-12 on GPU3, assuming the model is a Transformer-style model with 12 layers (for example). The forward pass and backward pass pipelines look like this:\n\nFor a single batch, only one GPU is active at a time during the forward pass (`T1`\n\nto `T4`\n\n), leaving the others idle.\n\nThe same is true for the backward pass (`T5`\n\nto `T8`\n\n).\n\n**Massive GPU Idle Time**\n\nNow, this is really bad. If we look closely, we will see that GPU0 waits for **6 time steps** before it can do anything! Each GPU is sitting idle approximately **75% of the time**. We’ve split our model among GPUs, but most of the time each GPU is just waiting around doing nothing. **This is a massive waste of compute resources**.\n\nNote\n\nYou might wonder: Can’t westart Batch 2’s forward pass while Batch 1 is still propagating?Unfortunately, No. We cannot begin the next forward pass until the current batch’s weights are updated. GPU0 must wait for the entire forward-backward cycle to complete before processing the next batch.\n\n**FSDP solves this problem elegantly**, enabling all GPUs to work simultaneously on different batches while still training a single coherent model. FSDP accomplishes this by combining two orthogonal splitting strategies to achieve both memory efficiency and high GPU utilization:\n\n**Vertical partitioning**: Organizing the model into`units`\n\n**Horizontal sharding**: Sharding the model`parameters`\n\n,`gradients`\n\n, and`optimizer states`\n\nacross all GPUs\n\n## Link**Our Running Example**\n\nLet’s set up a concrete scenario that we’ll trace throughout this entire walkthrough. The focus here is to understand the internals of FSDP, not the model training itself or the actual accuracy of the model.\n\n### Link**The Model**\n\nSo, we’ll use a simple `Transformer-style `\n\nmodel with `12 layers`\n\n, and `4 GPUs`\n\n, each with `16 GB `\n\nof memory.\n\nThe model has the following memory requirements (roughly):\n\nComponent | Memory Required |\nModel Parameters (MP) | 8 GB |\nGradients (GRD) | 8 GB |\nOptimizer State (OS) | 16 GB |\n| 32 GB |\n\nNote: Why 16 GB for Optimizer State?Here we are considering the`Adam`\n\noptimizer. The optimizer state (`OS`\n\n) maintainstwo FP32 tensorsper parameter: the first moment (mean of gradients, ) and the second moment (uncentered variance, ). For 8 GB of FP32 parameters, that’s GB for optimizer states.\n\n### Link**The Hardware**\n\nWe have **4 GPUs, each with 16 GB of memory**.\n\n**The problem is clear:** 32 GB of static memory doesn’t fit on any single 16 GB GPU. But with 4 GPUs, we have 64 GB total, more than enough if we can distribute the load effectively.\n\nSo, we can somehow train the model on 4 GPUs, if we can figure out a way to distribute the load effectively.\n\nAnd remember, we’re not considering the activations here, which are usually much larger than the model parameters and gradients. These activations fall under what we call [ Dynamic Memory Constraints](https://debnsuma.github.io/my-blog/posts/distributed-training-from-scratch/#dynamic-memory).\n\nFor context, throughout the rest of this post you can treat the input data as the activations too; it’s just that the activation for the 1st layer is called `input`\n\n. Not a big deal, but just to be clear.\n\nLet’s see how we can distribute and perform the training of the model effectively on 4 GPUs. As we know, when it comes to GPUs, there are two things that consume most of the memory: parameters and activations (including the input data). So, let’s handle the dataset first.\n\n### Link**The Dataset**\n\nWe split our training data into **4 different mini-batches**, one for each GPU:\n\nEach GPU will process its own batch **simultaneously**, using the **same model weights**. This is still data parallelism at its core, the key difference is **how** we store and manage those weights.\n\n## Link**FSDP’s Two-Dimensional Splitting Strategy**\n\nFSDP combines two orthogonal splitting strategies to achieve both memory efficiency and high GPU utilization.\n\n### Link**Vertical Partitioning (Units)**\n\nHere we organize the model’s layers into **units**, with each unit managing a specific range of layers. So, considering our model has 12 layers, we can organize it into 4 units, each unit managing 3 layers.\n\nThis is purely `organizational`\n\n, it determines the granularity at which **FSDP will gather and release parameters during computation**. We can choose to have more or less units, depending on the model size and the number of GPUs.\n\nThis is purely `organizational`\n\n, it determines the granularity at which **FSDP will gather and release parameters during computation**. We can choose to have more or less units, depending on the model size and the number of GPUs.\n\nNote: Units ≠ GPUsIn FSDP, definingunits(vertical partitions of layers) is a modeling choice to control parameter loading, checkpointing, and granularity of sharding; it doesnotneed to match the number of GPUs, and often does not. There isno technical requirementfor the count or boundaries of units to correspond to your device topology.Units are simply logical groupings for parameter management and do not dictate how parameters are sharded across GPUs, the sharding is handled separately and horizontally. You can have any number of units regardless of your GPU count.\n\n### Link**Horizontal Sharding**\n\nHere’s where FSDP differs fundamentally from the naive approach. Instead of assigning different layers to different GPUs, we **shard each entity (parameters, gradients, optimizer states) horizontally across ALL GPUs**.\n\n**Before Sharding** (would need `32GB`\n\nper GPU):\n\n**After Sharding** (only `8GB`\n\nper GPU):\n\nImportant: Critical InsightEach shard contains ahorizontal slice from ALL layers, not just`one unit's`\n\nlayers.For example,\n\n`GPU0's`\n\nshard has the first`1/4`\n\nof parameters from layers`1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, and 12`\n\n. This is fundamentally different from the naive approach where GPU0 would have`all`\n\nparameters from layers`1-3`\n\nonly.\n\nThis horizontal sharding is what allows every GPU to participate in processing *every* layer of the model.\n\n## Link**What Exactly Gets Sharded in FSDP?**\n\nNow that we’ve seen how FSDP splits the model both vertically (units) and horizontally (across GPUs), it’s important to clarify: *what* are we actually sharding?\n\nUnder FSDP’s default and most memory-efficient mode, `FULL_SHARD`\n\n, we shard all the three entities (model parameters, gradients, optimizer state) across all GPUs. This is the critical difference that lets you scale to larger models than would ever fit on a single device.\n\nHere are a few notational conventions we’ll use moving forward (for clarity in diagrams and explanations):\n\n**Shard**: A single partitioned chunk of parameters, gradients, or optimizer state, held by one specific GPU** Unit**: A logical grouping of one or more layers (the result of our vertical split)** ACT**: Activations calculated during the forward pass (and kept for backward): Total memory needed for parameters + gradients + optimizer state (e.g., 32 GB in our example scenario)\n\nWith those fundamentals in place, let’s see how FSDP actually operates step by step.\n\n## Link**Step-by-Step FSDP Walkthrough**\n\nNow let’s trace through one complete training iteration (forward pass and backward pass).\n\n### Link**Phase 0: Initial Setup**\n\nBefore training begins, FSDP performs two setup operations:\n\n**Step 0.1: Split the Dataset**\n\nEach GPU gets assigned a different mini-batch of the dataset:\n\n**Step 0.2: Shard the Model**\n\nThen we divide each entity into 4 shards and distribute them, as we have seen in the previous section:\n\n**Initial State After Sharding**:\n\nNoteAt this stage, the GRD (gradient) shards are simplyplaceholders, they have been allocated on each GPU, but donotcontain any meaningful data yet.No gradients have been calculated at this point because the forward and backward passes haven’t started. The actual gradient values will only be computed and filled in during the backward pass, once loss values are propagated backward through the network.\n\n### Link**Phase 1: Forward Pass**\n\nRecall that a “unit” is a logical grouping of one or more layers (from the vertical split earlier). In the forward pass, each unit is processed one after another, but all 4 GPUs operate **in parallel** on their respective mini-batches during each unit’s turn. We have the following now in ** each GPU**:\n\nIts own shard of the\n\n`model parameters`\n\nfor the**Unit 1**(layers 1-3)Its own shard of the\n\n`optimizer state`\n\nfor the**Unit 1** Its own shard of the\n\n`gradients`\n\nfor the**Unit 1**(placeholders)Its own mini-batch of the\n\n`dataset`\n\n#### LinkStep 1.1: All-Gather Parameters for Unit 1\n\nBefore we can run layers 1-3, each GPU needs the **complete parameters** for **Unit 1**. Since these parameters are sharded across all GPUs, we perform an **All-Gather** operation:\n\nEach GPU now temporarily holds the complete **Unit 1** parameters. The key word here is **temporarily**,we’ll discard the borrowed shards after use.\n\n#### LinkStep 1.2: Forward Pass on Unit 1 (All GPUs in Parallel!)\n\nNow, **all 4 GPUs simultaneously** run the forward pass on **layers 1-3** (for **Unit 1**), each using its own mini-batch but the same model weights.\n\nThis is the magic of FSDPAll 4 GPUs are working simultaneously! Same`model weights`\n\n, different`data`\n\n, different`activations`\n\n.\n\n#### LinkStep 1.3: Save Activations\n\nEach GPU stores its computed activations, these are needed later for gradient computation during backward pass.\n\n#### LinkStep 1.4: Reshard (Free Temporary Memory)\n\nNow that the forward pass for **Unit 1** is complete, we can **delete the borrowed parameter shards** to free up GPU memory, keeping only our owned shard.\n\nMemory usage drops back down, but we’ve retained the ** activations** we need for backward.\n\n#### LinkStep 1.5: Repeat for Units 2, 3, and 4\n\nThe same **All-Gather** → **Forward** → **Save ACT** → **Reshard** cycle repeats for each of the remaining `units `\n\n(Unit 2, 3, and 4):\n\n#### LinkStep 1.6: Compute Loss\n\nEach GPU computes the loss for its respective batch (mini-batch).\n\n**End of Forward Pass State:**\n\nSo at this point, each GPU now holds:\n\nIts original 1/4 shard of ALL model parameters\n\nActivations from ALL units (\n\n`ACT_unit1`\n\n,`ACT_unit2`\n\n,`ACT_unit3`\n\n,`ACT_unit4`\n\n)Its computed loss value\n\nPlaceholder gradient shards (not yet filled)\n\n### Link**Phase 2: Backward Pass**\n\nNow it’s time to run everything in reverse – we’re going to calculate gradients and send them back through the network, so our model can learn! We’ll start from the last set of layers (`Unit 4`\n\n) and work our way backward to the first set of layers (`Unit 1`\n\n).\n\n#### LinkStep 2.1: All-Gather Parameters for Unit 4\n\nBefore starting the backward pass, we need to gather the full parameters for the `Unit 4 `\n\nagain, as we did in the forward pass.\n\nNo Need to All-Gather for Unit 4Good news: for Unit 4, wealreadyhave all the full parameters (`MP_unit4`\n\n) on every GPU from the forward pass, so there’s nothing new to do here for this step! Each GPU kept a full copy of Unit 4’s weights after the forward pass because they were just used.But after we move to earlier units (\n\n`Unit 3`\n\n,`Unit 2`\n\n, and`Unit 1`\n\n), the full parameters will need to be re-assembled on each GPU again (using All-Gather), just like we did during the forward pass. That’s because, after each unit’s step is finished, we typically “reshard” and return to just holding a shard to save memory.\n\n#### LinkStep 2.2: Compute Local Gradients for Unit 4\n\nEach GPU computes gradients based on **its own batch’s loss** and **its own activations**.\n\nAt this point, each GPU has computed a **local gradient** based on only its portion of the data.\n\nBut for proper optimization, we need the **global gradient** (sum of all local gradients). So, we need to **reduce** the gradients across all GPUs.\n\n#### LinkStep 2.3: Reduce-Scatter Gradients (The Key Operation!)\n\nThis is where the magic happens. We use a **Reduce-Scatter** operation to:\n\n**Reduce**(sum) all gradients across GPUs** Scatter**the result so each GPU gets only its`responsible shard`\n\nWhy Reduce-Scatter instead of All-Reduce?In DDP, we useAll-Reducewhich gives every GPU thefullsummed gradient. But in FSDP, each GPU onlyneedsthe gradient for parameters it owns.\n\nReduce-Scatteris more efficient, it produces the same sum but distributes it,saving both memory and communication bandwidth.\n\n#### LinkStep 2.4: Free Memory\n\nAfter reduce-scatter, we can release:\n\nThe temporary gathered\n\n`MP_unit4`\n\n(keep only owned shard)The\n\n`ACT_unit4`\n\n(no longer needed)\n\n#### LinkStep 2.5: Repeat for Units 3, 2, and 1\n\nWorking backward through the network, for each of the remaining units (Unit 3, 2, and 1):\n\nUnit 3:\n\n**All-Gather**`MP_unit3`\n\n→**Backward**→** Reduce-Scatter**`GRD`\n\n→**Free**`ACT_unit3`\n\nUnit 2:\n\n**All-Gather**`MP_unit2`\n\n→**Backward**→** Reduce-Scatter**`GRD`\n\n→**Free**`ACT_unit2`\n\nUnit 1:\n\n**All-Gather**`MP_unit1`\n\n→**Backward**→** Reduce-Scatter**`GRD`\n\n→**Free**`ACT_unit1`\n\n**End of Backward Pass State:**\n\nEach GPU now holds: - Its 1/4 shard of model parameters (`MP_shard`\n\n) - Its 1/4 shard of ACCUMULATED gradients (`GRD_shard`\n\n) ← Ready for optimization! - Its 1/4 shard of optimizer state (`OS_shard`\n\n) - All activations have been freed!\n\n### Link**Phase 3: Optimizer Step**\n\nNow comes the beautiful part: **each GPU can update its parameters independently**.\n\n#### LinkStep 3.1: Local Optimizer Update\n\nEach GPU has everything it needs to update its portion of the model:\n\nIts parameter shard\n\nThe accumulated gradient for that shard (summed across all batches)\n\nIts optimizer state shard\n\nTip: No Communication Needed!Each GPU updates only its shard using only data it already has. This is perfectly parallel and requires zero inter-GPU communication.\n\n#### LinkStep 3.2: Ready for Next Batch\n\nWe’re back to our initial state, but with updated parameters and we can fetch the next set of batches.\n\nAnd then we can repeat the entire process for all the batches in the dataset.\n\nSo, to summarize, this is what happens during one full training iteration with FSDP:\n\n## Link**Memory Analysis: The Numbers**\n\nLet’s crunch the numbers for our 4-GPU example to really see the impact that FSDP makes on memory efficiency.\n\n### Link**Without FSDP (DDP)**\n\nWith classic Data Parallelism (DDP), each GPU holds a `full copy `\n\nof the model, its gradients, and the optimizer state:\n\nHere: - **MP** = Model parameters (8 GB) - **GRD** = Gradients (8 GB) - **OS** = Optimizer state (16 GB)\n\nSo, **each GPU would need a whopping 32 GB** just for model training.\n\n**Result**: Won’t fit on 16 GB GPUs! Most consumer and even many data center GPUs would just run out of memory immediately.\n\n### Link**With FSDP (4 GPUs)**\n\nWith FSDP, the memory requirements are slashed, `parameters`\n\n, `gradients`\n\n, and `optimizer states`\n\nare sharded across all GPUs.\n\nEach GPU needs:\n\nWhere is the number of GPUs in our example.\n\nNow, each GPU only stores a `quarter`\n\nof the parameters, gradients, and optimizer states at any time.\n\n**Result**: Fits comfortably on 16 GB GPUs! FSDP enables you to train models twice as large on the same hardware or the same model with much larger batch sizes.\n\nBut there is no free lunch, we need to pay the price in communication overhead.\n\n### Link**Communication Cost**\n\nFor each training iteration:\n\nPhase | Operation | Data Volume |\nForward (per unit) | All-Gather MP | |\nBackward (per unit) | All-Gather MP | |\nBackward (per unit) | Reduce-Scatter GRD |\n\nTotal communication per iteration: approximately (where is total parameters)\n\nTip: Prefetching OptimizationIn practice, FSDP overlaps communication with computation. While computing forward pass on Unit , it can start all-gathering parameters for Unit in the background. This significantly reduces the effective communication overhead.\n\n## Link**Implementing FSDP with PyTorch and Ray Train**\n\nNow that we have a solid understanding of *how* FSDP works, let’s implement it. The implementation is fairly straightforward.\n\nWe’ll use [ PyTorch FSDP2](https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html) with\n\n[to train a Vision Transformer on FashionMNIST dataset.](https://docs.ray.io/en/latest/train/train.html)\n\n__Ray Train__If you are new to **Ray Train**, you can check out my previous post [ here](https://debnsuma.github.io/my-blog/posts/distributed-training-from-scratch/#introduction-to-ray---unified-ai-compute-engine).\n\n## Link**FSDP2: What’s New and Why Does It Matter?**\n\nPyTorch’s Fully Sharded Data Parallel (FSDP) module has undergone a significant evolution in its second major version, often called **FSDP2**. This new version introduces architectural, usability, and performance improvements over the original FSDP design, now sometimes retroactively referred to as **FSDP1**.\n\nLet’s examine the main improvements in detail:\n\nAspect | FSDP1 | FSDP2 |\n| Uses a large FlatParameter tensor (concatenates all params per group for sharding) | Each parameter is sharded independently across ranks (per-parameter sharding) |\n| Flattened groups of parameters, requiring explicit grouping | Individual parameters; native granularity for any param tensor |\n| Experimental and limited; not natively exposed | Native, built-in DTensor integration (for multi-dimensional and hybrid sharding) |\n| Loading/saving often requires cross-worker communication to reconstruct full tensors | Can save/load fully sharded state dicts without collective communication, supporting parallel check/restore and streaming |\n| Difficult to manage; groups must be updated if freezing layers | Frozen (non-trainable) parameters are naturally skipped and don’t require extra grouping steps |\n| Rigid or error-prone; cannot always wrap at arbitrary module boundaries | Fine-grained, easy wrapping at any module, submodule, or parameter level |\n\nIn FSDP2, the core insight is that each parameter tensor is partitioned (typically along the first dimension, dim-0) and distributed across all participating GPUs (“ranks”). This approach eliminates the need to flatten and concatenate parameters into a single tensor per sharding group, which simplifies parameter management, improves compatibility with a wider variety of models, and makes integration with other sharding strategies trivial.\n\n## Link**Prerequisite: Setting Up a Ray Cluster**\n\nBefore getting into the code, we need to first have a Ray cluster up and running. I strongly recommend using [ Anyscale](https://www.anyscale.com/) as it provides an easy way to launch and manage Ray clusters with GPU workers.\n\nHere is how you can get started, for detailed instructions, you can refer to [ GitHub Repository](https://github.com/debnsuma/vhol-ray-train?tab=readme-ov-file#getting-started-with-anyscale). You can also check out the\n\n[for more details.](https://docs.anyscale.com/get-started)\n\n__Anyscale documentation__But in brief, here’s how you can get started:\n\n**Here’s how you can get started:**\n\n**Create an Anyscale Account:** First, sign up at.__https://www.anyscale.com/__**Provision a Ray Cluster on Anyscale:** After logging in, start a new project and create a new Ray cluster.\n\nMake sure the cluster configuration includes:\n\nOne master node (head node)\n\nTwo or more worker nodes with GPUs (such as NVIDIA V100, A100, T4, L4, H100, etc.) For this tutorial, you may like to use L4 based GPU workers.\n\n**Open the Workspace** Once the cluster is ready, open the workspace. You can do this by clicking on the\n\n`Workspace`\n\nbutton in the Anyscale dashboard.Clone the\n\nand install the dependencies.__GitHub repository__\n\n```\ngit clone https://github.com/debnsuma/vhol-ray-train.git\ncd vhol-ray-train\npip install -r requirements.txt\n```\n\n## Link**Setting Up the Environment**\n\n``` python\nimport os\nos.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n\nimport tempfile\nimport uuid\nimport torch\nimport ray\n\nprint(f\"PyTorch version: {torch.__version__}\")\nprint(f\"Ray version: {ray.__version__}\")\nPyTorch version: 2.10.0+cu128\nRay version: 2.53.0\n```\n\n## Link**Step 1: Define the Model**\n\nWe’ll use a **Vision Transformer (ViT)** for this tutorial. ViT has clear, repeatable block structures (transformer encoder blocks) that map perfectly to our **units** concept from the theory section. But you can use any model of your choice.\n\n``` python\nfrom torchvision.models import VisionTransformer\nfrom torchvision.datasets import FashionMNIST\nfrom torchvision.transforms import ToTensor, Normalize, Compose\n\ndef init_model():\n    \"\"\"Initialize Vision Transformer for FashionMNIST (28x28 grayscale, 10 classes).\"\"\"\n    model = VisionTransformer(\n        image_size=28, patch_size=7, num_layers=10, num_heads=2,\n        hidden_dim=128, mlp_dim=128, num_classes=10,\n    )\n    # Modify for grayscale input\n    model.conv_proj = torch.nn.Conv2d(1, 128, kernel_size=7, stride=7)\n    return model\n\n# Verify model\ntest_model = init_model()\nprint(f\"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}\")\ndel test_model\n```\n\nModel parameters: 1,006,090\n\n## Link**Step 2: Apply FSDP2 Sharding**\n\nNow we implement the sharding strategy we discussed earlier. Each encoder block becomes a **unit** that we shard individually:\n\n``` python\nfrom torch.distributed.fsdp import fully_shard\nfrom torch.distributed.device_mesh import init_device_mesh\nimport ray.train\n\ndef shard_model(model):\n    \"\"\"Apply FSDP2 sharding to the model.\"\"\"\n    world_size = ray.train.get_context().get_world_size()\n\n    # Create device mesh for data parallelism\n    mesh = init_device_mesh(\"cuda\", (world_size,), mesh_dim_names=(\"dp\",))\n\n    # Shard each encoder block individually\n    for block in model.encoder.layers.children():\n        fully_shard(block, mesh=mesh, reshard_after_forward=True)\n\n    # Shard the root model\n    fully_shard(model, mesh=mesh, reshard_after_forward=True)\n```\n\nTip: reshard_after_forward Trade-offSetting`reshard_after_forward=True`\n\nimplements the memory optimization we discussed earlier, i.e. parameters are freed after forward pass and re-gathered during backward. This reduces peak memory but increases communication.\n\nOptional: Advanced PoliciesFor memory-constrained scenarios, you can add:\n\n- CPU Offloading:`CPUOffloadPolicy()`\n\n- Offloads parameters to CPU when not in use- Mixed Precision:`MixedPrecisionPolicy(param_dtype=torch.float16)`\n\n- Reduces memory with FP16. For modern GPUs (A100, H100), prefer BF16 over FP16. BF16 has the same exponent range as FP32, reducing overflow/underflow issues and eliminating the need for loss scaling in most cases.\n\n``` python\nfrom torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy\n\nfully_shard(block, \n            mesh=mesh, \n            reshard_after_forward=True,\n            offload_policy=CPUOffloadPolicy(),\n            mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))\n```\n\n## Link**Step 3: Distributed Checkpointing**\n\nNow let’s tackle checkpointing for sharded models. Conventional checkpointing approaches (e.g., `torch.save(model.state_dict())`\n\n) require gathering all model parameters on rank 0, which is infeasible for large, sharded models due to excessive memory usage and communication overhead.\n\nPyTorch Distributed Checkpoint (DCP) solves this by enabling efficient, scalable checkpointing across all workers. Its key features:\n\n**Parallel I/O**: Each worker saves only its portion (shard) of the model and optimizer state in parallel, no need to gather everything to a single process.**Automatic Resharding**: When resuming, DCP automatically reshuffles states if the number of workers changes between save and load. This means you can resume training with a different world size (e.g., after a node failure or scale-up).**Full Optimizer State**: DCP can checkpoint both the model and the full optimizer state, enabling robust training resumption and fault tolerance.\n\nThis class provides a wrapper for PyTorch Distributed Checkpoint (DCP) to save and load model and optimizer state together.\n\n``` python\nfrom torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict, get_model_state_dict, StateDictOptions\nfrom torch.distributed.checkpoint.stateful import Stateful\nimport torch.distributed.checkpoint as dcp\n\nclass AppState(Stateful):\n    \"\"\"Wrapper for DCP checkpointing.\"\"\"\n    def __init__(self, model, optimizer=None, epoch=None):\n        self.model, self.optimizer, self.epoch = model, optimizer, epoch\n\n    def state_dict(self):\n        model_sd, optim_sd = get_state_dict(self.model, self.optimizer)\n        return {\"model\": model_sd, \"optim\": optim_sd, \"epoch\": self.epoch}\n\n    def load_state_dict(self, state_dict):\n        set_state_dict(self.model, self.optimizer,\n                      model_state_dict=state_dict[\"model\"],\n                      optim_state_dict=state_dict[\"optim\"])\n        self.epoch = state_dict.get(\"epoch\")\n```\n\nThis function loads a DCP checkpoint, restoring model, optimizer, and epoch for training resumption and DCP’s automatic resharding.\n\n``` python\ndef load_checkpoint(model, optimizer, ckpt):\n    \"\"\"Load FSDP checkpoint (handles resharding automatically).\"\"\"\n    with ckpt.as_directory() as ckpt_dir:\n        app_state = AppState(model, optimizer)\n        dcp.load(state_dict={\"app\": app_state}, checkpoint_id=ckpt_dir)\n    return app_state.epoch\n```\n\nThis function collects sharded model weights onto rank 0 and saves a full PyTorch model checkpoint for inference use.\n\n``` python\ndef save_model_for_inference(model, world_rank):\n    \"\"\"Consolidate sharded model for inference (rank 0 saves full model).\"\"\"\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        model_sd = get_model_state_dict(model, options=StateDictOptions(full_state_dict=True, cpu_offload=True))\n        ckpt = None\n        if world_rank == 0:\n            torch.save(model_sd, os.path.join(tmp_dir, \"full-model.pt\"))\n            ckpt = ray.train.Checkpoint.from_directory(tmp_dir)\n        ray.train.report({}, checkpoint=ckpt, checkpoint_dir_name=\"full_model\")\n```\n\nNote: Model Consolidation for Inference\n\nThe`save_model_for_inference`\n\nfunction all-gathers weights to rank 0 and saves a standard PyTorch checkpoint. This consolidated model can be loaded without FSDP for inference.\n\n## Link**Step 4: The Training Function**\n\nLet’s now implement the training function. This function is executed on each Ray worker and orchestrates the end-to-end FSDP training lifecycle.\n\nPay special attention to:\n\n**Checkpoint handling:** This supports training resumption and fault-tolerance, and is critical for distributed workflows.**Model sharding:** The`shard_model(model)`\n\ncall prepares your model for FSDP wrapping.**Data loading:** Note the use of`ray.train.torch.prepare_data_loader`\n\nto ensure efficient data sharding and distribution across workers.**Reporting and model saving:** Notice how checkpoints are reported with metrics for Ray dashboard and final model weights are consolidated for inference post-training.\n\nEach of these ensures that distributed training runs robustly, can be resumed after interruptions, and produces a standard model for inference.\n\n``` python\nimport ray.train.torch\nfrom torch.nn import CrossEntropyLoss\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\n\ndef train_func(config):\n    \"\"\"FSDP2 training function.\"\"\"\n    # Model setup\n    model = init_model()\n    device = ray.train.torch.get_device()\n    torch.cuda.set_device(device)\n    model.to(device)\n    shard_model(model)  # Prepares your model for FSDP sharding and distributed execution\n\n    # Training setup\n    criterion = CrossEntropyLoss()\n    optimizer = Adam(model.parameters(), lr=config.get('lr', 0.001))\n\n    # Resume from checkpoint if available\n    start_epoch = 0\n    if ray.train.get_checkpoint():\n        # Checkpoint loading lets you resume or recover from failures safely\n        start_epoch = load_checkpoint(model, optimizer, ray.train.get_checkpoint()) + 1\n\n    # Data loading\n    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])\n    train_data = FashionMNIST(root=tempfile.gettempdir(), train=True, download=True, transform=transform)\n    train_loader = DataLoader(train_data, batch_size=config.get('batch_size', 64), shuffle=True)\n    train_loader = ray.train.torch.prepare_data_loader(train_loader)  # Ensures distributed sharding of samples\n\n    # Context\n    world_rank = ray.train.get_context().get_world_rank()\n\n    # Training loop\n    for epoch in range(start_epoch, config.get('epochs', 1)):\n        # Ensures good shuffling across epochs in a distributed setting\n        if ray.train.get_context().get_world_size() > 1:\n            train_loader.sampler.set_epoch(epoch)\n\n        total_loss, num_batches = 0.0, 0\n        for images, labels in train_loader:\n            outputs = model(images)\n            loss = criterion(outputs, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n            num_batches += 1\n\n        avg_loss = total_loss / num_batches\n        # Checkpoint saving is key: it enables Ray's fault-tolerance and progress tracking\n        save_checkpoint(model, optimizer, {\"loss\": avg_loss, \"epoch\": epoch}, epoch)\n        if world_rank == 0:\n            print(f\"Epoch {epoch}: loss={avg_loss:.4f}\")\n\n    # Consolidate and save the full model for downstream inference (run only on rank 0)\n    save_model_for_inference(model, world_rank)\n```\n\n## Link**Step 5: Launch Distributed Training**\n\nLet’s now launch the distributed training.\n\nRay Train’s `TorchTrainer`\n\nhandles worker spawning, process group initialization, and checkpoint coordination.\n\nYou may like to pay special attention to:\n\nEach experiment gets a\n\n**unique name** for later tracking and artifact separation.**The ScalingConfig** sets the number of distributed workers and enables GPU use.**The RunConfig** configures where Ray will persist checkpoints and outputs.\n\n``` python\nimport ray.train.torch\nimport uuid\n\n# Configuration\nexperiment_name = f\"fsdp_{uuid.uuid4().hex[:8]}\"\nscaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)\nrun_config = ray.train.RunConfig(storage_path=\"/mnt/cluster_storage/\", name=experiment_name)\ntrain_config = {\"epochs\": 1, \"lr\": 0.001, \"batch_size\": 64}\n\nprint(f\"Experiment: {experiment_name}\")\n```\n\nNow let’s launch the training:\n\n```\n# Create and run trainer\ntrainer = ray.train.torch.TorchTrainer(\n    train_loop_per_worker=train_func,\n    scaling_config=scaling_config,\n    train_loop_config=train_config,\n    run_config=run_config,\n)\nresult = trainer.fit()\nprint(f\"Training complete! Checkpoint: {result.checkpoint}\")\n```\n\nTraining output:\n\n```\nExperiment: fsdp_b2f564ce\n(RayTrainWorker) Epoch 0: loss=0.7410\n\nTraining complete! Checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/fsdp_b2f564ce/full_model)\n```\n\nNote: Parameter-Efficient Fine-TuningIn this example, we’re fine-tuning theentire modelusing full parameter updates. If you’d like to use parameter-efficient fine-tuning methods likeLoRAorQLoRA, you can easily integrate them here as well, the distributed FSDP training pipeline will largely remain the same. Just wrap or modify your model, optimizer, and training loop as needed, and use Ray Train as shown.\n\n## Link**Step 6: Inspect Training Artifacts**\n\nBefore we move on to the next step, let’s inspect the training artifacts.\n\n`checkpoint_*/`\n\n- Epoch checkpoints with distributed shards`full_model/`\n\n- Consolidated model for inference\n\n```\n# List artifacts\nstorage_path = f\"/mnt/cluster_storage/{experiment_name}/\"\nprint(f\"Artifacts in {storage_path}:\")\nfor item in sorted(os.listdir(storage_path)):\n    print(f\"  {item}/\" if os.path.isdir(os.path.join(storage_path, item)) else f\"  {item}\")\nArtifacts in /mnt/cluster_storage/fsdp_b2f564ce/:\n  .validate_storage_marker\n  checkpoint_2026-02-02_06-52-14.180406/\n  checkpoint_manager_snapshot.json\n  full_model/\n```\n\n## Link**Step 7: Load Model for Inference**\n\nNow let’s load the model for inference. The consolidated model (`full-model.pt`\n\n) is a standard PyTorch checkpoint that works without FSDP2:\n\n```\n# Load model for inference\nmodel_path = f\"/mnt/cluster_storage/{experiment_name}/full_model/full-model.pt\"\ninference_model = init_model()\ninference_model.load_state_dict(torch.load(model_path, map_location='cpu', weights_only=True))\ninference_model.eval()\nprint(\"Model loaded.\")\n# Test inference\ntest_data = FashionMNIST(root=\"/tmp\", train=False, download=True,\n                         transform=Compose([ToTensor(), Normalize((0.5,), (0.5,))]))\nwith torch.no_grad():\n    sample = test_data.data[0].reshape(1, 1, 28, 28).float()\n    output = inference_model(sample)\nprint(f\"Inference output shape: {output.shape}\")\n```\n\nInference output shape: torch.Size([1, 10])\n\n## Link**DeepSpeed: An Alternative to FSDP2**\n\nWhen it comes to large-scale model training, there are several strategies that help distribute computation and memory efficiently across multiple GPUs or nodes. **DeepSpeed**, an open-source library developed by Microsoft, is designed from the ground up to make distributed training fast, scalable, and easy to use for gigantic models.\n\nIt offers efficient training optimizations such as **ZeRO**, advanced optimizers, and mixed precision, enabling researchers and practitioners to train models that would otherwise not fit in GPU memory.\n\nWhile FSDP2 is PyTorch’s native solution for sharded training, DeepSpeed stands out as another popular and feature-rich framework for distributed training. Let’s introduce it briefly and show how it compares.\n\n## Link**Key Differences from FSDP2**\n\nAspect | FSDP2 | DeepSpeed |\nSetup |\n|\n|\nOptimizer | User creates separately | Managed by DeepSpeed |\nBackward |\n|\n|\nConfig | Python API | JSON/dict config |\n\nWe already discussed the **ZeRO** stages in the previous [ post](https://debnsuma.github.io/my-blog/posts/distributed-training-from-scratch/#zero-zero-redundancy-optimizer). DeepSpeed implements the same ZeRO stages as FSDP2, but makes it easier to configure them via a simple configuration file.\n\n## Link**How DeepSpeed Works**\n\nIt’s designed as a plug-in replacement for the standard PyTorch training loop, making it very approachable for most PyTorch users.\n\nIn contrast to FSDP2 (which is all Python API), DeepSpeed intentionally uses a `user-friendly configuration file`\n\nto define its distributed behavior and optimization strategies. Here’s a simple example of such a configuration programmatically defined in code (but you can also put this in a JSON):\n\n``` python\ndef get_deepspeed_config(batch_size=64, lr=0.001):\n    \"\"\"A minimal DeepSpeed ZeRO Stage 2 config\"\"\"\n    return {\n        \"optimizer\": {\n            \"type\": \"Adam\",\n            \"params\": {\"lr\": lr, \"betas\": [0.9, 0.999], \"eps\": 1e-8},\n        },\n        \"fp16\": {\"enabled\": False},  # Change to True to enable mixed precision\n        \"zero_optimization\": {\n            \"stage\": 2,  # ZeRO Stage 2 for optimizer and gradient state partitioning\n            \"allgather_bucket_size\": 2e8,\n            \"reduce_bucket_size\": 2e8,\n            \"overlap_comm\": True,\n            \"contiguous_gradients\": True,\n        },\n        \"train_micro_batch_size_per_gpu\": batch_size,\n        \"gradient_accumulation_steps\": 1,\n        \"gradient_clipping\": 1.0,\n        \"steps_per_print\": 1000,\n    }\n```\n\nYou just pass this dictionary (or the path to a config JSON) into DeepSpeed, no complex code rewrite needed! If you want mixed precision, NVMe offload, or other advanced features, you just add keys to this config. You can read much more on [ DeepSpeed’s official Getting Started guide](https://www.deepspeed.ai/getting-started/), which includes example configs, performance tips, and other features like ZeRO stage 3 and offloading to NVMe for truly massive models.\n\n## Link**DeepSpeed Training Function**\n\nGetting started with DeepSpeed is very similar to PyTorch. The two main steps are:\n\n**Initialize your model with DeepSpeed:** This wraps it into a\n\n`model engine`\n\nthat handles distributed parallelism, memory optimizations (like ZeRO), optimizer state, and learning rate scheduling.**Use the DeepSpeed engine in the training loop:**`model_engine.backward(loss)`\n\nreplaces the usual`loss.backward()`\n\n`model_engine.step()`\n\nreplaces the usual`optimizer.step()`\n\nLet’s see how the training function looks like:\n\n``` python\ndef train_func(config):\n    \"\"\"DeepSpeed training function (modeled after PyTorch, but much easier for scale).\"\"\"\n    import deepspeed\n\n    # Setup model and DeepSpeed engine\n    model = init_model()\n    ds_config = get_deepspeed_config(batch_size=config.get('batch_size', 64), lr=config.get('lr', 0.001))\n    model_engine, optimizer, _, _ = deepspeed.initialize(\n        model=model, config=ds_config, model_parameters=model.parameters()\n    )\n    device = model_engine.device\n\n    criterion = CrossEntropyLoss()\n\n    # Distributed sampler and dataloader (just like PyTorch DDP)\n    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])\n    train_data = FashionMNIST(root=tempfile.gettempdir(), train=True, download=True, transform=transform)\n    sampler = torch.utils.data.DistributedSampler(\n        train_data,\n        num_replicas=ray.train.get_context().get_world_size(),\n        rank=ray.train.get_context().get_world_rank(),\n        shuffle=True,\n    )\n    train_loader = DataLoader(train_data, batch_size=config.get('batch_size', 64), sampler=sampler)\n\n    # Training loop\n    for epoch in range(config.get('epochs', 1)):\n        sampler.set_epoch(epoch)\n        total_loss, num_batches = 0.0, 0\n\n        for images, labels in train_loader:\n            images, labels = images.to(device), labels.to(device)\n\n            # Forward/backward/step handled by DeepSpeed\n            outputs = model_engine(images)\n            loss = criterion(outputs, labels)\n            model_engine.backward(loss)\n            model_engine.step()\n\n            total_loss += loss.item()\n            num_batches += 1\n\n        avg_loss = total_loss / num_batches\n        print(f\"Epoch {epoch}: loss={avg_loss:.4f}\")\n```\n\n## Link**Project: Fine-tuning Qwen3-TTS (Voice Cloning)**\n\nNow that we’ve covered FSDP and distributed training concepts, let’s put everything together with a real-world application that’s genuinely exciting: **fine-tuning a 1.7B parameter text-to-speech model to clone your own voice**.\n\nThis section walks through building a voice cloning system using [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) from Alibaba, applying the FSDP and Ray Train techniques we’ve learned.\n\n## Link**What is Qwen3-TTS?**\n\nQwen3-TTS is an open-source text-to-speech model with 1.7B parameters. It uses a unique architecture:\n\n**Text encoder**: Processes input text into hidden representations** Speaker encoder**: Extracts voice characteristics (x-vector embeddings)** Audio decoder**: Generates discrete audio codes (12Hz, 16 codebooks)** Vocoder**: Converts audio codes to waveforms\n\nThe model can perform **zero-shot voice cloning**, generating speech in any voice given just a reference audio sample. But with fine-tuning, we can make it **much** better at matching a specific voice.\n\n## Link**Qwen3-TTS Voice Cloning with Ray Distributed Training**\n\nThis project demonstrates how to build a custom voice using Qwen3-TTS, leveraging distributed training with Ray Train for scalable, fault-tolerant fine-tuning. The aim is to clone your unique voice by adapting the [Qwen3-TTS-12Hz-1.7B-Base](https://github.com/QwenLM/Qwen3-TTS) model using your audio recordings. In this case, I have cloned my own voice :) All I did was I downloaded 3-4 hrs of audio recording of my voice and transcribed it using Whisper.\n\n**Goal:** Fine-tune Qwen3-TTS with your own speech samples to create a personalized, high-quality voice clone.\n\n### Link**Pipeline**\n\nThe workflow follows these main stages:\n\n**Raw Audio Files****Data Processing (Ray Data):** Distributed segmenting/transcription of your audio into training samples**Audio Code Extraction:** Convert processed audio into suitable feature codes for the TTS model**SFT Training (Ray Train):** Distributed fine-tuning using Ray Train, adapting the model to your voice**Inference:** Generate custom speech from text inputs\n\n### Link**Step 1: Data Processing**\n\nTo begin, we transform your recorded audio files into usable training segments, distributing the workload efficiently with `Ray`\n\n.\n\n``` python\nimport ray\nimport whisper\nimport numpy as np\n\n@ray.remote(num_gpus=0.5)  # Use GPU for Whisper\ndef process_audio_ray(audio_path: str, output_dir: str, config: dict):\n    \"\"\"Process a single audio file on a Ray worker.\"\"\"\n    import soundfile as sf\n\n    # Load audio at 16kHz for Whisper transcription\n    audio_16k, _ = sf.read(audio_path)\n\n    # Transcribe with Whisper\n    model = whisper.load_model(\"base\")\n    result = model.transcribe(audio_16k, language=\"en\", word_timestamps=True)\n\n    # Segment based on Whisper's detected segments\n    segments = []\n    for seg in result[\"segments\"]:\n        if 1.0 < (seg[\"end\"] - seg[\"start\"]) < 15.0:  # Keep 1-15 second segments\n            segments.append({\n                \"audio\": audio_16k[int(seg[\"start\"]*16000):int(seg[\"end\"]*16000)],\n                \"text\": seg[\"text\"].strip()\n            })\n\n    # Save segments as individual WAV files\n    results = []\n    for i, seg in enumerate(segments):\n        seg_path = f\"{output_dir}/{Path(audio_path).stem}_seg{i:04d}.wav\"\n        sf.write(seg_path, seg[\"audio\"], 24000)  # Qwen3-TTS expects 24kHz\n        results.append({\"audio\": seg_path, \"text\": seg[\"text\"]})\n\n    return results\n```\n\nWe can now process your recorded **WAV** files in parallel using `Ray`\n\nby passing each file to the `process_audio_ray`\n\nremote function.\n\n```\n# Process all audio files in parallel\naudio_files = list(Path(\"data/\").glob(\"*.wav\"))\nfutures = [process_audio_ray.remote(str(f), \"output/wav/\", config) for f in audio_files]\nall_segments = ray.get(futures)\n```\n\nThe output is a JSONL file where each line contains an `audio `\n\npath and its `text `\n\ntranscript:\n\n```\n{\"audio\": \"output/wav/recording_seg0001.wav\", \"text\": \"Hello, this is my voice.\"}\n{\"audio\": \"output/wav/recording_seg0002.wav\", \"text\": \"I'm recording samples for training.\"}\n```\n\n### Link**Step 2: Extract Audio Codes**\n\nQwen3-TTS doesn’t work with raw audio waveforms. Instead, it uses `discrete audio codes`\n\n, a compressed representation that captures the essential acoustic information:\n\n``` python\nfrom qwen_tts import Qwen3TTSModel\n\n# Load the tokenizer model\ntokenizer_model = Qwen3TTSModel.from_pretrained(\n    \"Qwen/Qwen3-TTS-Tokenizer-12Hz\",\n    device_map=\"cuda:0\",\n    dtype=torch.bfloat16,\n)\n\ndef extract_audio_codes(audio_path: str) -> list:\n    \"\"\"Convert audio waveform to discrete codes.\"\"\"\n    import librosa\n\n    # Load audio at 24kHz\n    audio, sr = librosa.load(audio_path, sr=24000, mono=True)\n\n    # Extract codes: [time_steps, 16 codebooks]\n    with torch.no_grad():\n        codes = tokenizer_model.encode_audio(audio, sr=24000)\n\n    return codes.tolist()\n```\n\nFor a 10-second clip: 10 × 12Hz = 120 time steps × 16 channels = **1,920 tokens**.\n\n### Link**Step 3: The Training Function**\n\nHere’s the core training function that runs on each Ray worker. This is where all our FSDP knowledge comes together:\n\n``` python\nimport ray.train.torch\nfrom ray import train as ray_train\n\ndef train_func(config: dict):\n    \"\"\"Qwen3-TTS fine-tuning with speaker embedding conditioning.\"\"\"\n    import torch\n    from qwen_tts import Qwen3TTSModel\n    from torch.utils.data import DataLoader, DistributedSampler\n\n    # Setup distributed context\n    rank = ray_train.get_context().get_world_rank()\n    world_size = ray_train.get_context().get_world_size()\n    local_rank = ray_train.get_context().get_local_rank()\n\n    device = torch.device(f\"cuda:{local_rank}\")\n    torch.cuda.set_device(device)\n\n    print(f\"Worker {rank}/{world_size} starting on {device}\")\n\n    # Load pre-trained model\n    wrapper = Qwen3TTSModel.from_pretrained(\n        \"Qwen/Qwen3-TTS-12Hz-1.7B-Base\",\n        device_map=f\"cuda:{local_rank}\",\n        dtype=torch.bfloat16,\n    )\n    model = wrapper.model\n    talker = model.talker\n\n    # Freeze most parameters, only train the talker (audio generation)\n    for param in model.parameters():\n        param.requires_grad = False\n    for param in talker.parameters():\n        param.requires_grad = True\n\n    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n    print(f\"Trainable parameters: {trainable:,}\")\n\n    # Extract speaker embedding from reference audio (our voice signature)\n    import librosa\n    ref_audio, sr = librosa.load(config[\"ref_audio\"], sr=24000, mono=True)\n    with torch.no_grad():\n        speaker_embedding = model.extract_speaker_embedding(ref_audio, sr=24000)\n        speaker_embedding = speaker_embedding.to(device).to(torch.bfloat16)\n\n    # Setup data loading with DistributedSampler\n    dataset = TTSDataset(config[\"train_jsonl\"])\n    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)\n    dataloader = DataLoader(dataset, batch_size=config[\"batch_size\"], sampler=sampler)\n\n    # Optimizer with cosine schedule\n    optimizer = torch.optim.AdamW(\n        [p for p in model.parameters() if p.requires_grad],\n        lr=config[\"learning_rate\"],\n        weight_decay=0.01,\n    )\n\n    # Training loop\n    for epoch in range(config[\"num_epochs\"]):\n        sampler.set_epoch(epoch)\n        epoch_loss = 0.0\n\n        for batch_idx, batch in enumerate(dataloader):\n            # Tokenize text\n            text_inputs = wrapper.processor.tokenizer(\n                batch[\"text\"], padding=True, return_tensors=\"pt\"\n            ).to(device)\n\n            # Get audio codes [batch, time, 16]\n            audio_codes = torch.tensor(batch[\"audio_codes\"]).to(device)\n\n            # Get text embeddings and add speaker conditioning\n            with torch.no_grad():\n                text_embeds = talker.get_text_embeddings()(text_inputs[\"input_ids\"])\n                text_embeds = talker.text_projection(text_embeds)\n\n            # Forward pass with speaker-conditioned hidden states\n            loss = torch.tensor(0.0, device=device)\n            for t in range(min(audio_codes.shape[1], 100)):\n                codec_ids = audio_codes[:, t, :]\n\n                # Condition on speaker embedding\n                text_hidden = text_embeds[:, min(t, text_embeds.shape[1]-1), :]\n                talker_hidden = text_hidden + 0.1 * speaker_embedding\n\n                # Compute loss on audio code predictions\n                _, step_loss = talker.forward_sub_talker_finetune(\n                    codec_ids=codec_ids,\n                    talker_hidden_states=talker_hidden.to(torch.bfloat16)\n                )\n                if step_loss is not None:\n                    loss = loss + step_loss\n\n            # Backward pass\n            loss = loss / config[\"gradient_accumulation_steps\"]\n            loss.backward()\n\n            if (batch_idx + 1) % config[\"gradient_accumulation_steps\"] == 0:\n                torch.nn.utils.clip_grad_norm_(\n                    [p for p in model.parameters() if p.requires_grad], 1.0\n                )\n                optimizer.step()\n                optimizer.zero_grad()\n\n            epoch_loss += loss.item()\n\n        # Report metrics to Ray Train\n        avg_loss = epoch_loss / len(dataloader)\n        ray_train.report({\"loss\": avg_loss, \"epoch\": epoch})\n\n        if rank == 0:\n            print(f\"Epoch {epoch}: loss={avg_loss:.4f}\")\n```\n\nNote: Speaker Embedding ConditioningThe key to voice cloning is thespeaker embedding. We extract an x-vector from our reference audio that captures our voice’s unique characteristics (pitch, timbre, speaking style). During training, we add this embedding to the text hidden states, teaching the model to generate audio codes that sound likeourselves.\n\n### Link**Step 4: Launch Distributed Training**\n\nNow we launch training across multiple GPUs with Ray Train, like we did in the previous examples:\n\n``` python\nfrom ray.train.torch import TorchTrainer\nfrom ray.train import ScalingConfig, RunConfig\n\n# Training configuration\ntrain_config = {\n    \"train_jsonl\": \"output/train_with_codes.jsonl\",\n    \"ref_audio\": \"output/wav/reference.wav\",\n    \"batch_size\": 2,\n    \"learning_rate\": 1e-5,\n    \"num_epochs\": 10,\n    \"gradient_accumulation_steps\": 4,\n}\n\n# Scale across 4 GPUs\nscaling_config = ScalingConfig(\n    num_workers=4,\n    use_gpu=True,\n    resources_per_worker={\"CPU\": 4, \"GPU\": 1}\n)\n\nrun_config = RunConfig(\n    name=\"qwen_tts_voice_clone\",\n    storage_path=\"/mnt/cluster_storage/\",\n)\n\n# Launch training\ntrainer = TorchTrainer(\n    train_func,\n    train_loop_config=train_config,\n    scaling_config=scaling_config,\n    run_config=run_config,\n)\n\nprint(\"Starting voice cloning training...\")\nresult = trainer.fit()\nprint(f\"Training complete! Checkpoint: {result.checkpoint}\")\n```\n\nExpected training output:\n\nWorker 0/4 starting on cuda:0\n\nWorker 1/4 starting on cuda:1\n\nWorker 2/4 starting on cuda:2\n\nWorker 3/4 starting on cuda:3\n\nTrainable parameters: 847,234,560\n\nEpoch 0: loss=2.4521\n\nEpoch 1: loss=1.8734\n\nEpoch 2: loss=1.5289\n\n...\n\nEpoch 9: loss=0.8142\n\nTraining complete!\n\n### Link**Step 5: Generate Speech with Your Voice**\n\nAfter training, we can generate speech in our cloned voice:\n\n``` python\nimport torch\nfrom qwen_tts import Qwen3TTSModel\n\n# Load base model\nwrapper = Qwen3TTSModel.from_pretrained(\n    \"Qwen/Qwen3-TTS-12Hz-1.7B-Base\",\n    device_map=\"cuda:0\",\n    dtype=torch.bfloat16,\n)\n\n# Load fine-tuned weights\ncheckpoint = torch.load(\"final_model/model.pt\", map_location=\"cuda:0\")\nwrapper.model.load_state_dict(checkpoint[\"model_state_dict\"], strict=False)\nwrapper.model.eval()\n\n# Generate speech\ntext = \"Hello! This is my cloned voice speaking. Pretty cool, right?\"\n\nwith torch.no_grad():\n    wavs, sr = wrapper.generate_voice_clone(\n        text=text,\n        language=\"english\",\n        ref_audio=(\"reference.wav\", 24000),\n        x_vector_only_mode=True,\n    )\n\n# Save output\nimport soundfile as sf\nsf.write(\"my_voice_output.wav\", wavs[0].cpu().numpy(), sr)\nprint(\"Generated speech saved to my_voice_output.wav\")\n```\n\nHere is one of the samples generated by the fine-tuned model (first 10 seconds of the audio):\n\n🔊 __This is my voice generated by the fine-tuned model__\n\n## Link**Conclusion**\n\nIn this post, we took a deep dive into Fully Sharded Data Parallel (FSDP) and explored how it can be leveraged, together with Ray Train, to address the challenges of large-scale deep learning. We started by examining why traditional, sequential training approaches fail to fully utilize available GPU resources and why they quickly become infeasible as model sizes grow. Through hands-on segments, we learned how FSDP partitions model parameters both vertically (across layers or units) and horizontally (across GPUs), enabling the efficient training of massive models through smart sharding and communication.\n\nAlong the way, we broke down what actually happens during one full training iteration with FSDP: parameters are gathered from all devices for computation, then resharded, followed by the distributed back propagation and optimizer steps. Building on these foundations, we put theory into practice: first by training a vision transformer with production-quality, distributed code, then by scaling up to a real-world application, cloning a unique voice by fine-tuning a 1.7-billion parameter text-to-speech model.\n\nFSDP makes a crucial tradeoff: it reduces memory usage by sharding parameters, at the cost of more communication between devices. Thanks to techniques like overlapping computation and communication, this overhead is manageable, allowing us to train much larger models than before.\n\nDistributed training engines like FSDP and Ray Train unlock capabilities that were, until recently, reserved for only the largest research labs. The fine-tuned voice cloning model we built demonstrates the practical power of training large models at scale. Although the model was not all that large by today’s standards, it was a good starting point to understand the basics of distributed training and FSDP.\n\n## Link**References**\n\nAnyscale and Ray\n\nDistributed Training\n\nGPU/System Engineering\n\nLLM and Advance Deep Learning\n\nRay, PyTorch and DeepSpeed\n\n#### Table of contents\n\n[Introduction](#introduction)[Why FSDP?](#why-fsdp?)[Our Running Example](#our-running-example)[The Model](#the-model)[The Hardware](#the-hardware)[The Dataset](#the-dataset)[FSDP’s Two-Dimensional Splitting Strategy](#fsdp’s-two-dimensional-splitting-strategy)[Vertical Partitioning (Units)](#vertical-partitioning-(units))[Horizontal Sharding](#horizontal-sharding)[What Exactly Gets Sharded in FSDP?](#what-exactly-gets-sharded-in-fsdp?)[Step-by-Step FSDP Walkthrough](#step-by-step-fsdp-walkthrough)[Phase 0: Initial Setup](#phase-0:-initial-setup)[Phase 1: Forward Pass](#phase-1:-forward-pass)[Phase 2: Backward Pass](#phase-2:-backward-pass)[Phase 3: Optimizer Step](#phase-3:-optimizer-step)[Memory Analysis: The Numbers](#memory-analysis:-the-numbers)[Without FSDP (DDP)](#without-fsdp-(ddp))[With FSDP (4 GPUs)](#with-fsdp-(4-gpus))[Communication Cost](#communication-cost)[Implementing FSDP with PyTorch and Ray Train](#implementing-fsdp-with-pytorch-and-ray-train)[FSDP2: What’s New and Why Does It Matter?](#fsdp2:-what’s-new-and-why-does-it-matter?)[Prerequisite: Setting Up a Ray Cluster](#prerequisite:-setting-up-a-ray-cluster)[Setting Up the Environment](#setting-up-the-environment)[Step 1: Define the Model](#step-1:-define-the-model)[Step 2: Apply FSDP2 Sharding](#step-2:-apply-fsdp2-sharding)[Step 3: Distributed Checkpointing](#step-3:-distributed-checkpointing)[Step 4: The Training Function](#step-4:-the-training-function)[Step 5: Launch Distributed Training](#step-5:-launch-distributed-training)[Step 6: Inspect Training Artifacts](#step-6:-inspect-training-artifacts)[Step 7: Load Model for Inference](#step-7:-load-model-for-inference)[DeepSpeed: An Alternative to FSDP2](#deepspeed:-an-alternative-to-fsdp2)[Key Differences from FSDP2](#key-differences-from-fsdp2)[How DeepSpeed Works](#how-deepspeed-works)[DeepSpeed Training Function](#deepspeed-training-function)[Project: Fine-tuning Qwen3-TTS (Voice Cloning)](#project:-fine-tuning-qwen3-tts-(voice-cloning))[What is Qwen3-TTS?](#what-is-qwen3-tts?)[Qwen3-TTS Voice Cloning with Ray Distributed Training](#qwen3-tts-voice-cloning-with-ray-distributed-training)[Pipeline](#pipeline)[Step 1: Data Processing](#step-1:-data-processing)[Step 2: Extract Audio Codes](#step-2:-extract-audio-codes)[Step 3: The Training Function](#step-3:-the-training-function)[Step 4: Launch Distributed Training](#step-4:-launch-distributed-training)[Step 5: Generate Speech with Your Voice](#step-5:-generate-speech-with-your-voice)[Conclusion](#conclusion)[References](#references)\n\n#### Sign up for product updates\n\n#### Recommended content\n\n#### Scalable Distributed Training: From Single-GPU Limits to Reliable Multi-Node Runs with Ray on Anyscale\n\nRead more\n\n#### Streamline Distributed AI Monitoring and Debugging with New Ray Train & Ray Data Dashboards in Anyscale\n\nRead more", "url": "https://wpnews.pro/news/inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data", "canonical_source": "https://anyscale.com/blog/fsdp-pytorch-deepspeed-ray-large-scale-distributed-training", "published_at": "2026-06-12 00:00:00+00:00", "updated_at": "2026-06-16 06:54:10.599121+00:00", "lang": "en", "topics": ["large-language-models", "ai-infrastructure", "ai-tools", "machine-learning", "ai-research"], "entities": ["PyTorch", "Ray", "DeepSpeed", "Alibaba", "Qwen3-TTS", "Hugging Face", "FashionMNIST", "Suman Debnath"], "alternates": {"html": "https://wpnews.pro/news/inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data", "markdown": "https://wpnews.pro/news/inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data.md", "text": "https://wpnews.pro/news/inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data.txt", "jsonld": "https://wpnews.pro/news/inside-fsdp-with-pytorch-and-ray-scaling-model-training-with-fully-sharded-data.jsonld"}}