You got a large model demo running on your laptop and it works well, but when you try to turn it into a real product, problems arise:
\n- \n
- What if the model has too many parameters to fit on a single GPU? \n
- Training takes weeksβhow do you resume from checkpoints if something fails midway? \n
- With millions of user requests daily, how do you keep latency under 1 second? \n
- How do you collect user feedback to keep the model evolving? \n
These are the problems that AI system architecture solves.
\nDemos focus on whether it can run at all; production systems focus on whether it can run stably, efficiently, and cost-effectively.
\n\n\nCharacteristics of production-grade AI systems: 7x24 availability, support for millions of concurrent users, observability, scalability, disaster recovery reliability, and cost control.
\n
\n
Unique Challenges of AI Systems
\nCompared to traditional web services, AI systems have three unique challenges.
\nChallenge 1: Non-deterministic Output
\nTraditional systems produce deterministic outputβyou input 1+1, it always returns 2.
\nAI systems produce probabilistic outputβthe same prompt may generate different results each time.
\nThis raises several questions: How do you ensure output quality? How do you evaluate performance? How do you handle hallucinations?
\nTypical solutions include: adding sampling strategies at the output layer, post-processing and filtering results, human feedback loops, and multi-model voting.
\nChallenge 2: The Latency-Cost Tradeoff
\nAI inference requires massive computation, which means there's a natural tension between latency and cost.
\nWant it fast? Use more GPUs, costs skyrocket.
\nWant to save money? Queue processing, poor user experience.
\nThe core of production systems is finding the balance point between SLA (Service Level Agreement) and cost.
\n| Optimization Direction | \nCommon Techniques | \nEffect | \n
|---|---|---|
| Model Compression | \nQuantization, Pruning, Distillation | \n2-4x speedup, slight accuracy drop | \n
| Inference Optimization | \nvLLM, TensorRT, FlashAttention | \n3-10x throughput improvement | \n
| Architecture Design | \nBatch processing, multi-level caching | \n50%-80% reduction in per-request cost | \n
Challenge 3: Building the Data Flywheel
\nAI systems aren't "done once deployed"βthey require continuous iteration.
\nThe more users use it, the more feedback data is generated, the better the model can be trained, and the more users are attractedβthis is the data flywheel.
\nBut getting the flywheel spinning isn't easy: How do you collect effective feedback? How do you label data? How do you train continuously? How do you evaluate new versions?
\nThere's no standard answer to these questions, but every successful AI product has its own flywheel design.
\n\n
Large-Scale Training Infrastructure
\nTraining models with hundreds of billions or even trillions of parameters requires supercomputing infrastructure.
\nGPU Cluster Architecture
\nModern AI training clusters typically consist of hundreds or thousands of GPUs.
\nTake the H100 GPU as an example: a single H100 has 80GB of memory and delivers 1979 TFLOPS in FP8 precision.
\nBut a single GPU is far from enoughβGPT-3 training used about 355 V100s and took 3 months.
\nA typical cluster topology is:
\n| Level | \nDevices | \nConnection | \nBandwidth | \n
|---|---|---|---|
| Within single machine | \nGPU-GPU | \nNVLink | \n900GB/s | \n
| Same rack | \nServer-Server | \nInfiniBand | \n400Gb/s | \n
| Cross-rack | \nSwitch-Switch | \nInfiniBand Fabric | \n400Gb/s | \n
\n\nNetwork is the bottleneck of distributed training. If communication bandwidth is insufficient, GPU utilization may drop from 90% to 30%, with most time spent waiting for data.
\n
InfiniBand High-Speed Interconnect
\nRegular Ethernet is too slow; distributed training uses InfiniBand.
\nInfiniBand's characteristics are: extremely low latency (microsecond level), extremely high bandwidth, and support for RDMA (Remote Direct Memory Access).
\nRDMA allows one GPU to directly read and write another server's GPU memory without going through the OS kernel, which is much faster.
\nStorage System Design
\nTraining data is typically TB or even PB scale, so storage systems are also critical.
\nTypical tiered storage design:
\n- \n
- Hot data: SSD or NVMe, stores the currently training batch \n
- Warm data: Distributed storage (e.g., Ceph, Lustre), stores the complete training set \n
- Cold data: Object storage (e.g., S3), stores historical data and backups \n
Fault Tolerance and Checkpoints
\nTraining takes weeksβwhat if a GPU fails during that time? Starting over would be too wasteful.
\nThe solution is Checkpointingβperiodically saving model state to disk, and recovering from the most recent checkpoint if an error occurs.
\nBut checkpoints also have costs: saving once may take several minutes and occupy dozens of GB of space.
\nTypical strategy: save every few hundred steps, keep the most recent few checkpoints, and automatically clean up old ones.
\n\n
Distributed Training Strategies
\nA single GPU can't hold large models, so training tasks need to be split across multiple GPUs.
\nThere are mainly three parallel strategies: data parallelism, tensor parallelism, and pipeline parallelism. Combined, they form "3D parallelism".
\nData Parallelism (Data Parallelism)
\nThe simplest and most commonly used strategy: each GPU holds the complete model but processes different data.
\nFor example, with 8 GPUs and batch size 1024, each GPU processes 128 data samples.
\nForward propagation is calculated independently, and after backward propagation, gradients are aggregated and averaged, then the model is updated.
\nThe problem with data parallelism is: memory is still the bottleneckβif the model is too large for a single GPU, data parallelism doesn't help.
\nZeRO Memory Optimization
\nZeRO (Zero Redundancy Optimizer) is an enhanced version of data parallelism that can further save memory.
\nIn ordinary data parallelism, each GPU stores complete model parameters, gradients, and optimizer statesβthis is very redundant.
\nZeRO's idea is: split these states across different GPUs, and communicate only when needed.
\n| ZeRO Stage | \nSplit Content | \nMemory Savings | \n
|---|---|---|
| ZeRO-1 | \nOptimizer states | \n~4x | \n
| ZeRO-2 | \nOptimizer states + Gradients | \n~8x | \n
| ZeRO-3 | \nOptimizer states + Gradients + Parameters | \nLinear with GPU count | \n
Configuring ZeRO with DeepSpeed is simple:
\nExamples
\n\n{\n "train_batch_size": 1024,\n "train_micro_batch_size_per_gpu": 16,\n "optimizer": {\n "type": "Adam",\n "params": {\n "lr": 0.0001,\n "betas": [0.9, 0.95],\n "eps": 1e-8,\n "weight_decay": 0.01\n }\n },\n "zero_optimization": {\n "stage": 3,\n "allgather_partitions": true,\n "allgather_bucket_size": 2e8,\n "overlap_comm": true,\n "reduce_scatter": true,\n "reduce_bucket_size": 2e8,\n "contiguous_gradients": true,\n "stage3_prefetch_bucket_size": 1e8,\n "stage3_param_persistence_threshold": 1e5,\n "stage3_max_live_parameters": 1e9,\n "stage3_max_reuse_distance": 1e9\n },\n "gradient_clipping": 1.0,\n "fp16": {\n "enabled": true,\n "loss_scale": 0,\n "loss_scale_window": 1000,\n "initial_scale_power": 16,\n "hysteresis": 2,\n "min_loss_scale": 1\n },\n "checkpoint": {\n "tag": "tutorial-checkpoint",\n "load_universal": true\n }\n}\n \n This configuration uses ZeRO-3, which can distribute model states across all GPUs, with memory usage decreasing linearly with the number of GPUs.
\nTensor Parallelism (Tensor Parallelism)
\nIf ZeRO is still not enough, tensor parallelism is neededβsplitting the computation of a single layer across multiple GPUs.
\nMatrix multiplications in Transformers can be split by row or column:
\n- \n
- Split matrix A into Aβ and Aβ by row, calculate AβΓB and AβΓB on GPU 0 and GPU 1 respectively \n
- Finally concatenate the results \n
This requires communication for each layer's computation, but memory usage is also halved.
\nMegatron-LM is NVIDIA's tensor parallelism library, with good PyTorch compatibility.
\nPipeline Parallelism (Pipeline Parallelism)
\nTensor parallelism is "intra-layer splitting"; pipeline parallelism is "inter-layer splitting".
\nFor example, with a 32-layer model, GPU 0 holds the first 8 layers, GPU 1 holds the middle 8 layers, GPU 2 holds the next 8 layers, and GPU 3 holds the last 8 layers.
\nData flows from GPU 0 to GPU 3, like a factory assembly line.
\nBut pipeline has a problem: bubblesβwhen GPU 0 starts computing, GPUs 1-3 are idle; when data reaches GPU 1, GPU 0 is idle again.
\nThe solution is to split data into "micro-batches" and feed them in like a pipeline, reducing bubble time.
\n3D Parallelism (DP+TP+PP)
\nThe three strategies can be combined:
\n- \n
- Pipeline parallelism: Split model layers across nodes \n
- Tensor parallelism: Split intra-layer computation within nodes \n
- Data parallelism: Replicate the entire pipeline at a larger scale \n
For example, with 64 GPUs, you could plan:
\n- \n
- 8 pipeline stages (PP=8) \n
- 2 GPUs for tensor parallelism within each stage (TP=2) \n
- Then replicate 4 copies for data parallelism (DP=4) \n
- Total: 8 Γ 2 Γ 4 = 64 GPUs \n
This is 3D parallelismβthe standard configuration for modern large model training.
\n\n
Data Engineering
\nGood models require good dataβdata engineering accounts for over 60% of AI system workload.
\nData Collection and Cleaning Pipeline
\nTraining data typically comes from multiple sources: web pages, books, code, conversations, etc.
\nTypical processing workflow:
\n- \n
- Deduplication: Remove duplicate or highly similar documents \n
- Quality filtering: Remove low-quality, toxic, or biased content \n
- Format unification: Convert different sources to a unified format \n
- Tokenization: Convert text to model input sequences \n
Data Deduplication: MinHash LSH
\nDirectly computing pairwise document similarity is too slow; the common method is MinHash + LSH (Locality Sensitive Hashing).
\nThe idea is: convert each document into a short "fingerprint", where similar documents have fingerprints that are likely the same or similar, then group by fingerprint.
\nExamples
\n\nimport hashlib\nimport re\nfrom typing import List, Set, Dict, Tuple\n\ndef generate_shingles(text: str, k: int=5) -> Set:\n """Generate k-shingles: sequences of k consecutive words\n e.g., "I love tutorial tutorials", k=2 β {"I love", "love tutorial", "tutorial tutorials"}\n """\n # Simple tokenization (professional tools can be used in production)\n words = re.findall(r'w+', text.lower())\n shingles = set()\n for i in range(len(words) - k + 1):\n shingle = ' '.join(words[i:i+k])\n shingles.add(shingle)\n return shingles\n\ndef minhash_signature(shingles: Set, num_hashes: int=100) -> List:\n """Generate MinHash signature\n Use multiple hash functions, each taking the minimum value\n """\n signature = []\n for i in range(num_hashes):\n # Use i as seed to generate different hash functions\n min_hash = None\n for shingle in shingles:\n # Combine shingle and i to generate hash value\n h = hashlib.sha256(f"{shingle}-{i}".encode()).hexdigest()\n h_int = int(h, 16)\n if min_hash is None or h_int List:\n """Use banding method to generate LSH keys\n Split signature into multiple bands, each band is hashed separately\n """\n keys = []\n rows_per_band = len(signature) // bands\n for i in range(bands):\n start = i * rows_per_band\n end = start + rows_per_band\n band = tuple(signature[start:end])\n # Hash this band to generate a key\n band_hash = hashlib.sha256(str(band).encode()).hexdigest()[:16]\n keys.append(f"band-{i}-{band_hash}")\n return keys\n\ndef deduplicate_documents(documents: List,\n threshold: float=0.7) -> List:\n """Deduplicate documents using MinHash + LSH\n Returns deduplicated document list\n """\n # Storage: LSH key β list of document indices\n buckets: Dict[str, List] = {}\n # Storage: document index β signature\n signatures: Dict[int, List] = {}\n # Marks: which documents are duplicates\n duplicates: Set = set()\n\n for idx, doc in enumerate(documents):\n shingles = generate_shingles(doc)\n sig = minhash_signature(shingles)\n signatures = sig\n keys = lsh_banding(sig)\n\n # Check if similar documents already exist\n is_duplicate = False\n for key in keys:\n if key in buckets:\n # There are documents in this bucket, compare signatures one by one\n for other_idx in buckets:\n other_sig = signatures\n # Calculate signature similarity (Jaccard approximation)\n matches = sum(1 for a, b in zip(sig, other_sig) if a == b)\n similarity = matches / len(sig)\n if similarity >= threshold:\n # Exceeds threshold, considered duplicate\n is_duplicate = True\n duplicates.add(idx)\n break\n if is_duplicate:\n break\n\n if not is_duplicate:\n # Not a duplicate, add self to each bucket\n for key in keys:\n if key not in buckets:\n buckets = []\n buckets.append(idx)\n\n # Return non-duplicate documents\n return [doc for idx, doc in enumerate(documents) if idx not in duplicates]\n\n# ============================================\n# Test tutorial data deduplication\n# ============================================\n\nif __name__ == "__main__":\n documents = [\n "Welcome to tutorial tutorials, this is a great place to learn programming.",\n "Welcome to tutorial tutorials, this is a great place to learn programming.", # Highly similar\n "Python is a concise and elegant language, suitable for beginners.",\n "Python is a concise and elegant programming language, very suitable for beginners.", # Highly similar\n "Machine learning lets computers learn patterns from data.",\n "This is a completely different article.",\n ]\n\n print(f"Before deduplication: {len(documents)} documents")\n deduplicated = deduplicate_documents(documents, threshold=0.6)\n print(f"After deduplication: {len(deduplicated)} documentsn")\n\n print("Retained documents:")\n for i, doc in enumerate(deduplicated):\n print(f" [{i}] {doc}")\n\n # Output:\n # Before deduplication: 6 documents\n # After deduplication: 4 documents\n #\n # Retained documents:\n # Welcome to tutorial tutorials, this is a great place to learn programming.\n # Python is a concise and elegant language, suitable for beginners.\n # Machine learning lets computers learn patterns from data.\n # This is a completely different article.\n \n In actual production, more efficient implementations are used (such as the datasketch library), but the core idea is the same.
\nData Format: WebDataset
\nSmall datasets can be stored casually, but TB-scale datasets need specialized formats.
\nWebDataset is a commonly used format: it packages files into tar archives, with each tar containing thousands of samples, supporting random and sequential access.
\nBenefits are:
\n- \n
- Reduces filesystem pressure (millions of small files are slow) \n
- Supports streaming access, no need to load entire dataset into memory \n
- Can be loaded distributedly, with each worker reading different tars \n
\n
Data Flywheel Design
\nData fly
YouTip