ml-pipeline-automation
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseML Pipeline Automation
机器学习流水线自动化
Orchestrate end-to-end machine learning workflows from data ingestion to production deployment with production-tested Airflow, Kubeflow, and MLflow patterns.
借助经过生产验证的Airflow、Kubeflow和MLflow实践,编排从数据采集到生产部署的端到端机器学习工作流。
When to Use This Skill
何时使用该技能
Load this skill when:
- Building ML Pipelines: Orchestrating data → train → deploy workflows
- Scheduling Retraining: Setting up automated model retraining schedules
- Experiment Tracking: Tracking experiments, parameters, metrics across runs
- MLOps Implementation: Building reproducible, monitored ML infrastructure
- Workflow Orchestration: Managing complex multi-step ML workflows
- Model Registry: Managing model versions and deployment lifecycle
在以下场景中使用本技能:
- 构建ML流水线:编排数据→训练→部署的完整工作流
- 调度重训练任务:设置模型自动重训练的调度计划
- 实验追踪:跨运行周期追踪实验、参数与指标
- 落地MLOps:构建可复现、可监控的机器学习基础设施
- 工作流编排:管理复杂的多步骤ML工作流
- 模型注册表:管理模型版本与部署生命周期
Quick Start: ML Pipeline in 5 Steps
快速上手:5步搭建ML流水线
bash
undefinedbash
undefined1. Install Airflow and MLflow (check for latest versions at time of use)
1. 安装Airflow和MLflow(使用时请确认最新版本)
pip install apache-airflow==3.1.5 mlflow==3.7.0
pip install apache-airflow==3.1.5 mlflow==3.7.0
Note: These versions are current as of December 2025
注意:以上版本为2025年12月的当前稳定版
Check PyPI for latest stable releases: https://pypi.org/project/apache-airflow/
可在PyPI查看最新稳定版本:https://pypi.org/project/apache-airflow/
2. Initialize Airflow database
2. 初始化Airflow数据库
airflow db init
airflow db init
3. Create DAG file: dags/ml_training_pipeline.py
3. 创建DAG文件:dags/ml_training_pipeline.py
cat > dags/ml_training_pipeline.py << 'EOF'
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'ml-team',
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
schedule_interval='@daily',
start_date=datetime(2025, 1, 1)
)
def train_model(**context):
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('iris-training')
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
mlflow.log_metric('accuracy', accuracy)
mlflow.sklearn.log_model(model, 'model')train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
EOF
cat > dags/ml_training_pipeline.py << 'EOF'
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'ml-team',
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
schedule_interval='@daily',
start_date=datetime(2025, 1, 1)
)
def train_model(**context):
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('iris-training')
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
mlflow.log_metric('accuracy', accuracy)
mlflow.sklearn.log_model(model, 'model')train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
EOF
4. Start Airflow scheduler and webserver
4. 启动Airflow调度器与Web服务
airflow scheduler &
airflow webserver --port 8080 &
airflow scheduler &
airflow webserver --port 8080 &
5. Trigger pipeline
5. 触发流水线
airflow dags trigger ml_training_pipeline
airflow dags trigger ml_training_pipeline
Access UI: http://localhost:8080
访问UI界面:http://localhost:8080
**Result**: Working ML pipeline with experiment tracking in under 5 minutes.
**结果**:在5分钟内搭建起带有实验追踪功能的可用ML流水线。Core Concepts
核心概念
Pipeline Stages
流水线阶段
- Data Collection → Fetch raw data from sources
- Data Validation → Check schema, quality, distributions
- Feature Engineering → Transform raw data to features
- Model Training → Train with hyperparameter tuning
- Model Evaluation → Validate performance on test set
- Model Deployment → Push to production if metrics pass
- Monitoring → Track drift, performance in production
- 数据采集 → 从数据源获取原始数据
- 数据验证 → 检查数据 schema、质量与分布
- 特征工程 → 将原始数据转换为特征
- 模型训练 → 结合超参数调优进行模型训练
- 模型评估 → 在测试集上验证模型性能
- 模型部署 → 若指标达标则推送至生产环境
- 监控 → 追踪生产环境中的数据漂移与模型性能
Orchestration Tools Comparison
编排工具对比
| Tool | Best For | Strengths |
|---|---|---|
| Airflow | General ML workflows | Mature, flexible, Python-native |
| Kubeflow | Kubernetes-native ML | Container-based, scalable |
| MLflow | Experiment tracking | Model registry, versioning |
| Prefect | Modern Python workflows | Dynamic DAGs, native caching |
| Dagster | Asset-oriented pipelines | Data-aware, testable |
| 工具 | 适用场景 | 优势 |
|---|---|---|
| Airflow | 通用ML工作流 | 成熟稳定、灵活、原生支持Python |
| Kubeflow | Kubernetes原生ML工作流 | 基于容器、可扩展性强 |
| MLflow | 实验追踪 | 提供模型注册表与版本管理 |
| Prefect | 现代Python工作流 | 动态DAG、原生缓存机制 |
| Dagster | 面向资产的流水线 | 数据感知、可测试性强 |
Basic Airflow DAG
基础Airflow DAG示例
python
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'email': ['alerts@example.com'],
'email_on_failure': True,
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
description='End-to-end ML training pipeline',
schedule_interval='@daily',
start_date=datetime(2025, 1, 1),
catchup=False
)
def validate_data(**context):
"""Validate input data quality."""
import pandas as pd
data_path = "/data/raw/latest.csv"
df = pd.read_csv(data_path)
# Validation checks
assert len(df) > 1000, f"Insufficient data: {len(df)} rows"
assert df.isnull().sum().sum() < len(df) * 0.1, "Too many nulls"
context['ti'].xcom_push(key='data_path', value=data_path)
logger.info(f"Data validation passed: {len(df)} rows")
def train_model(**context):
"""Train ML model with MLflow tracking."""
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
data_path = context['ti'].xcom_pull(key='data_path', task_ids='validate_data')
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('production-training')
with mlflow.start_run():
# Training logic here
model = RandomForestClassifier(n_estimators=100)
# model.fit(X, y) ...
mlflow.log_param('n_estimators', 100)
mlflow.sklearn.log_model(model, 'model')
validate = PythonOperator(
task_id='validate_data',
python_callable=validate_data,
dag=dag
)
train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
validate >> trainpython
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'email': ['alerts@example.com'],
'email_on_failure': True,
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
description='端到端ML训练流水线',
schedule_interval='@daily',
start_date=datetime(2025, 1, 1),
catchup=False
)
def validate_data(**context):
"""验证输入数据质量。"""
import pandas as pd
data_path = "/data/raw/latest.csv"
df = pd.read_csv(data_path)
# 验证检查
assert len(df) > 1000, f"数据量不足:仅{len(df)}行"
assert df.isnull().sum().sum() < len(df) * 0.1, "空值占比过高"
context['ti'].xcom_push(key='data_path', value=data_path)
logger.info(f"数据验证通过:共{len(df)}行")
def train_model(**context):
"""结合MLflow追踪训练ML模型。"""
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
data_path = context['ti'].xcom_pull(key='data_path', task_ids='validate_data')
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('production-training')
with mlflow.start_run():
# 训练逻辑(此处省略)
model = RandomForestClassifier(n_estimators=100)
# model.fit(X, y) ...
mlflow.log_param('n_estimators', 100)
mlflow.sklearn.log_model(model, 'model')
validate = PythonOperator(
task_id='validate_data',
python_callable=validate_data,
dag=dag
)
train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
validate >> trainKnown Issues Prevention
常见问题预防
1. Task Failures Without Alerts
1. 任务失败无告警
Problem: Pipeline fails silently, no one notices until users complain.
Solution: Configure email/Slack alerts on failure:
python
default_args = {
'email': ['ml-team@example.com'],
'email_on_failure': True,
'email_on_retry': False
}
def on_failure_callback(context):
"""Send Slack alert on failure."""
from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator
slack_msg = f"""
:red_circle: Task Failed: {context['task_instance'].task_id}
DAG: {context['task_instance'].dag_id}
Execution Date: {context['ds']}
Error: {context.get('exception')}
"""
SlackWebhookOperator(
task_id='slack_alert',
slack_webhook_conn_id='slack_webhook',
message=slack_msg
).execute(context)
task = PythonOperator(
task_id='critical_task',
python_callable=my_function,
on_failure_callback=on_failure_callback,
dag=dag
)问题:流水线静默失败,直到用户反馈才被发现。
解决方案:配置失败时的邮件/Slack告警:
python
default_args = {
'email': ['ml-team@example.com'],
'email_on_failure': True,
'email_on_retry': False
}
def on_failure_callback(context):
"""任务失败时发送Slack告警。"""
from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator
slack_msg = f"""
:red_circle: 任务失败:{context['task_instance'].task_id}
DAG: {context['task_instance'].dag_id}
执行日期:{context['ds']}
错误信息:{context.get('exception')}
"""
SlackWebhookOperator(
task_id='slack_alert',
slack_webhook_conn_id='slack_webhook',
message=slack_msg
).execute(context)
task = PythonOperator(
task_id='critical_task',
python_callable=my_function,
on_failure_callback=on_failure_callback,
dag=dag
)2. Missing XCom Data Between Tasks
2. 任务间XCom数据丢失
Problem: Task expects XCom value from previous task, gets None, crashes.
Solution: Always validate XCom pulls:
python
def process_data(**context):
data_path = context['ti'].xcom_pull(
key='data_path',
task_ids='upstream_task'
)
if data_path is None:
raise ValueError("No data_path from upstream_task - check XCom push")
# Process data...问题:任务期望从上游获取XCom值,但得到None导致崩溃。
解决方案:始终验证XCom拉取结果:
python
def process_data(**context):
data_path = context['ti'].xcom_pull(
key='data_path',
task_ids='upstream_task'
)
if data_path is None:
raise ValueError("未从upstream_task获取到data_path - 请检查XCom推送逻辑")
# 数据处理...3. DAG Not Appearing in UI
3. DAG未在UI中显示
Problem: DAG file exists in but doesn't show in Airflow UI.
dags/Solution: Check DAG parsing errors:
bash
undefined问题:DAG文件存在于目录,但未在Airflow UI中展示。
dags/解决方案:检查DAG解析错误:
bash
undefinedCheck for syntax errors
检查语法错误
python dags/my_dag.py
python dags/my_dag.py
View DAG import errors in UI
在UI中查看DAG导入错误
Navigate to: Browse → DAG Import Errors
路径:Browse → DAG Import Errors
Common fixes:
常见修复方案:
1. Ensure DAG object is defined in file
1. 确保文件中定义了DAG对象
2. Check for circular imports
2. 检查是否存在循环导入
3. Verify all dependencies installed
3. 验证所有依赖已安装
4. Fix syntax errors
4. 修复语法错误
undefinedundefined4. Hardcoded Paths Break in Production
4. 硬编码路径在生产环境失效
Problem: Paths like work locally, fail in production.
/Users/myname/data/Solution: Use Airflow Variables or environment variables:
python
from airflow.models import Variable
def load_data(**context):
# ❌ Bad: Hardcoded path
# data_path = "/Users/myname/data/train.csv"
# ✅ Good: Use Airflow Variable
data_dir = Variable.get("data_directory", "/data")
data_path = f"{data_dir}/train.csv"
# Or use environment variable
import os
data_path = os.getenv("DATA_PATH", "/data/train.csv")问题:本地路径如可正常工作,但在生产环境失败。
/Users/myname/data/解决方案:使用Airflow变量或环境变量:
python
from airflow.models import Variable
def load_data(**context):
# ❌ 错误方式:硬编码路径
# data_path = "/Users/myname/data/train.csv"
# ✅ 正确方式:使用Airflow变量
data_dir = Variable.get("data_directory", "/data")
data_path = f"{data_dir}/train.csv"
# 或使用环境变量
import os
data_path = os.getenv("DATA_PATH", "/data/train.csv")5. Stuck Tasks Consume Resources
5. 停滞任务占用资源
Problem: Task hangs indefinitely, blocks worker slot, wastes resources.
Solution: Set execution_timeout on tasks:
python
from datetime import timedelta
task = PythonOperator(
task_id='long_running_task',
python_callable=my_function,
execution_timeout=timedelta(hours=2), # Kill after 2 hours
dag=dag
)问题:任务无限期挂起,占用工作节点资源。
解决方案:为任务设置执行超时时间:
python
from datetime import timedelta
task = PythonOperator(
task_id='long_running_task',
python_callable=my_function,
execution_timeout=timedelta(hours=2), # 2小时后终止任务
dag=dag
)6. No Data Validation = Bad Model Training
6. 无数据验证导致模型训练效果差
Problem: Train on corrupted/incomplete data, model performs poorly in production.
Solution: Add data quality validation tasks:
python
def validate_data_quality(**context):
"""Comprehensive data validation."""
import pandas as pd
df = pd.read_csv(data_path)
# Schema validation
required_cols = ['user_id', 'timestamp', 'feature_a', 'target']
missing_cols = set(required_cols) - set(df.columns)
if missing_cols:
raise ValueError(f"Missing columns: {missing_cols}")
# Statistical validation
if df['target'].isnull().sum() > 0:
raise ValueError("Target column contains nulls")
if len(df) < 1000:
raise ValueError(f"Insufficient data: {len(df)} rows")
logger.info("✅ Data quality validation passed")问题:使用损坏或不完整的数据训练模型,导致生产环境性能不佳。
解决方案:添加数据质量验证任务:
python
def validate_data_quality(**context):
"""全面的数据质量验证。"""
import pandas as pd
df = pd.read_csv(data_path)
# Schema验证
required_cols = ['user_id', 'timestamp', 'feature_a', 'target']
missing_cols = set(required_cols) - set(df.columns)
if missing_cols:
raise ValueError(f"缺失列:{missing_cols}")
# 统计验证
if df['target'].isnull().sum() > 0:
raise ValueError("目标列包含空值")
if len(df) < 1000:
raise ValueError(f"数据量不足:仅{len(df)}行")
logger.info("✅ 数据质量验证通过")7. Untracked Experiments = Lost Knowledge
7. 未追踪实验导致知识丢失
Problem: Can't reproduce results, don't know which hyperparameters worked.
Solution: Use MLflow for all experiments:
python
import mlflow
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('model-experiments')
with mlflow.start_run(run_name='rf_v1'):
# Log ALL hyperparameters
mlflow.log_params({
'model_type': 'random_forest',
'n_estimators': 100,
'max_depth': 10,
'random_state': 42
})
# Log ALL metrics
mlflow.log_metrics({
'train_accuracy': 0.95,
'test_accuracy': 0.87,
'f1_score': 0.89
})
# Log model
mlflow.sklearn.log_model(model, 'model')问题:无法复现实验结果,不清楚哪些超参数有效。
解决方案:使用MLflow追踪所有实验:
python
import mlflow
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('model-experiments')
with mlflow.start_run(run_name='rf_v1'):
# 记录所有超参数
mlflow.log_params({
'model_type': 'random_forest',
'n_estimators': 100,
'max_depth': 10,
'random_state': 42
})
# 记录所有指标
mlflow.log_metrics({
'train_accuracy': 0.95,
'test_accuracy': 0.87,
'f1_score': 0.89
})
# 记录模型
mlflow.sklearn.log_model(model, 'model')When to Load References
何时加载参考文档
Load reference files for detailed production implementations:
-
Airflow DAG Patterns: Loadwhen building complex DAGs with error handling, dynamic generation, sensors, task groups, or retry logic. Contains complete production DAG examples.
references/airflow-patterns.md -
Kubeflow & MLflow Integration: Loadwhen using Kubeflow Pipelines for container-native orchestration, integrating MLflow tracking, building KFP components, or managing model registry.
references/kubeflow-mlflow.md -
Pipeline Monitoring: Loadwhen implementing data quality checks, drift detection, alert configuration, or pipeline health monitoring with Prometheus.
references/pipeline-monitoring.md
在以下场景加载参考文件以获取详细的生产级实现方案:
-
Airflow DAG实践:当构建带有错误处理、动态生成、传感器、任务组或重试逻辑的复杂DAG时,加载。该文档包含完整的生产级DAG示例。
references/airflow-patterns.md -
Kubeflow与MLflow集成:当使用Kubeflow Pipelines进行容器原生编排、集成MLflow追踪、构建KFP组件或管理模型注册表时,加载。
references/kubeflow-mlflow.md -
流水线监控:当实现数据质量检查、漂移检测、告警配置或使用Prometheus进行流水线健康监控时,加载。
references/pipeline-monitoring.md
Best Practices
最佳实践
- Idempotent Tasks: Tasks should produce same result when re-run
- Atomic Operations: Each task does one thing well
- Version Everything: Data, code, models, dependencies
- Comprehensive Logging: Log all important events with context
- Error Handling: Fail fast with clear error messages
- Monitoring: Track pipeline health, data quality, model drift
- Testing: Test tasks independently before integrating
- Documentation: Document DAG purpose, task dependencies
- 幂等任务:任务重新运行时应产生相同结果
- 原子操作:每个任务专注完成一件事
- 版本化所有内容:数据、代码、模型、依赖均需版本化
- 全面日志:记录所有重要事件及上下文
- 错误处理:快速失败并提供清晰的错误信息
- 监控:追踪流水线健康状态、数据质量与模型漂移
- 测试:集成前独立测试每个任务
- 文档:记录DAG用途与任务依赖关系
Common Patterns
常见模式
Conditional Execution
条件执行
python
from airflow.operators.python import BranchPythonOperator
def choose_branch(**context):
accuracy = context['ti'].xcom_pull(key='accuracy', task_ids='evaluate')
if accuracy > 0.9:
return 'deploy_to_production'
else:
return 'retrain_with_more_data'
branch = BranchPythonOperator(
task_id='check_accuracy',
python_callable=choose_branch,
dag=dag
)
train >> evaluate >> branch >> [deploy, retrain]python
from airflow.operators.python import BranchPythonOperator
def choose_branch(**context):
accuracy = context['ti'].xcom_pull(key='accuracy', task_ids='evaluate')
if accuracy > 0.9:
return 'deploy_to_production'
else:
return 'retrain_with_more_data'
branch = BranchPythonOperator(
task_id='check_accuracy',
python_callable=choose_branch,
dag=dag
)
train >> evaluate >> branch >> [deploy, retrain]Parallel Training
并行训练
python
from airflow.utils.task_group import TaskGroup
with TaskGroup('train_models', dag=dag) as train_group:
train_rf = PythonOperator(task_id='train_rf', ...)
train_lr = PythonOperator(task_id='train_lr', ...)
train_xgb = PythonOperator(task_id='train_xgb', ...)python
from airflow.utils.task_group import TaskGroup
with TaskGroup('train_models', dag=dag) as train_group:
train_rf = PythonOperator(task_id='train_rf', ...)
train_lr = PythonOperator(task_id='train_lr', ...)
train_xgb = PythonOperator(task_id='train_xgb', ...)All models train in parallel
所有模型并行训练
preprocess >> train_group >> select_best
undefinedpreprocess >> train_group >> select_best
undefinedWaiting for Data
等待数据就绪
python
from airflow.sensors.filesystem import FileSensor
wait_for_data = FileSensor(
task_id='wait_for_data',
filepath='/data/input/{{ ds }}.csv',
poke_interval=60, # Check every 60 seconds
timeout=3600, # Timeout after 1 hour
mode='reschedule', # Don't block worker
dag=dag
)
wait_for_data >> process_datapython
from airflow.sensors.filesystem import FileSensor
wait_for_data = FileSensor(
task_id='wait_for_data',
filepath='/data/input/{{ ds }}.csv',
poke_interval=60, # 每60秒检查一次
timeout=3600, # 1小时后超时
mode='reschedule', # 不占用工作节点
dag=dag
)
wait_for_data >> process_data