摘要


OpenVLA

核心创新:ActionTokenizer

ActionTokenizer 通过将连续的机器人动作映射到离散的词元空间(256 bins),将机器人控制转化为了一个语言任务。(与论文对应)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class ActionTokenizer:
def __init__(self, tokenizer, bins=256, min_action=-1, max_action=1):
self.tokenizer = tokenizer
self.bins = bins
# 创建用于离散化的bin边界
self.bin_centers = np.linspace(min_action, max_action, bins)

def encode(self, actions):
"""将连续动作编码为token序列"""
# 1. 量化:将动作值映射到最近的bin
action_bins = np.digitize(actions, self.bin_centers) - 1
action_bins = np.clip(action_bins, 0, self.bins - 1)

action_tokens = []
for bin_idx in action_bins.flatten():
# 2. token 化:将bin索引转换为特殊的文本token,如 "<action_025>"
token = f"<action_{bin_idx:03d}>"
action_tokens.append(self.tokenizer.convert_tokens_to_ids(token))
return action_tokens
  • 将连续动作空间量化为256个离散bins,并映射到LLM的词汇表中。
  • 利用 pretrained LLM 来统一处理视觉、语言和动作。

端到端推理流程实现

predict_action 函数封装了完整的推理逻辑,从接收多模态输入到输出最终的机器人动作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# AI 辅助
@torch.inference_mode()
def predict_action(
self, image: Image, instruction: str, unnorm_key: Optional[str] = None, **kwargs: str
) -> np.ndarray:
"""
VLA推理的核心函数;将输入图像和任务指令映射为连续动作。

Args:
image: PIL图像,格式为[高度, 宽度, 3]的RGB图像
instruction: 任务指令字符串,描述机器人应该执行的任务
unnorm_key: 可选的数据集名称,用于检索去归一化统计信息;
如果为None,则检查模型是否仅在单个数据集上训练,并检索该统计信息
**kwargs: 传递给generate方法的额外参数

Returns:
np.ndarray: 去归一化的(连续)动作向量 --> 末端执行器增量变化
"""
# 获取视觉backbone的图像变换和语言模型的分词器
image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer

# 构建VLA提示词
prompt_builder = self.get_prompt_builder()
prompt_builder.add_turn(role="human", message=f"What action should the robot take to {instruction.lower()}?")
prompt_text = prompt_builder.get_prompt()

# 准备输入数据
# 将提示词文本转换为token IDs并移动到设备上
input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
...
# 预处理图像
# 应用图像变换(如缩放、归一化等)
pixel_values = image_transform(image)
if isinstance(pixel_values, torch.Tensor):
# 如果是张量,添加batch维度并移动到设备
pixel_values = pixel_values[None, ...].to(self.device)
elif isinstance(pixel_values, dict):
# 如果是字典(多分辨率等情况),对每个值添加batch维度并移动到设备
pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
else:
raise ValueError(f"不支持的像素值类型 = {type(pixel_values)}")

# 调用super().generate --> 利用`GenerationMixin`,它会重定向到`forward()`方法
...
# 根据掩码选择性地去归一化动作
# 对于掩码为True的维度:从[-1,1]范围映射回原始动作范围
# 对于掩码为False的维度:保持原始归一化值
actions = np.where(
mask,
0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
normalized_actions,
)
return actions

通过动作即文本的方式,成功将机器人控制转化为文本生成任务。


openpi

流匹配算法与JAX实现

π₀ 模型采用流匹配(类似扩散),直接学习从噪声到真实动作的向量场,并借助 jax 框架实现静态编译和高效并行化。

  • 流匹配核心公式:x_t = t * noise + (1-t) * actions,模型学习的是直接将噪声推向真实数据的速度场,比扩散模型路径更短。
  • jax优势:函数式编程范式、JIT(编译以及自动并行化能力,支持高效的GPU/TPU计算),以及自动微分,便于中间结果的缓存和复用。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def step(carry):
"""单步扩散采样。"""
x_t, time = carry
# 嵌入当前状态的动作序列
suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(...)
# `suffix_attn_mask`的形状为(b, suffix_len, suffix_len),表示后缀令牌之间如何互相关注
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
# `prefix_attn_mask`的形状为(b, suffix_len, prefix_len),表示后缀令牌如何关注前缀令牌
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
# `full_attn_mask`的形状为(b, suffix_len, prefix_len + suffix_len),表示后缀令牌(生成查询)
# 如何关注完整的前缀+后缀序列(生成键和值)
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
# `positions`的形状为(b, suffix_len),表示后缀令牌的位置
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

# 使用KV缓存进行高效的Transformer推理
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
)
assert prefix_out is None
# 预测向量场
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
# 使用欧拉方法更新状态:x_{t+dt} = x_t + dt * v_t
return x_t + dt * v_t, time + dt

并行化与分布式训练优化(lerobot, openpi_pytorch, openpi)

PyTorch的并行优化

openpi_pytorch 中,通过 einops 库对多视图图像嵌入进行了并行优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# openpi_pytorch
def embed_prefix(self, images, img_masks, lang_tokens, lang_masks):
# 使用einops将多视图(n)合并到批次(b)维度,形成 (b*n) 的大批次
images = einops.rearrange(images, "b n c h w -> (b n) c h w")
# 一次性完成所有图像的嵌入,充分利用GPU并行
img_emb = self.paligemma_with_expert.embed_image(images)
# 恢复维度
img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize)

# lerobot
def embed_prefix(self, images, img_masks, lang_tokens, lang_masks):
# TODO: 避免Python中的列表和torch.cat;优先使用torch.empty预分配
embs = []
pad_masks = []
att_masks = []

# TODO: 移除for循环
for (img, img_mask) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
# 逐个处理每个图像

JAX的自动分片与分布式策略

jax 通过 Sharding 机制,实现复杂的分布式训练策略。

1
2
3
4
5
6
7
8
9
10
11
12
13
# jax FSDP 自动分片策略
def fsdp_sharding(pytree, mesh: jax.sharding.Mesh, *, min_size_mbytes: int = 4):
"""智能分片:遍历模型参数,根据张量大小和设备网格自动分配"""
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
# 优先在张量最大的维度上进行分片
axes = np.argsort(array.shape)[::-1]
spec = [None] * len(axes)
for i in axes:
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
spec[i] = FSDP_AXIS
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
# 如果无法分片,则在所有设备上复制
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
  • jax:提供自动化并行策略,只需定义设备网格即可自动完成参数分片和通信。
  • PyTorch:需要使用 DistributedDataParallel 或通过张量操作库 einops 来重组数据以提升并行效率。

KV缓存优化机制

KV缓存能够提升多步采样的模型的推理速度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# JAX版本的高效扩散采样
def sample_actions(self, rng, observation, *, num_steps: int = 10):
# 首先通过前缀的前向传播填充KV缓存,这样后续步骤只需要处理后缀
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
# 第一次前向传播,获取KV缓存
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

...
# 定义单步采样函数
def step(carry):
x_t, time = carry
# 复用缓存:在每一步采样时,只计算后缀(动作)部分的嵌入,并传入之前缓存的KV值
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens],
mask=full_attn_mask,
positions=positions,
kv_cache=kv_cache # 复用缓存
)
return x_t + dt * v_t, time + dt
...
# 使用jax.lax.while_loop,整个循环可被JIT编译
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
  • 缓存策略:视觉和语言等上下文特征是固定的,其KV值只需计算一次并被后续所有动作预测步骤复用,极大地减少了重复计算。
  • 编译优化:jax 能够将整个包含while_loop的采样函数JIT编译成一个单一计算图,性能更好。