2026/4/11 12:11:55
网站建设
项目流程
搭建国外网站的步骤,广东省住房和城乡建设局官网,贵阳网站设计公司,计算机网络技术主要就业方向ms-swift扩展性揭秘#xff1a;如何自定义loss函数和优化器
在大模型微调实践中#xff0c;一个常被忽视却至关重要的能力是——框架是否真正开放其训练内核。很多开发者在使用主流微调工具时会遇到这样的困境#xff1a;当标准交叉熵损失无法满足特定任务需求#xff08;…ms-swift扩展性揭秘如何自定义loss函数和优化器在大模型微调实践中一个常被忽视却至关重要的能力是——框架是否真正开放其训练内核。很多开发者在使用主流微调工具时会遇到这样的困境当标准交叉熵损失无法满足特定任务需求比如需要引入领域知识约束、多目标平衡或对抗性正则或者默认AdamW优化器在稀疏梯度场景下收敛缓慢时往往束手无策。他们要么被迫修改底层源码、破坏框架升级路径要么绕道重写整个训练循环丧失ms-swift带来的工程效率优势。而ms-swift的设计哲学从一开始就拒绝这种妥协它不是把训练逻辑“黑盒化”后提供一层薄薄的API而是将训练流程解耦为可插拔、可替换、可组合的模块化组件。其中loss函数与优化器正是最核心、最常需定制的两个环节。本文将带你深入ms-swift的扩展机制不讲抽象概念只聚焦三件事怎么安全地替换默认loss且不影响其他功能如LoRA注入、梯度检查点如何无缝接入自定义优化器比如Lion、Adan或带warmup重启动的SGD真实案例演示为医疗问答任务添加实体一致性loss使模型在生成答案时自动对齐病历中的关键实体所有操作均基于ms-swift v1.10版本无需修改任何源码全部通过配置与轻量Python代码完成。1. 理解ms-swift的训练流程架构为什么能轻松定制在动手前先建立一个清晰的认知地图。ms-swift并非简单封装HuggingFace Transformers的Trainer而是构建了一套更细粒度的控制流--------------------- ---------------------- ---------------------- | 数据加载与预处理 | -- | 模型前向与loss计算 | -- | 反向传播与参数更新 | | (Dataset, Collator) | | (Model.forward loss) | | (Optimizer.step ...) | --------------------- ---------------------- ---------------------- | | | v v v --------------------------------------------------------------- | 训练主循环Seq2SeqTrainer | | - 自动管理梯度累积、混合精度、分布式同步、日志记录等 | --------------------------------------------------------------- | v ---------------------- | 用户可干预的钩子点 | | on_train_begin | ← 可注册初始化逻辑 | compute_loss | ← 本文重点loss计算入口 | create_optimizer | ← 本文重点优化器创建入口 | on_step_end | ← 可注入梯度裁剪、权重冻结等 ----------------------关键洞察在于compute_loss和create_optimizer这两个钩子是ms-swift官方明确暴露给用户的标准化扩展接口。它们被设计为纯函数式、无状态、与框架内部调度完全解耦——这意味着你的自定义逻辑可以像搭积木一样插入而不会干扰到LoRA参数隔离、序列并行Ulysses/Ring-Attention、vLLM推理加速等高级特性。重要提示所有自定义必须通过继承Seq2SeqTrainer并重写对应方法实现切勿直接修改ms-swift源码中的trainer.py。这保证了未来框架升级时你的业务代码依然可用。2. 自定义Loss函数从零开始构建领域感知损失2.1 标准Loss替换三步完成交叉熵增强假设你正在微调一个用于法律文书摘要的模型发现标准CrossEntropyLoss会让模型过度关注高频词汇如“原告”、“被告”而忽略关键法律条款编号如“《民法典》第1165条”。你需要一个能强化条款识别能力的loss。ms-swift的compute_loss方法接收两个参数model当前训练模型和inputs字典格式的batch数据含input_ids,labels,attention_mask等。返回值必须是标量Tensorloss值。from torch import nn import torch class LegalClauseLoss(nn.Module): def __init__(self, base_loss_fnnn.CrossEntropyLoss(), clause_weight2.0): super().__init__() self.base_loss base_loss_fn self.clause_weight clause_weight # 预编译正则匹配《XXX》第XXX条模式 import re self.clause_pattern re.compile(r《[^》]》第\d条) def forward(self, logits, labels): # 1. 基础交叉熵损失 shift_logits logits[..., :-1, :].contiguous() shift_labels labels[..., 1:].contiguous() base_loss self.base_loss( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) # 2. 条款识别奖励鼓励模型在label含条款时logits对应位置概率更高 # 获取label中所有条款token的索引简化版假设条款token id已知 # 实际项目中应通过tokenizer.convert_tokens_to_ids获取 clause_token_ids [12345, 67890] # 示例ID需根据实际tokenizer确定 clause_mask torch.isin(shift_labels, torch.tensor(clause_token_ids, devicelabels.device)) if clause_mask.any(): # 提取clause_mask为True位置的logits概率 clause_logits shift_logits[clause_mask] clause_labels shift_labels[clause_mask] # 计算这些位置的交叉熵仅针对条款token clause_loss nn.functional.cross_entropy( clause_logits, clause_labels, reductionmean ) # 将条款损失加权回总loss注意这是奖励所以用减法 total_loss base_loss - self.clause_weight * clause_loss else: total_loss base_loss return total_loss # 创建自定义Trainer类 from swift import Seq2SeqTrainer class CustomLegalTrainer(Seq2SeqTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.criterion LegalClauseLoss() def compute_loss(self, model, inputs, return_outputsFalse): outputs model(**inputs) logits outputs.get(logits) labels inputs.get(labels) loss self.criterion(logits, labels) return (loss, outputs) if return_outputs else loss关键要点解析compute_loss中必须调用model(**inputs)不能跳过前向传播。ms-swift依赖此调用触发LoRA适配层、梯度检查点等机制。labels是原始输入的标签张量-100填充需按惯例做shift处理以对齐logits预测位置。自定义loss中禁止访问model.config以外的模型属性如model.lm_head.weight否则可能破坏量化或并行逻辑。2.2 高级场景多任务联合Loss与动态权重更复杂的场景如多模态医疗问答需同时优化文本生成answer和图像区域定位bbox。ms-swift天然支持多输出模型其outputs字典可包含任意键值对。class MultiTaskLoss(nn.Module): def __init__(self, weights{text: 1.0, bbox: 0.5}): super().__init__() self.weights weights self.text_loss nn.CrossEntropyLoss() self.bbox_loss nn.SmoothL1Loss() # 或GIoULoss def forward(self, outputs, inputs): total_loss 0.0 # 文本生成loss if logits in outputs: shift_logits outputs[logits][..., :-1, :].contiguous() shift_labels inputs[labels][..., 1:].contiguous() text_loss self.text_loss( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) total_loss self.weights[text] * text_loss # 边界框回归loss if pred_boxes in outputs and gt_boxes in inputs: pred_boxes outputs[pred_boxes] gt_boxes inputs[gt_boxes] bbox_loss self.bbox_loss(pred_boxes, gt_boxes) total_loss self.weights[bbox] * bbox_loss return total_loss # 在CustomTrainer中集成 class MultiModalTrainer(Seq2SeqTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.criterion MultiTaskLoss() def compute_loss(self, model, inputs, return_outputsFalse): outputs model(**inputs) # 此处model需支持多输出 loss self.criterion(outputs, inputs) return (loss, outputs) if return_outputs else loss验证技巧在训练启动前用小batch数据手动调用compute_loss打印各子loss值确保数值合理非NaN、非Inf、量级正常。3. 自定义优化器超越AdamW的收敛控制3.1 替换优化器一行代码切换Lion优化器Lion优化器因其内存效率高、在大batch场景下收敛快而广受青睐。ms-swift的create_optimizer钩子允许你完全接管优化器构建过程。from lion_pytorch import Lion from transformers import TrainerState class LionTrainer(Seq2SeqTrainer): def create_optimizer(self): 重写create_optimizer返回自定义优化器实例 注意必须调用get_decay_parameter_names获取需衰减参数 opt_model self.model_wrapped if is_sagemaker_mp_enabled() else self.model # 1. 获取所有可训练参数自动包含LoRA权重 decay_parameters self.get_decay_parameter_names(opt_model) optimizer_grouped_parameters [ { params: [ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) ], weight_decay: self.args.weight_decay, }, { params: [ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) ], weight_decay: 0.0, }, ] # 2. 创建Lion优化器需pip install lion-pytorch optimizer Lion( optimizer_grouped_parameters, lrself.args.learning_rate, weight_decayself.args.weight_decay, use_tritonTrue # 启用Triton加速 ) return optimizer为什么安全self.get_decay_parameter_names()是ms-swift提供的标准方法能正确识别LoRA适配层如lora_A.default.weight是否参与weight decay避免错误地对低秩矩阵施加衰减。所有参数分组逻辑复用ms-swift原生逻辑确保与--train_type lora、--quant_bits 4等参数完全兼容。3.2 高级定制带warmup重启动的SGD优化器某些强化学习微调如GRPO需要优化器在每轮策略更新后重置动量以消除历史梯度干扰。我们构建一个WarmupRestartsSGDimport math from torch.optim import SGD class WarmupRestartsSGD(SGD): def __init__(self, params, lr0.01, momentum0, dampening0, weight_decay0, nesterovFalse, restart_interval1000): super().__init__(params, lr, momentum, dampening, weight_decay, nesterov) self.restart_interval restart_interval self.step_count 0 def step(self, closureNone): self.step_count 1 # 每restart_interval步重置所有参数的momentum_buffer if self.step_count % self.restart_interval 0: for group in self.param_groups: for p in group[params]: if momentum_buffer in self.state[p]: del self.state[p][momentum_buffer] super().step(closure) # 在Trainer中集成 class GRPORestarterTrainer(Seq2SeqTrainer): def create_optimizer(self): opt_model self.model_wrapped if is_sagemaker_mp_enabled() else self.model decay_params self.get_decay_parameter_names(opt_model) optimizer_grouped_parameters [ {params: [p for n, p in opt_model.named_parameters() if n in decay_params and p.requires_grad], weight_decay: self.args.weight_decay}, {params: [p for n, p in opt_model.named_parameters() if n not in decay_params and p.requires_grad], weight_decay: 0.0} ] # 使用自定义SGD optimizer WarmupRestartsSGD( optimizer_grouped_parameters, lrself.args.learning_rate, momentum0.9, restart_interval500 # 每500步重启一次 ) return optimizer工程实践建议在args中增加自定义参数如--restart_interval 500通过self.args.restart_interval读取保持命令行接口一致性。重启动逻辑应放在step()而非__init__()中确保分布式训练时各GPU步数同步。4. 完整端到端示例为Qwen2.5-7B-Instruct添加领域一致性Loss现在我们将上述技术整合构建一个生产就绪的定制方案针对电商客服微调任务添加商品属性一致性Loss确保模型在回答中提及的商品参数如“颜色红色”、“尺寸XL”与用户query中明确声明的属性严格一致。4.1 数据准备结构化标注query中的属性首先改造你的数据集在每条样本中加入attributes字段{ query: 这个连衣裙有红色和XL码吗, response: 有的这款连衣裙提供红色和XL码。, attributes: {color: [红色], size: [XL]} }ms-swift支持自定义dataset处理器只需继承DatasetPreprocessorfrom swift import DatasetPreprocessor class EcommerceAttrPreprocessor(DatasetPreprocessor): def __call__(self, dataset, num_proc1): def add_attr_features(example): # 将attributes字典转为tokenized ID序列便于模型学习 attr_str .join([f{k}:{|.join(v)} for k, v in example.get(attributes, {}).items()]) example[attr_input_ids] self.tokenizer.encode(attr_str, add_special_tokensFalse) return example return dataset.map(add_attr_features, num_procnum_proc)4.2 构建属性一致性Lossclass AttrConsistencyLoss(nn.Module): def __init__(self, tokenizer, attr_weight1.0): super().__init__() self.tokenizer tokenizer self.attr_weight attr_weight # 构建属性关键词到token id的映射实际项目中应从知识库加载 self.attr_keywords { color: self.tokenizer.convert_tokens_to_ids([红色, 蓝色, 黑色]), size: self.tokenizer.convert_tokens_to_ids([S, M, L, XL]) } def forward(self, logits, labels, attr_input_ids): # 1. 基础loss shift_logits logits[..., :-1, :].contiguous() shift_labels labels[..., 1:].contiguous() base_loss nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index-100, reductionmean ) # 2. 属性一致性loss鼓励模型在生成response时对attr_input_ids中出现的token # 给予更高概率KL散度最小化 if len(attr_input_ids) 0: # 获取attr_input_ids在logits中的对应位置概率 # 简化假设attr_input_ids出现在response开头实际需更精准对齐 attr_logits shift_logits[:, :len(attr_input_ids), :] attr_probs nn.functional.softmax(attr_logits, dim-1) # 构建target分布对每个attr token设其对应位置概率为1 target_dist torch.zeros_like(attr_probs) for i, token_id in enumerate(attr_input_ids): if i target_dist.size(1): target_dist[:, i, token_id] 1.0 # KL散度损失 attr_loss nn.functional.kl_div( torch.log(attr_probs 1e-8), target_dist, reductionbatchmean ) total_loss base_loss self.attr_weight * attr_loss else: total_loss base_loss return total_loss4.3 启动训练命令行与Python双模式命令行方式推荐用于快速验证# 1. 准备自定义preprocessor需提前注册到ms-swift python -c from swift import register_dataset_preprocessor register_dataset_preprocessor(ecommerce_attr, EcommerceAttrPreprocessor) # 2. 启动训练指定自定义Trainer类 CUDA_VISIBLE_DEVICES0 \ swift sft \ --model Qwen/Qwen2.5-7B-Instruct \ --dataset your-ecommerce-dataset \ --train_type lora \ --custom_trainer_class path.to.CustomEcommerceTrainer \ --output_dir output_ecom \ --learning_rate 2e-4 \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 8Python API方式适合复杂pipelinefrom swift import Swift, Seq2SeqTrainer, get_model_tokenizer, load_dataset from transformers import TrainingArguments # 加载模型与tokenizer model, tokenizer get_model_tokenizer(Qwen/Qwen2.5-7B-Instruct) # 注册自定义preprocessor from swift import register_dataset_preprocessor register_dataset_preprocessor(ecommerce_attr, EcommerceAttrPreprocessor) # 加载并预处理数据 train_dataset, _ load_dataset(your-ecommerce-dataset) train_dataset EcommerceAttrPreprocessor(tokenizer)(train_dataset) # 配置训练参数 training_args TrainingArguments( output_dir./output_ecom, per_device_train_batch_size2, gradient_accumulation_steps8, learning_rate2e-4, num_train_epochs2, save_steps100, logging_steps10, report_tonone ) # 创建自定义Trainer实例 trainer CustomEcommerceTrainer( modelmodel, argstraining_args, train_datasettrain_dataset, tokenizertokenizer, data_collator... # 使用默认collator即可 ) # 开始训练 trainer.train()5. 调试与验证确保定制逻辑正确生效定制功能上线前必须进行三层验证5.1 日志层验证确认钩子被调用在compute_loss和create_optimizer方法开头添加日志def compute_loss(self, model, inputs, return_outputsFalse): print(f[DEBUG] CustomLoss called with input shape: {inputs[input_ids].shape}) # ... rest of logic观察训练日志中是否出现该打印确认钩子未被跳过。5.2 数值层验证监控Loss组成修改compute_loss返回各子loss供日志记录def compute_loss(self, model, inputs, return_outputsFalse): outputs model(**inputs) logits outputs.get(logits) labels inputs.get(labels) base_loss self.base_criterion(logits, labels) attr_loss self.attr_criterion(logits, labels, inputs.get(attr_input_ids, [])) # 记录到日志 self.log({base_loss: base_loss.item(), attr_loss: attr_loss.item()}) total_loss base_loss 0.5 * attr_loss return (total_loss, outputs) if return_outputs else total_loss5.3 行为层验证人工评估生成质量训练100步后用固定prompt测试from swift import PtEngine engine PtEngine(Qwen/Qwen2.5-7B-Instruct, adapters./output_ecom/checkpoint-100) resp engine.infer([{role: user, content: 这个手机有512GB存储和绿色吗}]) print(resp.choices[0].message.content) # 期望输出包含512GB和绿色且无矛盾表述获取更多AI镜像想探索更多AI镜像和应用场景访问 CSDN星图镜像广场提供丰富的预置镜像覆盖大模型推理、图像生成、视频生成、模型微调等多个领域支持一键部署。