|
|
|
@@ -10,7 +10,16 @@ from jinja2 import BaseLoader, Environment |
|
|
|
from numpy.typing import NDArray |
|
|
|
|
|
|
|
from ..._parser import TokenParser |
|
|
|
from ..._schema import EngineResponse, GenToken, GenTokenExtra, SamplingParams, TokenUsage |
|
|
|
from ..._schema import ( |
|
|
|
EngineResponse, |
|
|
|
GenToken, |
|
|
|
GenTokenExtra, |
|
|
|
SamplingParams, |
|
|
|
StepConfig, |
|
|
|
StepContext, |
|
|
|
StepFeedback, |
|
|
|
TokenUsage, |
|
|
|
) |
|
|
|
from ..._utils import apply_min_p_filter, apply_repetition_penalty, apply_top_k_and_top_p_filter, log_init, softmax |
|
|
|
from ._state import EngineState |
|
|
|
from ._tokenizer import Tokenizer |
|
|
|
@@ -88,6 +97,7 @@ class Engine(ABC): |
|
|
|
grammar: str, |
|
|
|
ensure_bos_token: bool = True, |
|
|
|
sampling_params: SamplingParams | None = None, |
|
|
|
step_config: StepConfig | None = None, |
|
|
|
) -> Generator[EngineResponse, None, TokenUsage]: |
|
|
|
"""Main entry point for the inference-parser loop. Yields EngineCallResponse objects as |
|
|
|
the parser advances through the grammar. |
|
|
|
@@ -123,9 +133,24 @@ class Engine(ABC): |
|
|
|
issued_token: GenToken | None = None |
|
|
|
usage = TokenUsage(round_trips=1, ff_tokens=0) |
|
|
|
|
|
|
|
step_every_k: int | None = None |
|
|
|
step_stop_strings: set[str] = set() |
|
|
|
step_callback = None |
|
|
|
if step_config is not None: |
|
|
|
step_every_k = step_config.get("step_every_k") # type: ignore[assignment] |
|
|
|
step_stop_strings = set(step_config.get("step_stop_tokens", set())) # type: ignore[assignment] |
|
|
|
step_callback = step_config.get("callback") # type: ignore[assignment] |
|
|
|
step_counter = 0 |
|
|
|
|
|
|
|
step_tokens_buffer: list[int] = [] |
|
|
|
all_generated_tokens: list[int] = [] |
|
|
|
all_text_bytes = bytearray() |
|
|
|
|
|
|
|
while not parser.done(): |
|
|
|
t1 = time.monotonic() |
|
|
|
recode = False |
|
|
|
has_injection_backtrack = False # Track if this response has injection backtrack |
|
|
|
|
|
|
|
if issued_token is None: |
|
|
|
prefix_tokens, backtrack, ff_tokens, mask_fut = parser.process_prompt( |
|
|
|
prompt_tokens=tokens, |
|
|
|
@@ -269,18 +294,199 @@ class Engine(ABC): |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
new_bytes_acc = bytearray(legacy_engine_response.new_bytes) |
|
|
|
captures_acc = dict(legacy_engine_response.capture_groups) |
|
|
|
cap_log_probs_acc = dict(legacy_engine_response.capture_group_log_probs) |
|
|
|
|
|
|
|
new_token_ids_this_iter = [t.token_id for t in gen_tokens] |
|
|
|
step_tokens_buffer.extend(new_token_ids_this_iter) |
|
|
|
all_generated_tokens.extend(new_token_ids_this_iter) |
|
|
|
all_text_bytes += legacy_engine_response.new_bytes |
|
|
|
|
|
|
|
boundary_hit = False |
|
|
|
boundary_type = None # Track whether it's "every_k" or "stop_string" |
|
|
|
matched_stop_string = None |
|
|
|
if new_token_ids_this_iter: |
|
|
|
if step_every_k is not None and step_every_k > 0: |
|
|
|
if len(step_tokens_buffer) >= step_every_k: |
|
|
|
boundary_hit = True |
|
|
|
boundary_type = "every_k" |
|
|
|
if step_stop_strings and not boundary_hit: |
|
|
|
# Check if the accumulated text ends with any stop string |
|
|
|
accumulated_text = bytes(all_text_bytes).decode("utf-8", errors="ignore") |
|
|
|
for stop_string in step_stop_strings: |
|
|
|
if accumulated_text.endswith(stop_string): |
|
|
|
boundary_hit = True |
|
|
|
boundary_type = "stop_string" |
|
|
|
matched_stop_string = stop_string |
|
|
|
break |
|
|
|
|
|
|
|
if boundary_hit and step_callback is not None: |
|
|
|
ctx: StepContext = { |
|
|
|
"last_step_text": self.tokenizer.decode(step_tokens_buffer).decode("utf-8", errors="ignore"), |
|
|
|
"last_step_tokens": list(step_tokens_buffer), |
|
|
|
"all_text": bytes(all_text_bytes).decode("utf-8", errors="ignore"), |
|
|
|
"all_tokens": list(all_generated_tokens), |
|
|
|
"captures": dict(captures_acc), |
|
|
|
"step_counter": step_counter, |
|
|
|
} |
|
|
|
feedback: StepFeedback | None = step_callback(ctx) # type: ignore[misc] |
|
|
|
step_counter = ctx["step_counter"] |
|
|
|
if feedback: |
|
|
|
inj_bytes: bytes | None = None |
|
|
|
if "injected_bytes" in feedback and feedback["injected_bytes"]: |
|
|
|
inj_bytes = feedback["injected_bytes"] |
|
|
|
elif "injected_text" in feedback and feedback["injected_text"]: |
|
|
|
inj_bytes = feedback["injected_text"].encode("utf-8") |
|
|
|
if inj_bytes: |
|
|
|
# Only rollback for stop_string case |
|
|
|
backtrack_token_ids = [] |
|
|
|
if boundary_type == "stop_string" and matched_stop_string: |
|
|
|
# Calculate how many tokens to backtrack based on the matched stop string |
|
|
|
# We need to find which recent tokens form the stop string |
|
|
|
stop_string_bytes = matched_stop_string.encode("utf-8") |
|
|
|
|
|
|
|
# Search backwards through recent tokens to find which ones form the stop string |
|
|
|
accumulated_bytes = b"" |
|
|
|
for i in range(len(all_generated_tokens) - 1, -1, -1): |
|
|
|
token_id = all_generated_tokens[i] |
|
|
|
token_bytes = self.tokenizer.decode([token_id]) |
|
|
|
accumulated_bytes = token_bytes + accumulated_bytes |
|
|
|
backtrack_token_ids.insert(0, token_id) |
|
|
|
|
|
|
|
# Check if we've accumulated enough to match the stop string |
|
|
|
accumulated_text = accumulated_bytes.decode("utf-8", errors="ignore") |
|
|
|
if stop_string_bytes.decode("utf-8", errors="ignore") in accumulated_text: |
|
|
|
# We've found all tokens that contribute to the stop string |
|
|
|
break |
|
|
|
|
|
|
|
# Safety: don't go back more than 20 tokens |
|
|
|
if len(backtrack_token_ids) >= 20: |
|
|
|
break |
|
|
|
|
|
|
|
backtrack_bytes_to_remove = ( |
|
|
|
self.tokenizer.decode(backtrack_token_ids) if backtrack_token_ids else b"" |
|
|
|
) |
|
|
|
|
|
|
|
# Remove the tokens from model context |
|
|
|
if len(tokens) >= len(backtrack_token_ids): |
|
|
|
tokens = tokens[: -len(backtrack_token_ids)] |
|
|
|
|
|
|
|
# Determine which backtrack tokens are in the current response vs previous |
|
|
|
# Tokens in current response are in new_bytes_acc and gen_tokens |
|
|
|
current_response_token_count = 0 |
|
|
|
temp_bytes = bytes(new_bytes_acc) |
|
|
|
for i in range(len(backtrack_token_ids) - 1, -1, -1): |
|
|
|
token_bytes = self.tokenizer.decode([backtrack_token_ids[i]]) |
|
|
|
if temp_bytes.endswith(token_bytes): |
|
|
|
current_response_token_count += 1 |
|
|
|
temp_bytes = temp_bytes[: -len(token_bytes)] |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
previous_response_token_count = len(backtrack_token_ids) - current_response_token_count |
|
|
|
|
|
|
|
# Remove tokens from current response |
|
|
|
if current_response_token_count > 0: |
|
|
|
# Remove from new_bytes_acc |
|
|
|
for i in range( |
|
|
|
len(backtrack_token_ids) - 1, |
|
|
|
len(backtrack_token_ids) - 1 - current_response_token_count, |
|
|
|
-1, |
|
|
|
): |
|
|
|
token_bytes = self.tokenizer.decode([backtrack_token_ids[i]]) |
|
|
|
if new_bytes_acc.endswith(token_bytes): |
|
|
|
new_bytes_acc = new_bytes_acc[: -len(token_bytes)] |
|
|
|
# Remove from gen_tokens |
|
|
|
if len(gen_tokens) >= current_response_token_count: |
|
|
|
gen_tokens = gen_tokens[:-current_response_token_count] |
|
|
|
|
|
|
|
# Backtrack bytes are only from previous responses |
|
|
|
backtrack_bytes_from_previous = ( |
|
|
|
self.tokenizer.decode(backtrack_token_ids[:previous_response_token_count]) |
|
|
|
if previous_response_token_count > 0 |
|
|
|
else b"" |
|
|
|
) |
|
|
|
|
|
|
|
# Remove from tracking buffers (for future context) |
|
|
|
if len(step_tokens_buffer) >= len(backtrack_token_ids): |
|
|
|
step_tokens_buffer = step_tokens_buffer[: -len(backtrack_token_ids)] |
|
|
|
if len(all_generated_tokens) >= len(backtrack_token_ids): |
|
|
|
all_generated_tokens = all_generated_tokens[: -len(backtrack_token_ids)] |
|
|
|
if backtrack_bytes_to_remove and all_text_bytes.endswith(backtrack_bytes_to_remove): |
|
|
|
all_text_bytes = all_text_bytes[: -len(backtrack_bytes_to_remove)] |
|
|
|
|
|
|
|
# Add injection backtrack to any existing parser backtrack |
|
|
|
# For injection: only backtrack what's in previous responses |
|
|
|
backtracked_bytes = backtrack_bytes_from_previous + backtracked_bytes |
|
|
|
backtrack = previous_response_token_count + backtrack |
|
|
|
backtracked_bytes = backtrack_bytes_from_previous + backtracked_bytes |
|
|
|
backtrack = previous_response_token_count + backtrack |
|
|
|
|
|
|
|
# Set flag to indicate this is an injection backtrack |
|
|
|
has_injection_backtrack = True |
|
|
|
|
|
|
|
# Inject tokens (applies to both every_k and stop_string cases) |
|
|
|
inj_token_ids = self.tokenizer.encode(inj_bytes) |
|
|
|
for inj_token_id in inj_token_ids: |
|
|
|
backtrack2, ff_tokens2, mask_fut2 = parser.advance(token_id=inj_token_id) |
|
|
|
if backtrack2: |
|
|
|
tokens[:] = tokens[:-backtrack2] |
|
|
|
# Add the injected token to the model's context |
|
|
|
tokens.append(inj_token_id) |
|
|
|
tokens += ff_tokens2 |
|
|
|
mask2, ll_response2, _ = mask_fut2.result() |
|
|
|
legacy2 = ll_response2.progress.to_engine_call_response() |
|
|
|
# DON'T add injected tokens to current response - they'll appear in next iteration |
|
|
|
for k, v in legacy2.capture_groups.items(): |
|
|
|
captures_acc[k] = v |
|
|
|
for k, v in legacy2.capture_group_log_probs.items(): |
|
|
|
cap_log_probs_acc[k] = v |
|
|
|
|
|
|
|
usage.ff_tokens += len(ff_tokens2) |
|
|
|
step_tokens_buffer.append(inj_token_id) |
|
|
|
step_tokens_buffer.extend(ff_tokens2) |
|
|
|
all_generated_tokens.append(inj_token_id) |
|
|
|
all_generated_tokens.extend(ff_tokens2) |
|
|
|
all_text_bytes += legacy2.new_bytes |
|
|
|
|
|
|
|
# Add injected tokens to the CURRENT response |
|
|
|
inj_bytes_acc = bytearray() |
|
|
|
for inj_token_id in inj_token_ids: |
|
|
|
gen_tokens.append( |
|
|
|
GenTokenExtra( |
|
|
|
token_id=inj_token_id, |
|
|
|
bytes=self.tokenizer.decode([inj_token_id]), |
|
|
|
prob=float("nan"), |
|
|
|
latency_ms=0.0, |
|
|
|
is_generated=False, |
|
|
|
is_force_forwarded=True, |
|
|
|
is_input=False, |
|
|
|
is_backtracked=False, |
|
|
|
is_masked=False, |
|
|
|
top_k=[], |
|
|
|
) |
|
|
|
) |
|
|
|
inj_bytes_acc += self.tokenizer.decode([inj_token_id]) |
|
|
|
|
|
|
|
new_bytes_acc += inj_bytes_acc |
|
|
|
|
|
|
|
# Set flag to indicate this is an injection backtrack |
|
|
|
if boundary_type == "stop_string" and matched_stop_string: |
|
|
|
has_injection_backtrack = True |
|
|
|
|
|
|
|
step_tokens_buffer = [] |
|
|
|
|
|
|
|
engine_response = EngineResponse( |
|
|
|
new_bytes=legacy_engine_response.new_bytes, |
|
|
|
new_bytes=bytes(new_bytes_acc), |
|
|
|
backtrack_bytes=backtracked_bytes, |
|
|
|
capture_groups=legacy_engine_response.capture_groups, |
|
|
|
capture_group_log_probs=legacy_engine_response.capture_group_log_probs, |
|
|
|
capture_groups=captures_acc, |
|
|
|
capture_group_log_probs=cap_log_probs_acc, |
|
|
|
backtrack=backtrack, |
|
|
|
tokens=gen_tokens, |
|
|
|
injection_backtrack=has_injection_backtrack, |
|
|
|
) |
|
|
|
|
|
|
|
# process engine_response |
|
|
|
# NOTE (loc): We should not yield the engine_response if new_bytes are invalid utf-8 bytes |
|
|
|
# delayed bytes should be handled here in the engine |
|
|
|
yield engine_response |
|
|
|
|
|
|
|
if ll_response.stop: |
|
|
|
|