Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
else
while IFS= read -r file; do
case "$file" in
xtuner/v1/rl/*)
tests/rl/*|xtuner/v1/rl/*)
;;
*)
only_rl=false
Expand Down
29 changes: 10 additions & 19 deletions tests/rl/test_qwen35_vl_moe_async_train_2step.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import math
import os
import tempfile
import time
import unittest
from pathlib import Path
Expand Down Expand Up @@ -68,21 +69,8 @@
MAX_PROMPT_LENGTH = 4096
MAX_RESPONSE_LENGTH = 2048
PACK_MAX_LENGTH = 8192
MISMATCH_KL_MAX = float(
os.environ.get(
"XTUNER_TRAIN_2STEP_MISMATCH_KL_MAX",
os.environ.get("XTUNER_PR_REAL_SMOKE_MISMATCH_KL_MAX", "0.005"),
)
)
MISMATCH_K3_KL_MAX = float(
os.environ.get(
"XTUNER_TRAIN_2STEP_MISMATCH_K3_KL_MAX",
os.environ.get("XTUNER_PR_REAL_SMOKE_MISMATCH_K3_KL_MAX", "0.005"),
)
)
RUN_ROOT = Path(
os.environ.get("XTUNER_TRAIN_2STEP_RUN_ROOT", os.environ.get("XTUNER_PR_REAL_SMOKE_RUN_ROOT", "."))
).resolve()
MISMATCH_KL_MAX = 0.005
MISMATCH_K3_KL_MAX = 0.005

REQUIRED_STEP_METRICS = (
"mismatch/mismatch_kl",
Expand Down Expand Up @@ -111,9 +99,12 @@ def setUp(self):
if not DATA_PATH.exists():
raise FileNotFoundError(f"Long-tail training dataset does not exist: {DATA_PATH}")

self.temp_dir = RUN_ROOT / f"{EXPERIMENT_NAME}_{time.strftime('%Y%m%d%H%M%S')}_{os.getpid()}"
self.temp_dir.mkdir(parents=True, exist_ok=False)
print(f"qwen35 vl moe async train 2-step work dir: {self.temp_dir}")
self.temp_dir = tempfile.TemporaryDirectory(
prefix=f"{EXPERIMENT_NAME}_{time.strftime('%Y%m%d%H%M%S')}_{os.getpid()}_",
)
self.addCleanup(self.temp_dir.cleanup)
self.temp_dir_path = Path(self.temp_dir.name)
print(f"qwen35 vl moe async train 2-step temp dir: {self.temp_dir_path}")
self.produce_calls: list[dict[str, Any]] = []
self.produce_results: list[ProduceBatchResult] = []
self.update_weight_calls = 0
Expand All @@ -137,7 +128,7 @@ def tearDown(self):
self._restore_env()

def test_qwen35_vl_moe_async_train_2step_and_metrics(self):
work_dir = Path(self.temp_dir) / "work_dir"
work_dir = self.temp_dir_path / "work_dir"
work_dir.mkdir(parents=True, exist_ok=True)

start_s = time.perf_counter()
Expand Down
2 changes: 0 additions & 2 deletions tests/rl/test_rollout_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def _build_controller(self, router):
controller.config = SimpleNamespace(rollout_timeout=1.0, random_seed=0)
controller.timeout_multiplier = 1.0
controller.router = router
controller._tool_call_parser = None
controller._reasoning_parser = None
controller.logger = MagicMock()
return controller

Expand Down
35 changes: 0 additions & 35 deletions xtuner/v1/rl/rollout/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@
from ray.actor import ActorProxy
from ray.util.placement_group import PlacementGroup

from transformers import AutoTokenizer
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.utils import AutoAcceleratorWorkers
from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger

from .constants import ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY
from .health_manager import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthManager
from .parser.factory import build_reasoning_parser, build_tool_call_parser
from .parser.reasoning_parser import ReasoningParser
from .parser.tool_parser import ToolCallParser
from .proxy_manager import RolloutProxyManager
from .utils import SessionRouter
from .worker import (
Expand Down Expand Up @@ -66,7 +62,6 @@ def __init__(
worker_lifecycle_listeners=[self.proxy_manager] if self.proxy_manager is not None else None,
)
self.health_manager.start()
self._tool_call_parser, self._reasoning_parser = self._build_output_parsers()

def get_rollout_metadata(self) -> RolloutWorkerMetadata:
"""Get information about the current rollout setup.
Expand All @@ -93,20 +88,6 @@ def validate_registered_workers_to_proxy(self) -> None:
return
self.proxy_manager.validate_registered_session_urls()

def _build_output_parsers(self) -> tuple[ToolCallParser | None, ReasoningParser | None]:
tool_call_parser = None
reasoning_parser = None

if self.config.tool_call_parser != "none":
tool_call_parser = build_tool_call_parser(self.config.tool_call_parser)

if self.config.reasoning_parser != "none":
tokenizer_path = self.config.tokenizer_path or self.config.model_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
reasoning_parser = build_reasoning_parser(self.config.reasoning_parser, tokenizer)

return tool_call_parser, reasoning_parser

@ray.method(concurrency_group=ROLLOUT_CONCURRENCY_GROUP_GENERATE)
async def generate(self, rollout_state: RolloutState) -> RolloutState:
if XTUNER_DETERMINISTIC:
Expand All @@ -129,7 +110,6 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState:
response_ref,
timeout=self.config.rollout_timeout * self.timeout_multiplier,
)
self._apply_output_parsers(response_rollout_state)
return response_rollout_state
except asyncio.TimeoutError:
self.logger.error(
Expand All @@ -142,21 +122,6 @@ async def generate(self, rollout_state: RolloutState) -> RolloutState:
)
return rollout_state

def _apply_output_parsers(self, rollout_state: RolloutState) -> None:
"""Apply tool-call and reasoning parsers to the rollout state in-
place."""
if self._tool_call_parser is not None:
parsed = self._tool_call_parser.parse(rollout_state)
rollout_state.tool_calls = parsed.tool_calls
rollout_state.response = parsed.remaining_text or None
if self._reasoning_parser is not None:
parsed_reasoning = self._reasoning_parser.parse(rollout_state)
rollout_state.response = parsed_reasoning.remaining_text
if parsed_reasoning.reasoning_text:
rollout_state.extra_fields["reasoning_text"] = parsed_reasoning.reasoning_text
else:
rollout_state.extra_fields.pop("reasoning_text", None)

def set_enable_partial_rollout(self, enable: bool) -> None:
"""Propagate enable_partial_rollout flag to all active workers."""
active_workers = self.registry.active_workers()
Expand Down
19 changes: 0 additions & 19 deletions xtuner/v1/rl/rollout/parser/__init__.py

This file was deleted.

36 changes: 0 additions & 36 deletions xtuner/v1/rl/rollout/parser/factory.py

This file was deleted.

59 changes: 0 additions & 59 deletions xtuner/v1/rl/rollout/parser/qwen3_reasoning_parser.py

This file was deleted.

115 changes: 0 additions & 115 deletions xtuner/v1/rl/rollout/parser/qwen3_tool_parser.py

This file was deleted.

Loading
Loading