摘要
OpenVLA
核心创新:ActionTokenizer
ActionTokenizer
通过将连续的机器人动作映射到离散的词元空间(256 bins),将机器人控制转化为了一个语言任务。(与论文对应)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class ActionTokenizer : def __init__ (self, tokenizer, bins=256 , min_action=-1 , max_action=1 ): self.tokenizer = tokenizer self.bins = bins self.bin_centers = np.linspace(min_action, max_action, bins) def encode (self, actions ): """将连续动作编码为token序列""" action_bins = np.digitize(actions, self.bin_centers) - 1 action_bins = np.clip(action_bins, 0 , self.bins - 1 ) action_tokens = [] for bin_idx in action_bins.flatten(): token = f"<action_{bin_idx:03d} >" action_tokens.append(self.tokenizer.convert_tokens_to_ids(token)) return action_tokens
将连续动作空间量化为256个离散bins,并映射到LLM的词汇表中。
利用 pretrained LLM 来统一处理视觉、语言和动作。
端到端推理流程实现
predict_action
函数封装了完整的推理逻辑,从接收多模态输入到输出最终的机器人动作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 @torch.inference_mode() def predict_action ( self, image: Image, instruction: str , unnorm_key: Optional [str ] = None , **kwargs: str ) -> np.ndarray: """ VLA推理的核心函数;将输入图像和任务指令映射为连续动作。 Args: image: PIL图像,格式为[高度, 宽度, 3]的RGB图像 instruction: 任务指令字符串,描述机器人应该执行的任务 unnorm_key: 可选的数据集名称,用于检索去归一化统计信息; 如果为None,则检查模型是否仅在单个数据集上训练,并检索该统计信息 **kwargs: 传递给generate方法的额外参数 Returns: np.ndarray: 去归一化的(连续)动作向量 --> 末端执行器增量变化 """ image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer prompt_builder = self.get_prompt_builder() prompt_builder.add_turn(role="human" , message=f"What action should the robot take to {instruction.lower()} ?" ) prompt_text = prompt_builder.get_prompt() input_ids = tokenizer(prompt_text, truncation=True , return_tensors="pt" ).input_ids.to(self.device) ... pixel_values = image_transform(image) if isinstance (pixel_values, torch.Tensor): pixel_values = pixel_values[None , ...].to(self.device) elif isinstance (pixel_values, dict ): pixel_values = {k: v[None , ...].to(self.device) for k, v in pixel_values.items()} else : raise ValueError(f"不支持的像素值类型 = {type (pixel_values)} " ) ... actions = np.where( mask, 0.5 * (normalized_actions + 1 ) * (action_high - action_low) + action_low, normalized_actions, ) return actions
通过动作即文本的方式,成功将机器人控制转化为文本生成任务。
openpi
流匹配算法与JAX实现
π₀ 模型采用流匹配(类似扩散),直接学习从噪声到真实动作的向量场,并借助 jax 框架实现静态编译和高效并行化。
流匹配核心公式:x_t = t * noise + (1-t) * actions
,模型学习的是直接将噪声推向真实数据的速度场,比扩散模型路径更短。
jax优势:函数式编程范式、JIT(编译以及自动并行化能力,支持高效的GPU/TPU计算),以及自动微分,便于中间结果的缓存和复用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def step (carry ): """单步扩散采样。""" x_t, time = carry suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(...) suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask) prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p" , s=suffix_tokens.shape[1 ]) full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1 ) positions = jnp.sum (prefix_mask, axis=-1 )[:, None ] + jnp.cumsum(suffix_mask, axis=-1 ) - 1 (prefix_out, suffix_out), _ = self.PaliGemma.llm( [None , suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache ) assert prefix_out is None v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) return x_t + dt * v_t, time + dt
并行化与分布式训练优化(lerobot, openpi_pytorch, openpi)
PyTorch的并行优化
openpi_pytorch 中,通过 einops
库对多视图图像嵌入进行了并行优化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def embed_prefix (self, images, img_masks, lang_tokens, lang_masks ): images = einops.rearrange(images, "b n c h w -> (b n) c h w" ) img_emb = self.paligemma_with_expert.embed_image(images) img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d" , b=bsize) def embed_prefix (self, images, img_masks, lang_tokens, lang_masks ): embs = [] pad_masks = [] att_masks = [] for (img, img_mask) in zip (images, img_masks, strict=False ): img_emb = self.paligemma_with_expert.embed_image(img)
JAX的自动分片与分布式策略
jax 通过 Sharding
机制,实现复杂的分布式训练策略。
1 2 3 4 5 6 7 8 9 10 11 12 13 def fsdp_sharding (pytree, mesh: jax.sharding.Mesh, *, min_size_mbytes: int = 4 ): """智能分片:遍历模型参数,根据张量大小和设备网格自动分配""" def _shard_arr (kp, array: jax.ShapeDtypeStruct ): axes = np.argsort(array.shape)[::-1 ] spec = [None ] * len (axes) for i in axes: if array.shape[i] % mesh.shape[FSDP_AXIS] == 0 : spec[i] = FSDP_AXIS return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec)) return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
jax :提供自动化并行策略,只需定义设备网格即可自动完成参数分片和通信。
PyTorch :需要使用 DistributedDataParallel
或通过张量操作库 einops
来重组数据以提升并行效率。
KV缓存优化机制
KV缓存能够提升多步采样的模型的推理速度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def sample_actions (self, rng, observation, *, num_steps: int = 10 ): prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) positions = jnp.cumsum(prefix_mask, axis=1 ) - 1 _, kv_cache = self.PaliGemma.llm([prefix_tokens, None ], mask=prefix_attn_mask, positions=positions) ... def step (carry ): x_t, time = carry (prefix_out, suffix_out), _ = self.PaliGemma.llm( [None , suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache ) return x_t + dt * v_t, time + dt ... x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0 ))
缓存策略 :视觉和语言等上下文特征是固定的,其KV值只需计算一次并被后续所有动作预测步骤复用,极大地减少了重复计算。
编译优化 :jax 能够将整个包含while_loop
的采样函数JIT编译成一个单一计算图,性能更好。