Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def _to_json_safe(value: Any) -> Any:
return json.loads(json.dumps(value, ensure_ascii=False, default=str))


def _selected_agent(item: AgentRolloutItem) -> dict[str, Any] | None:
if item.infer.agent is None:
return None
return item.infer.agent.model_dump(mode="json")


def _count_tool_turns(messages: list[dict[str, Any]]) -> int:
return sum(
1
Expand Down Expand Up @@ -249,7 +255,11 @@ async def _fill_rollout_state(self, rollout_state: RolloutState, item: AgentRoll
)
rollout_state.reward = _extract_reward_payload(item)
rollout_state.extra_fields["agent_status"] = item.status.value
rollout_state.extra_fields["agent_artifacts"] = item.artifacts
selected_agent = _selected_agent(item)
if selected_agent is not None:
rollout_state.extra_fields["agent_name"] = selected_agent.get("name")
rollout_state.extra_fields["agent_selected"] = _to_json_safe(selected_agent)
rollout_state.extra_fields["agent_artifacts"] = _to_json_safe(item.artifacts)
rollout_state.extra_fields["agent_judgers"] = {
name: record.model_dump(mode="json") for name, record in item.judgers.items()
}
Expand Down Expand Up @@ -305,7 +315,11 @@ def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRollo
rollout_state.response_mask = None
rollout_state.response_model_steps = None
rollout_state.extra_fields["agent_status"] = item.status.value
rollout_state.extra_fields["agent_artifacts"] = item.artifacts
selected_agent = _selected_agent(item)
if selected_agent is not None:
rollout_state.extra_fields["agent_name"] = selected_agent.get("name")
rollout_state.extra_fields["agent_selected"] = _to_json_safe(selected_agent)
rollout_state.extra_fields["agent_artifacts"] = _to_json_safe(item.artifacts)
rollout_state.extra_fields["agent_judgers"] = {
name: record.model_dump(mode="json") for name, record in item.judgers.items()
}
Expand All @@ -321,7 +335,6 @@ def _fill_eval_rollout_state(self, rollout_state: RolloutState, item: AgentRollo

messages, tools = _load_eval_trace_segment(item.artifacts)
if messages:
rollout_state.extra_fields["agent_trajectory"] = _to_json_safe({"messages": messages, "tools": tools})
rollout_state.extra_fields["agent_messages"] = messages
rollout_state.extra_fields["agent_tools"] = tools
rollout_state.extra_fields["agent_tool_turns"] = _count_tool_turns(messages)
19 changes: 13 additions & 6 deletions xtuner/v1/rl/agent_loop/sandbox_agent_loop/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import io
import json
import re
import shlex
import tarfile
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -448,12 +450,17 @@ async def __call__(self, client: Any, item: AgentRolloutItem, record: StageRecor
if chosen is None:
raise RuntimeError("PickAgent must run before RunAgentInstallDeps")
script = f"{self.workspace}/agent/{chosen.name}/install-deps.sh"
await exec_in(
client,
f'[ -f "{script}" ] && bash "{script}" || true',
timeout_sec=self.timeout,
raise_on_error=True,
)
script_q = shlex.quote(script)
t0 = time.monotonic()
try:
await exec_in(
client,
f"if [ -f {script_q} ]; then bash {script_q}; fi",
timeout_sec=self.timeout,
raise_on_error=True,
)
finally:
record.metadata["install_agent_time_s"] = time.monotonic() - t0


# ─────────────────────────────────────────────────────────────────
Expand Down
34 changes: 23 additions & 11 deletions xtuner/v1/rl/agent_loop/sandbox_agent_loop/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
from xtuner.v1.utils import get_logger


# ─────────────────────────────────────────────────────────────────
# Runner
# ─────────────────────────────────────────────────────────────────


class Runner:
"""Pairs one infer stage with one validation judger."""

Expand Down Expand Up @@ -155,15 +150,32 @@ def _log_final(
t_infer: float | None,
t_validate: float | None,
) -> None:
def format_seconds(label: str, val: Any) -> str | None:
if isinstance(val, (int, float)):
return f"{label}={val:.1f}s"
return None

parts: list[str] = [f"status={item.status.value}"]
if item.reward is not None:
parts.append(f"reward={item.reward:.4f}")
if t_acquire is not None:
parts.append(f"t_acquire={t_acquire:.1f}s")
if t_infer is not None:
parts.append(f"t_infer={t_infer:.1f}s")
if t_validate is not None:
parts.append(f"t_validate={t_validate:.1f}s")
if item.infer.agent is not None:
parts.append(f"agent={item.infer.agent.name}")
timing_parts = [
format_seconds("t_acquire", t_acquire),
format_seconds("t_acquire_rate_limit", item.infer.metadata.get("sandbox_acquire_rate_limit_wait_s")),
format_seconds("t_sandbox_ready", item.infer.metadata.get("sandbox_create_to_ready_time_s")),
format_seconds("t_install_agent", item.infer.metadata.get("install_agent_time_s")),
format_seconds("t_infer", t_infer),
format_seconds("t_validate", t_validate),
]
parts.extend(part for part in timing_parts if part is not None)
attempts = item.infer.metadata.get("sandbox_create_attempts")
if isinstance(attempts, int):
parts.append(f"sandbox_create_attempts={attempts}")
if item.infer.sandbox_image:
parts.append(f"sandbox_image={item.infer.sandbox_image}")
if item.infer.sandbox_url:
parts.append(f"sandbox_url={item.infer.sandbox_url}")
if item.status == RolloutStatus.FAILED and item.error is not None:
parts.append(f"error={item.error.stage or '?'}/{item.error.category}")
get_logger().info(f"[{tid}] done {' '.join(parts)}")
Expand Down
25 changes: 21 additions & 4 deletions xtuner/v1/rl/agent_loop/sandbox_agent_loop/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ async def get(self, name: str, *, record: StageRecord | None = None) -> Any:
self.validate_name(name)
spec = self._specs[name]
try:
client, env_id = await self._acquire_ready(spec)
client, env_id = await self._acquire_ready(spec, record=record)
except Exception as exc:
if record is not None:
record.status = StageStatus.FAILED
Expand Down Expand Up @@ -931,9 +931,12 @@ def _url_of(client: Any) -> str | None:
return str(val)
return None

async def _acquire_ready(self, spec: SandboxSpec) -> tuple[Any, str]:
async def _acquire_ready(self, spec: SandboxSpec, *, record: StageRecord | None = None) -> tuple[Any, str]:
last_err: Exception | None = None
t_ready: float | None = None
for attempt in range(1, self._max_attempts + 1):
if record is not None:
record.metadata["sandbox_create_attempts"] = attempt
try:
create_kwargs: dict[str, Any] = {}
if spec.key:
Expand All @@ -943,7 +946,14 @@ async def _acquire_ready(self, spec: SandboxSpec) -> tuple[Any, str]:
if spec.resources:
create_kwargs["resources"] = spec.resources
if self._create_limiter is not None:
t_limit = time.monotonic()
await self._create_limiter.acquire()
if record is not None:
record.metadata["sandbox_acquire_rate_limit_wait_s"] = (
record.metadata.get("sandbox_acquire_rate_limit_wait_s", 0.0) + time.monotonic() - t_limit
)
if t_ready is None:
t_ready = time.monotonic()
client, env_id = await self._provider.create(
image_tag=spec.image,
ttl_seconds=spec.ttl_seconds,
Expand All @@ -954,7 +964,10 @@ async def _acquire_ready(self, spec: SandboxSpec) -> tuple[Any, str]:
await asyncio.sleep(min(2**attempt, 8))
continue

if await self._wait_healthy(client):
healthy = await self._wait_healthy(client)
if healthy:
if record is not None and t_ready is not None:
record.metadata["sandbox_create_to_ready_time_s"] = time.monotonic() - t_ready
return client, env_id

try:
Expand Down Expand Up @@ -1043,7 +1056,11 @@ async def exec_in(
result = await client.execute(command, cwd, timeout_sec, detach)
rc = _result_code(result)
if raise_on_error and rc != 0:
raise RuntimeError(f"command failed (return_code={rc}): {command}\nstderr: {result.get('stderr', '')}")
raise RuntimeError(
f"command failed (return_code={rc}): {command}\n"
f"stdout: {result.get('stdout', '')}\n"
f"stderr: {result.get('stderr', '')}"
)
return result


Expand Down
7 changes: 7 additions & 0 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,10 +1413,13 @@ def _save_trajectories(self, data_groups: list[list[RolloutState]], save_path: P
"response_len": response_len,
"reward_payload": data.reward,
"agent": {
"name": data.extra_fields.get("agent_name"),
"selected": data.extra_fields.get("agent_selected"),
"status": data.extra_fields.get("agent_status", None),
"judgers": data.extra_fields.get("agent_judgers", None),
"finish_info": data.extra_fields.get("agent_finish_info", None),
"tool_turns": data.extra_fields.get("agent_tool_turns", None),
"artifacts": data.extra_fields.get("agent_artifacts"),
"messages": data.extra_fields.get("agent_messages"),
"tools": data.extra_fields.get("agent_tools"),
},
Expand Down Expand Up @@ -1476,9 +1479,13 @@ def _save_eval_trajectories(self, data_groups: list[list[RolloutState]], save_pa
"finish_reason": data.finish_reason,
"error_msg": data.error_msg,
"agent": {
"name": data.extra_fields.get("agent_name"),
"selected": data.extra_fields.get("agent_selected"),
"status": data.extra_fields.get("agent_status", None),
"judgers": data.extra_fields.get("agent_judgers", None),
"finish_info": data.extra_fields.get("agent_finish_info", None),
"tool_turns": data.extra_fields.get("agent_tool_turns", None),
"artifacts": data.extra_fields.get("agent_artifacts"),
"messages": data.extra_fields.get("agent_messages"),
"tools": data.extra_fields.get("agent_tools"),
},
Expand Down
Loading