Skip to content

move dp loop to model agent #3598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend


Expand Down Expand Up @@ -217,5 +218,14 @@ def from_hf_config(cls,

@dataclass
class MiscConfig:
prefill_interval: int = 16
custom_module_map: str = None
empty_init: bool = False

@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
"""From engine config."""
misc_config = cls(custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval)
return misc_config
118 changes: 29 additions & 89 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _build_dist_config(engine_config: PytorchEngineConfig):

def _build_misc_config(engine_config: PytorchEngineConfig):
"""Build misc config."""
misc_config = MiscConfig(custom_module_map=engine_config.custom_module_map, empty_init=engine_config.empty_init)
misc_config = MiscConfig.from_engine_config(engine_config)
return misc_config


Expand Down Expand Up @@ -160,19 +160,6 @@ def set(self, idx: int = None):
raise NotImplementedError('Not implemented.')


class RunableEventSync(RunableEventBase):
"""Awaitable sync runable event."""

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler

async def wait(self):
"""Wait event."""

def set(self):
"""Set event."""


class RunableEventAsnyc(RunableEventBase):
"""Awaitable async runable event."""

Expand All @@ -184,13 +171,6 @@ async def wait(self):
"""Wait event."""
await self.event.wait()

def set_single(self):
"""Set single."""
if self.scheduler.has_unfinished():
self.event.set()
else:
self.event.clear()

def set(self):
"""Set event."""
if self.scheduler.has_unfinished():
Expand All @@ -199,12 +179,9 @@ def set(self):
self.event.clear()


def build_runable_event(scheduler: Scheduler, sync: bool):
def build_runable_event(scheduler: Scheduler):
"""Build runable event."""
if sync:
return RunableEventSync(scheduler)
else:
return RunableEventAsnyc(scheduler)
return RunableEventAsnyc(scheduler)


class InputsMakerBase:
Expand Down Expand Up @@ -234,7 +211,28 @@ def __init__(self, engine: 'Engine'):
self.scheduler = self.engine.scheduler
self.forward_inputs = None

def do_prefill(self):
self.dp = self.engine.dist_config.dp
self.role = self.engine.cache_config.role

self.next_is_prefill = True
if self.dp == 1:
self.do_prefill = self.do_prefill_default
else:
self.do_prefill = self.do_prefill_dp

def do_prefill_dp(self):
if self.role == EngineRole.Prefill:
return True

scheduler = self.scheduler

if self.next_is_prefill:
ret = scheduler.has_waiting()
else:
ret = not scheduler.has_running()
return ret

def do_prefill_default(self):
# decoding if no waiting
scheduler = self.scheduler
if not scheduler.has_waiting():
Expand Down Expand Up @@ -262,6 +260,7 @@ async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool
if logger.level <= logging.DEBUG:
session_ids = [seq.session_id for seq in next_running]
logger.debug(f'Forward session_ids: {session_ids}')
self.next_is_prefill = inputs.is_decoding
await self.executor.forward_async(forward_inputs)
self.forward_inputs = forward_inputs
return forward_inputs, next_running
Expand Down Expand Up @@ -292,37 +291,9 @@ async def prefetch_next_inputs(self):
return None, None


class InputsMakerSync(InputsMakerAsync):
"""Inputs maker synchronize."""

def __init__(self, engine: 'Engine'):
super().__init__(engine)
self._is_prefill = True

def do_prefill(self):
if self.engine.engine_config.role in [EngineRole.Hybrid, EngineRole.Decode]:
ret = self._is_prefill
self._is_prefill = not self._is_prefill
elif self.engine.engine_config.role == EngineRole.Prefill:
ret = True
return ret

async def send_next_inputs(self):
prefill = self.do_prefill()
return await self._send_next_inputs_impl(prefill)

async def prefetch_next_inputs(self):
"""prefetch."""
logger.info('Prefetching next forward inputs.')
return await self.send_next_inputs()


def build_inputs_maker(engine: 'Engine'):
"""Build inputs makers."""
if engine.should_execute_dummy_batch:
return InputsMakerSync(engine)
else:
return InputsMakerAsync(engine)
return InputsMakerAsync(engine)


class Engine:
Expand Down Expand Up @@ -370,7 +341,6 @@ def __init__(self,
backend_config = _build_backend_config(engine_config)
dist_config = _build_dist_config(engine_config)
misc_config = _build_misc_config(engine_config)
self.should_execute_dummy_batch = dist_config.need_dummy_batch()

# build model agent
raw_tokenizer = None
Expand Down Expand Up @@ -873,43 +843,16 @@ def __need_logits(seqs: SeqList):
"""Need logits."""
return any(seq.return_logits for seq in seqs)

def __make_dummy_inputs():
"""Make dummy inputs."""
logger.info(f'make dummy forward inputs: prefill={prefill}.')
num_loops = 1 if prefill else prefill_interval

batch_size = 2 if self.dist_config.enable_microbatch else 1
batch_size = min(self.cache_config.max_batches, batch_size)
return dict(
running=[],
inputs=ModelInputs.make_dummy(batch_size,
is_decoding=not prefill,
vocab_size=self.model_config.vocab_size),
swap_in_map=dict(),
swap_out_map=dict(),
loop_count=num_loops,
is_dummy=True,
sync_long_context=False,
)

scheduler = self.scheduler
logger.info(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')

if self.should_execute_dummy_batch:
if prefill and scheduler.num_waiting() == 0:
return __make_dummy_inputs()
if not prefill and scheduler.num_running() == 0:
return __make_dummy_inputs()

scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval)

if enable_empty and len(scheduler_output.running) == 0:
return None

# schedule decoding if no valid prefill reqs.
if prefill and len(
scheduler_output.running
) == 0 and not self.should_execute_dummy_batch and self.engine_config.role != EngineRole.Prefill:
if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill:
prefill = False
scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval)

Expand All @@ -918,9 +861,6 @@ def __make_dummy_inputs():
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map

if (self.should_execute_dummy_batch or self.engine_config.role == EngineRole.Prefill) and len(running) == 0:
return __make_dummy_inputs()

assert len(running) > 0

# create inputs
Expand Down Expand Up @@ -1144,7 +1084,7 @@ async def async_loop(self):

# preprocess task
logger.info('Starting async task MainLoopPreprocessMessage.')
has_runable_event = build_runable_event(self.scheduler, self.should_execute_dummy_batch)
has_runable_event = build_runable_event(self.scheduler)
loop_msg_proc = event_loop.create_task(self._async_loop_preprocess_message(
forward_event, has_runable_event),
name='MainLoopPreprocessMessage')
Expand Down
Loading