摘要 
OpenVLA 
核心创新:ActionTokenizer 
ActionTokenizer 通过将连续的机器人动作映射到离散的词元空间(256 bins),将机器人控制转化为了一个语言任务。(与论文对应)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 class  ActionTokenizer :    def  __init__ (self, tokenizer, bins=256 , min_action=-1 , max_action=1  ):         self.tokenizer = tokenizer         self.bins = bins                  self.bin_centers = np.linspace(min_action, max_action, bins)     def  encode (self, actions ):         """将连续动作编码为token序列"""                   action_bins = np.digitize(actions, self.bin_centers) - 1          action_bins = np.clip(action_bins, 0 , self.bins - 1 )         action_tokens = []         for  bin_idx in  action_bins.flatten():                          token = f"<action_{bin_idx:03d} >"              action_tokens.append(self.tokenizer.convert_tokens_to_ids(token))         return  action_tokens 
将连续动作空间量化为256个离散bins,并映射到LLM的词汇表中。 
利用 pretrained LLM 来统一处理视觉、语言和动作。 
 
端到端推理流程实现 
predict_action 函数封装了完整的推理逻辑,从接收多模态输入到输出最终的机器人动作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 @torch.inference_mode() def  predict_action (    self, image: Image, instruction: str , unnorm_key: Optional [str ] = None , **kwargs: str      """      VLA推理的核心函数;将输入图像和任务指令映射为连续动作。     Args:         image: PIL图像,格式为[高度, 宽度, 3]的RGB图像         instruction: 任务指令字符串,描述机器人应该执行的任务         unnorm_key: 可选的数据集名称,用于检索去归一化统计信息;                     如果为None,则检查模型是否仅在单个数据集上训练,并检索该统计信息         **kwargs: 传递给generate方法的额外参数     Returns:         np.ndarray: 去归一化的(连续)动作向量 --> 末端执行器增量变化     """          image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer          prompt_builder = self.get_prompt_builder()     prompt_builder.add_turn(role="human" , message=f"What action should the robot take to {instruction.lower()} ?" )     prompt_text = prompt_builder.get_prompt()               input_ids = tokenizer(prompt_text, truncation=True , return_tensors="pt" ).input_ids.to(self.device)     ...               pixel_values = image_transform(image)     if  isinstance (pixel_values, torch.Tensor):                  pixel_values = pixel_values[None , ...].to(self.device)     elif  isinstance (pixel_values, dict ):                  pixel_values = {k: v[None , ...].to(self.device) for  k, v in  pixel_values.items()}     else :         raise  ValueError(f"不支持的像素值类型 = {type (pixel_values)} " )          ...                    actions = np.where(         mask,         0.5  * (normalized_actions + 1 ) * (action_high - action_low) + action_low,         normalized_actions,     )     return  actions 
通过动作即文本的方式,成功将机器人控制转化为文本生成任务。
openpi 
流匹配算法与JAX实现 
π₀ 模型采用流匹配(类似扩散),直接学习从噪声到真实动作的向量场,并借助 jax 框架实现静态编译和高效并行化。
流匹配核心公式:x_t = t * noise + (1-t) * actions,模型学习的是直接将噪声推向真实数据的速度场,比扩散模型路径更短。 
jax优势:函数式编程范式、JIT(编译以及自动并行化能力,支持高效的GPU/TPU计算),以及自动微分,便于中间结果的缓存和复用。 
 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def  step (carry ):    """单步扩散采样。"""      x_t, time = carry          suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(...)          suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)          prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p" , s=suffix_tokens.shape[1 ])               full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1 )          positions = jnp.sum (prefix_mask, axis=-1 )[:, None ] + jnp.cumsum(suffix_mask, axis=-1 ) - 1           (prefix_out, suffix_out), _ = self.PaliGemma.llm(         [None , suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache     )     assert  prefix_out is  None           v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])          return  x_t + dt * v_t, time + dt 
并行化与分布式训练优化(lerobot, openpi_pytorch, openpi) 
PyTorch的并行优化 
openpi_pytorch 中,通过 einops 库对多视图图像嵌入进行了并行优化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def  embed_prefix (self, images, img_masks, lang_tokens, lang_masks ):         images = einops.rearrange(images, "b n c h w -> (b n) c h w" )          img_emb = self.paligemma_with_expert.embed_image(images)          img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d" , b=bsize) def  embed_prefix (self, images, img_masks, lang_tokens, lang_masks ):         embs = []     pad_masks = []     att_masks = []          for  (img, img_mask) in  zip (images, img_masks, strict=False ):         img_emb = self.paligemma_with_expert.embed_image(img)          
JAX的自动分片与分布式策略 
jax 通过 Sharding 机制,实现复杂的分布式训练策略。
1 2 3 4 5 6 7 8 9 10 11 12 13 def  fsdp_sharding (pytree, mesh: jax.sharding.Mesh, *, min_size_mbytes: int  = 4  ):    """智能分片:遍历模型参数,根据张量大小和设备网格自动分配"""      def  _shard_arr (kp, array: jax.ShapeDtypeStruct ):                  axes = np.argsort(array.shape)[::-1 ]         spec = [None ] * len (axes)         for  i in  axes:             if  array.shape[i] % mesh.shape[FSDP_AXIS] == 0 :                 spec[i] = FSDP_AXIS                 return  jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))                  return  jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) 
jax :提供自动化并行策略,只需定义设备网格即可自动完成参数分片和通信。PyTorch :需要使用 DistributedDataParallel 或通过张量操作库 einops 来重组数据以提升并行效率。 
KV缓存优化机制 
KV缓存能够提升多步采样的模型的推理速度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def  sample_actions (self, rng, observation, *, num_steps: int  = 10  ):         prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)     prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)     positions = jnp.cumsum(prefix_mask, axis=1 ) - 1           _, kv_cache = self.PaliGemma.llm([prefix_tokens, None ], mask=prefix_attn_mask, positions=positions)     ...          def  step (carry ):         x_t, time = carry                  (prefix_out, suffix_out), _ = self.PaliGemma.llm(             [None , suffix_tokens],             mask=full_attn_mask,             positions=positions,             kv_cache=kv_cache          )         return  x_t + dt * v_t, time + dt     ...          x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0 )) 
缓存策略 :视觉和语言等上下文特征是固定的,其KV值只需计算一次并被后续所有动作预测步骤复用,极大地减少了重复计算。编译优化 :jax 能够将整个包含while_loop的采样函数JIT编译成一个单一计算图,性能更好。