Machine learning during new times
Picture a data scientist or an ML engineer in 2015, working diligently on their laptop to train a machine learning model. Their workflow was straightforward: load data into memory, process it, train the model, and deploy.
We (the authors) have a vivid memory of this.
Fast forward to today, and this simple approach seems extremely naive. The explosion of big data and the increasing complexity of modern ML, from LLMs to AI agents, have fundamentally changed how we approach machine learning inference.
Inference at Scale
The journey from single-machine learning to distributed systems mirrors the broader evolution of computing itself. Just as early computers gave way to networked systems and eventually to cloud computing, machine learning has undergone a similar transformation. Inference at scale refers to the process of using trained machine learning models to make predictions on new data, scaled to handle massive datasets or real-time requirements.
Why Traditional Approaches No Longer Suffice
Consider a modern recommendation engine processing millions of user interactions per second or a fraud detection system analyzing transactions in real-time across global markets. These scenarios represent a perfect storm of challenges that have pushed us beyond the capabilities of traditional single-node solutions. Modern machine learning faces three critical demands:
Processing massive volumes of data that exceed single-machine capacity.
Delivering predictions with near-instantaneous latency.
Integrating seamlessly with modern data architectures in the cloud.
The rise of the Data Lake Architecture
A data lake architecture provides a centralized repository to store structured and unstructured data at any scale, leveraging technologies like Apache Kafka, Spark, and Iceberg, just to name a few. This has become a foundational component in modern distributed machine learning workflows.
This timeless tweet by Matt Turck hits home:
The Distributed ML Paradigm Shift
As backend and data engineers experienced the shift to scale work with data 20 years ago with the MapReduce paradigm shift, machine learning is catching up and distributed inference is here to stay. Distributed ML represents more than just splitting work across machines; it affects how data scientists and ML engineers work, often also called MLOps. This shift manifests in two primary approaches:
Data Parallelism: Each machine processes a portion of the dataset using the same model, enabling analysis of massive datasets efficiently.
Model Parallelism: Different machines handle separate parts of the model computations, optimizing resources for highly complex architectures.
Common Distributed Compute Frameworks for Inference
Distributed inference benefits from several frameworks optimized for scalability and flexibility:
Apache Spark: Combines big data processing with distributed ML pipelines, excelling in batch inference, streaming data predictions, and data preparation.
TensorFlow and PyTorch: Both support distributed training and scalable inference. TensorFlow uses TensorFlow Serving for production environments, while PyTorch leverages
torch.distributed
for custom setups.Horovod: Optimized for distributed training, it also supports scalable inference for GPU-based workloads across multiple nodes.
Ray Serve: Designed for high-performance distributed inference, it integrates seamlessly with Python-based models like TensorFlow and PyTorch.
H2O.ai and Databricks: Offer platforms for distributed training and inference, with Databricks specifically integrating Apache Spark for unified data and ML workflows.
ONNX Runtime: Lightweight and portable, it enables fast inference across diverse environments.
Microsoft CNTK: Supports large-scale deep learning with distributed training and inference.
These frameworks provide robust options for building efficient, scalable machine learning pipelines, tailored to diverse use cases and environments.
Deep Dive: Apache Spark's Approach to Scalable Inference
The evolution of distributed computing has given rise to a rich ecosystem of frameworks, each with its strengths. As this is Big Data Performance Weekly, we will focus on Spark, which is especially useful for organizations already invested in big data infrastructure and have Spark in their stack.
Spark’s success in distributed machine learning stems from its ability to seamlessly integrate with existing data workflows while providing powerful tools for both preprocessing and inference. Spark offers three primary approaches to deploying models:
Batch Inference Approach: Best for scenarios where predictions are performed on pre-collected datasets, such as generating insights or reports.
REST Endpoint Approach: Creates a dedicated service to answer prediction requests.
Streaming Integration Approach: Embeds the model directly into data streams for real-time processing.
Many production systems end up combining these approaches—Spark Streaming for real-time data, batch inference plus a REST endpoint for user-facing or ad-hoc predictions.
For more in-depth overview of the second and third approches, you are welcome to check Avichay’s pesonal blog post.
Example: Batch inference
Here’s a basic example where we have pre-trained an ML model and use spark for inference at scale:
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.SparkSession
object BatchInferenceExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("BatchInferenceExample")
.getOrCreate()
// Load a pre-trained model from storage
val pipelineModel = PipelineModel.load("s3://my-models/spark-pipeline")
// Load a large dataset for prediction
val inputData = spark.read.format("parquet").load("s3://my-data/input")
// Perform distributed inference
val predictions = pipelineModel.transform(inputData)
predictions.show()
}
}
What happens here?
Spark reads the data and distributes it across worker nodes.
The
PipelineModel
is loaded and broadcast to all nodes.Each node applies the model to its portion of the data.
This parallel processing makes Spark well-suited for batch inference over massive datasets.
Performance and Optimization Tips for Spark tInference
We are here for performance!
We have stumbled upon a few frequent mistakes. Below are consolidated tips to ensure your workflows are optimized for scale:
1. Optimize Model Broadcasting
When performing distributed inference, efficiently broadcasting the model to worker nodes is crucial. Without optimization, Spark sends the model from the driver to each worker node redundantly, leading to:
High Network Overhead: The model is repeatedly transferred to the cluster for every task.
Increased Latency: Workers spend time waiting for the model, delaying computation.
Resource Waste: Unnecessary memory and bandwidth are consumed.
Optimized Approach: Use Spark's broadcast
mechanism to send the model to all worker nodes only once. Each worker keeps a local copy of the model, avoiding repeated network transfers:
val broadcastModel = spark.sparkContext.broadcast(pipelineModel)
Broadcasting is especially beneficial when:
The model is large and could strain the network.
Multiple tasks within a partition need access to the same model.
By leveraging broadcast
, you can significantly reduce communication costs and speed up inference.
2. Leverage the Right UDF Patterns for Efficient Inference
When using Pandas UDFs for inference, the pattern you choose impacts performance, especially with larger models:
Series-to-Series UDFs:
This pattern initializes the model for each batch of data within a partition, which can be very slow in scenarios with frequent batch processing or larger models.Iterator-to-Iterator UDFs:
This approach initializes the model once per partition, reducing the initialization overhead and making it more suitable for distributed inference at scale.
Why It Matters: If your model is large (e.g., loaded with libraries like joblib), the repeated initialization in Series-to-Series UDFs can lead to unnecessary delays. Switching to an Iterator-to-Iterator pattern minimizes these delays, allowing the model to be reused across batches within the same partition.
More on this topic can be found in this Databricks UDF documentaion,
3. Understand Estimators and Transformers in MLlib
Spark MLlib distinguishes between Estimators and Transformers, a fundamental concept that underpins its machine learning workflows:
Estimators:
Objects that learn from data by performing afit
operation. They produce a trained model (Transformer). Examples include:Distributed algorithms like
RandomForestClassifier
,PCA
, orALS
.Feature engineering tools like
OneHotEncoder
orMinMaxScaler
.
Transformers (no, its not LLM’s) :
Objects that apply transformations to data. These include:Stateless Transformers: Require no training, e.g.,
VectorAssembler
.Transformers resulting from Estimators: E.g., a trained RandomForest model.
Pipelines allow combining multiple Estimators and Transformers into a single workflow. The fit
operation on a Pipeline produces a PipelineModel
, which itself is a Transformer capable of transforming data end-to-end.
By understanding this structure, you can design more effective and maintainable ML workflows in Spark.
Recommnded watch of the week:
Matt Turck, who was already mentioned, hosted Chip Huyen on his podcast. She is known for her expertise in machine learning systems and her impactful writing on AI. They discussed topics related to AI engineering, and this episode dovetails perfectly with our post.
Conclusion
Distributed inference is critical for scaling machine learning workflows to meet modern data demands. While options like Ray Serve and TensorFlow Serving offer specialized solutions, Apache Spark’s integration with big data pipelines and ability to handle both preprocessing and inference make it a versatile choice for many use cases. By understanding backend concepts, leveraging Spark’s strengths, and avoiding common pitfalls, ML engineers can build scalable, efficient inference pipelines that meet the needs of today’s data driven world.
If you found this content insightful, don’t miss out!
hit the subscribe button below to stay updated with the latest in big data performance.
Share this article with your colleagues who could benefit from these insights, and help grow our community!
We’d love to hear your thoughts and ideas, if you have something to contribute, feel free to join the conversation and make an impact!
Beautiful writeup. Keep it up