2 Commits

Author SHA1 Message Date
  Richard Edgar 6c73b371e7
Merge branch 'main' into riedgar-ms/build-314 1 week ago
  Nagarajan Natarajan ddf9eed4f9
[Feature] Monitor-guided Inference (#1391) 1 week ago
5 changed files with 319 additions and 12 deletions
Split View
  1. +27
    -1
      guidance/_schema.py
  2. +11
    -1
      guidance/models/_base/_model.py
  3. +213
    -7
      guidance/models/_engine/_engine.py
  4. +20
    -3
      guidance/models/_engine/_interpreter.py
  5. +48
    -0
      tests/unit/test_model.py

+ 27
- 1
guidance/_schema.py View File

@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Any, Literal, TypedDict
from typing import Any, Callable, Literal, Set, TypedDict

from annotated_types import Ge, Le
from pydantic import BaseModel, Field, NonNegativeInt, RootModel, computed_field, model_validator
@@ -97,6 +97,7 @@ class EngineResponse(BaseModel):
capture_group_log_probs: dict
backtrack: NonNegativeInt = 0 # number of tokens was backtracked by the parser
tokens: list["GenToken"] = [] # tokens associated with the generated bytes
injection_backtrack: bool = False # if True, backtrack should be processed before adding new_bytes


class LegacyEngineCallResponse(BaseModel):
@@ -234,3 +235,28 @@ class SamplingParams(TypedDict):
top_k: int | None
min_p: float | None
repetition_penalty: float | None


class StepContext(TypedDict):
last_step_text: str
last_step_tokens: list[int]
all_text: str
all_tokens: list[int]
captures: dict
step_counter: int


class StepFeedback(TypedDict, total=False):
# Either injected_text (utf-8) or injected_bytes can be provided.
injected_text: str
injected_bytes: bytes


class StepConfig(TypedDict, total=False):
# Trigger every k generated tokens (including fast-forwarded and injected)
step_every_k: int
# Trigger when the last generated token id is in this set
step_stop_tokens: Set[str]
# Callback invoked at each step boundary
# Returns optional feedback to inject next.
callback: Callable[[StepContext], StepFeedback]

+ 11
- 1
guidance/models/_base/_model.py View File

@@ -19,7 +19,7 @@ from ..._ast import (
RoleStart,
_parse_tags,
)
from ..._schema import SamplingParams, TokenUsage
from ..._schema import SamplingParams, StepConfig, TokenUsage
from ...trace import (
ImageInput,
LiteralInput,
@@ -302,6 +302,16 @@ class Model:
self.sampling_params = sampling_params
return self

def with_step_config(self, step_config: StepConfig) -> Self:
"""Return a new model with step interjection configured (engine-backed models only)."""
self = self.copy()
# Only EngineInterpreter has step_config; guard for other interpreter types
if hasattr(self._interpreter, "step_config"):
setattr(self._interpreter, "step_config", step_config)
else:
raise NotImplementedError("Step interjection is only supported for engine-backed models.")
return self

def __getattribute__(self, name):
if name == "engine":
# For legacy model.engine access (mostly for tests...)


+ 213
- 7
guidance/models/_engine/_engine.py View File

@@ -10,7 +10,16 @@ from jinja2 import BaseLoader, Environment
from numpy.typing import NDArray

from ..._parser import TokenParser
from ..._schema import EngineResponse, GenToken, GenTokenExtra, SamplingParams, TokenUsage
from ..._schema import (
EngineResponse,
GenToken,
GenTokenExtra,
SamplingParams,
StepConfig,
StepContext,
StepFeedback,
TokenUsage,
)
from ..._utils import apply_min_p_filter, apply_repetition_penalty, apply_top_k_and_top_p_filter, log_init, softmax
from ._state import EngineState
from ._tokenizer import Tokenizer
@@ -88,6 +97,7 @@ class Engine(ABC):
grammar: str,
ensure_bos_token: bool = True,
sampling_params: SamplingParams | None = None,
step_config: StepConfig | None = None,
) -> Generator[EngineResponse, None, TokenUsage]:
"""Main entry point for the inference-parser loop. Yields EngineCallResponse objects as
the parser advances through the grammar.
@@ -123,9 +133,24 @@ class Engine(ABC):
issued_token: GenToken | None = None
usage = TokenUsage(round_trips=1, ff_tokens=0)

step_every_k: int | None = None
step_stop_strings: set[str] = set()
step_callback = None
if step_config is not None:
step_every_k = step_config.get("step_every_k") # type: ignore[assignment]
step_stop_strings = set(step_config.get("step_stop_tokens", set())) # type: ignore[assignment]
step_callback = step_config.get("callback") # type: ignore[assignment]
step_counter = 0

step_tokens_buffer: list[int] = []
all_generated_tokens: list[int] = []
all_text_bytes = bytearray()

while not parser.done():
t1 = time.monotonic()
recode = False
has_injection_backtrack = False # Track if this response has injection backtrack

if issued_token is None:
prefix_tokens, backtrack, ff_tokens, mask_fut = parser.process_prompt(
prompt_tokens=tokens,
@@ -269,18 +294,199 @@ class Engine(ABC):
)
)

new_bytes_acc = bytearray(legacy_engine_response.new_bytes)
captures_acc = dict(legacy_engine_response.capture_groups)
cap_log_probs_acc = dict(legacy_engine_response.capture_group_log_probs)

new_token_ids_this_iter = [t.token_id for t in gen_tokens]
step_tokens_buffer.extend(new_token_ids_this_iter)
all_generated_tokens.extend(new_token_ids_this_iter)
all_text_bytes += legacy_engine_response.new_bytes

boundary_hit = False
boundary_type = None # Track whether it's "every_k" or "stop_string"
matched_stop_string = None
if new_token_ids_this_iter:
if step_every_k is not None and step_every_k > 0:
if len(step_tokens_buffer) >= step_every_k:
boundary_hit = True
boundary_type = "every_k"
if step_stop_strings and not boundary_hit:
# Check if the accumulated text ends with any stop string
accumulated_text = bytes(all_text_bytes).decode("utf-8", errors="ignore")
for stop_string in step_stop_strings:
if accumulated_text.endswith(stop_string):
boundary_hit = True
boundary_type = "stop_string"
matched_stop_string = stop_string
break

if boundary_hit and step_callback is not None:
ctx: StepContext = {
"last_step_text": self.tokenizer.decode(step_tokens_buffer).decode("utf-8", errors="ignore"),
"last_step_tokens": list(step_tokens_buffer),
"all_text": bytes(all_text_bytes).decode("utf-8", errors="ignore"),
"all_tokens": list(all_generated_tokens),
"captures": dict(captures_acc),
"step_counter": step_counter,
}
feedback: StepFeedback | None = step_callback(ctx) # type: ignore[misc]
step_counter = ctx["step_counter"]
if feedback:
inj_bytes: bytes | None = None
if "injected_bytes" in feedback and feedback["injected_bytes"]:
inj_bytes = feedback["injected_bytes"]
elif "injected_text" in feedback and feedback["injected_text"]:
inj_bytes = feedback["injected_text"].encode("utf-8")
if inj_bytes:
# Only rollback for stop_string case
backtrack_token_ids = []
if boundary_type == "stop_string" and matched_stop_string:
# Calculate how many tokens to backtrack based on the matched stop string
# We need to find which recent tokens form the stop string
stop_string_bytes = matched_stop_string.encode("utf-8")

# Search backwards through recent tokens to find which ones form the stop string
accumulated_bytes = b""
for i in range(len(all_generated_tokens) - 1, -1, -1):
token_id = all_generated_tokens[i]
token_bytes = self.tokenizer.decode([token_id])
accumulated_bytes = token_bytes + accumulated_bytes
backtrack_token_ids.insert(0, token_id)

# Check if we've accumulated enough to match the stop string
accumulated_text = accumulated_bytes.decode("utf-8", errors="ignore")
if stop_string_bytes.decode("utf-8", errors="ignore") in accumulated_text:
# We've found all tokens that contribute to the stop string
break

# Safety: don't go back more than 20 tokens
if len(backtrack_token_ids) >= 20:
break

backtrack_bytes_to_remove = (
self.tokenizer.decode(backtrack_token_ids) if backtrack_token_ids else b""
)

# Remove the tokens from model context
if len(tokens) >= len(backtrack_token_ids):
tokens = tokens[: -len(backtrack_token_ids)]

# Determine which backtrack tokens are in the current response vs previous
# Tokens in current response are in new_bytes_acc and gen_tokens
current_response_token_count = 0
temp_bytes = bytes(new_bytes_acc)
for i in range(len(backtrack_token_ids) - 1, -1, -1):
token_bytes = self.tokenizer.decode([backtrack_token_ids[i]])
if temp_bytes.endswith(token_bytes):
current_response_token_count += 1
temp_bytes = temp_bytes[: -len(token_bytes)]
else:
break

previous_response_token_count = len(backtrack_token_ids) - current_response_token_count

# Remove tokens from current response
if current_response_token_count > 0:
# Remove from new_bytes_acc
for i in range(
len(backtrack_token_ids) - 1,
len(backtrack_token_ids) - 1 - current_response_token_count,
-1,
):
token_bytes = self.tokenizer.decode([backtrack_token_ids[i]])
if new_bytes_acc.endswith(token_bytes):
new_bytes_acc = new_bytes_acc[: -len(token_bytes)]
# Remove from gen_tokens
if len(gen_tokens) >= current_response_token_count:
gen_tokens = gen_tokens[:-current_response_token_count]

# Backtrack bytes are only from previous responses
backtrack_bytes_from_previous = (
self.tokenizer.decode(backtrack_token_ids[:previous_response_token_count])
if previous_response_token_count > 0
else b""
)

# Remove from tracking buffers (for future context)
if len(step_tokens_buffer) >= len(backtrack_token_ids):
step_tokens_buffer = step_tokens_buffer[: -len(backtrack_token_ids)]
if len(all_generated_tokens) >= len(backtrack_token_ids):
all_generated_tokens = all_generated_tokens[: -len(backtrack_token_ids)]
if backtrack_bytes_to_remove and all_text_bytes.endswith(backtrack_bytes_to_remove):
all_text_bytes = all_text_bytes[: -len(backtrack_bytes_to_remove)]

# Add injection backtrack to any existing parser backtrack
# For injection: only backtrack what's in previous responses
backtracked_bytes = backtrack_bytes_from_previous + backtracked_bytes
backtrack = previous_response_token_count + backtrack
backtracked_bytes = backtrack_bytes_from_previous + backtracked_bytes
backtrack = previous_response_token_count + backtrack

# Set flag to indicate this is an injection backtrack
has_injection_backtrack = True

# Inject tokens (applies to both every_k and stop_string cases)
inj_token_ids = self.tokenizer.encode(inj_bytes)
for inj_token_id in inj_token_ids:
backtrack2, ff_tokens2, mask_fut2 = parser.advance(token_id=inj_token_id)
if backtrack2:
tokens[:] = tokens[:-backtrack2]
# Add the injected token to the model's context
tokens.append(inj_token_id)
tokens += ff_tokens2
mask2, ll_response2, _ = mask_fut2.result()
legacy2 = ll_response2.progress.to_engine_call_response()
# DON'T add injected tokens to current response - they'll appear in next iteration
for k, v in legacy2.capture_groups.items():
captures_acc[k] = v
for k, v in legacy2.capture_group_log_probs.items():
cap_log_probs_acc[k] = v

usage.ff_tokens += len(ff_tokens2)
step_tokens_buffer.append(inj_token_id)
step_tokens_buffer.extend(ff_tokens2)
all_generated_tokens.append(inj_token_id)
all_generated_tokens.extend(ff_tokens2)
all_text_bytes += legacy2.new_bytes

# Add injected tokens to the CURRENT response
inj_bytes_acc = bytearray()
for inj_token_id in inj_token_ids:
gen_tokens.append(
GenTokenExtra(
token_id=inj_token_id,
bytes=self.tokenizer.decode([inj_token_id]),
prob=float("nan"),
latency_ms=0.0,
is_generated=False,
is_force_forwarded=True,
is_input=False,
is_backtracked=False,
is_masked=False,
top_k=[],
)
)
inj_bytes_acc += self.tokenizer.decode([inj_token_id])

new_bytes_acc += inj_bytes_acc

# Set flag to indicate this is an injection backtrack
if boundary_type == "stop_string" and matched_stop_string:
has_injection_backtrack = True

step_tokens_buffer = []

engine_response = EngineResponse(
new_bytes=legacy_engine_response.new_bytes,
new_bytes=bytes(new_bytes_acc),
backtrack_bytes=backtracked_bytes,
capture_groups=legacy_engine_response.capture_groups,
capture_group_log_probs=legacy_engine_response.capture_group_log_probs,
capture_groups=captures_acc,
capture_group_log_probs=cap_log_probs_acc,
backtrack=backtrack,
tokens=gen_tokens,
injection_backtrack=has_injection_backtrack,
)

# process engine_response
# NOTE (loc): We should not yield the engine_response if new_bytes are invalid utf-8 bytes
# delayed bytes should be handled here in the engine
yield engine_response

if ll_response.stop:


+ 20
- 3
guidance/models/_engine/_interpreter.py View File

@@ -5,7 +5,7 @@ from io import BytesIO
from typing import Iterator

from ..._ast import GrammarNode, ImageBlob, JoinNode, LiteralNode, RoleEnd, RoleStart, SpecialToken, ToolCallNode
from ..._schema import GenTokenExtra, TokenUsage
from ..._schema import GenTokenExtra, StepConfig, TokenUsage
from ..._utils import to_utf8_or_bytes_string
from ...trace import Backtrack, ImageOutput, OutputAttr, Token, TokenOutput
from .._base import Interpreter
@@ -18,6 +18,7 @@ class EngineInterpreter(Interpreter[EngineState]):
super().__init__(state=EngineState())
self.engine = engine
self.chat_template = self.engine.get_chat_template()
self.step_config: StepConfig | None = None

def __deepcopy__(self, memo):
"""Custom deepcopy to ensure engine is not copied."""
@@ -67,6 +68,7 @@ class EngineInterpreter(Interpreter[EngineState]):
grammar=node.ll_grammar(),
ensure_bos_token=True,
sampling_params=kwargs.pop("sampling_params", None),
step_config=self.step_config,
)

delayed_bytes = b""
@@ -81,13 +83,28 @@ class EngineInterpreter(Interpreter[EngineState]):

new_bytes = recode_special_tokens(self.engine.tokenizer, chunk.new_bytes)
new_text, delayed_bytes = partial_decode(delayed_bytes + new_bytes)
self.state.prompt += new_text

if chunk.backtrack:
# Check if this is an injection backtrack (should happen before adding text)
if chunk.injection_backtrack and chunk.backtrack:
# Remove backtracked text from the prompt BEFORE adding new text
backtrack_text = chunk.backtrack_bytes.decode("utf-8", errors="ignore")
if self.state.prompt.endswith(backtrack_text):
self.state.prompt = self.state.prompt[: -len(backtrack_text)]
yield Backtrack(
n_tokens=chunk.backtrack,
bytes=b64encode(chunk.backtrack_bytes),
)
# Now add new text after backtrack
self.state.prompt += new_text
else:
# Normal flow: add text first, then backtrack
self.state.prompt += new_text

if chunk.backtrack:
yield Backtrack(
n_tokens=chunk.backtrack,
bytes=b64encode(chunk.backtrack_bytes),
)

for token in chunk.tokens:
if isinstance(token, GenTokenExtra):


+ 48
- 0
tests/unit/test_model.py View File

@@ -53,3 +53,51 @@ def test_trace():
m2 = m1 + "Roses are red and " + gen(name="suffix", regex="[A-Za-z]{2,5}", max_tokens=5)

assert m2["suffix"] is not None


def test_step_every_k_injection():
import re

lm = models.Mock(echo=False)

calls = {"count": 0}

def cb(ctx):
calls["count"] += 1
return {"injected_text": "[FIX]"}

cfg = {
"step_every_k": 4,
"callback": cb,
}
lm = lm.with_step_config(cfg)

lm = lm + gen(max_tokens=20, stop="\n", temperature=0.0)

s = str(lm)
# find all occurrences of [FIX] in s and their positions
occurrences = [m.start() for m in re.finditer(r"\[FIX\]", s)]
assert occurrences == [6, 18]
assert calls["count"] == len(occurrences)


def test_step_stop_token_trigger_injection():
lm = models.Mock(byte_patterns=[b"abc!\n"], echo=False)

calls = {"count": 0}

def cb(ctx):
calls["count"] += 1
return {"injected_text": "[FIX2]"}

cfg = {
"step_stop_tokens": {"ym"},
"callback": cb,
}
lm = lm.with_step_config(cfg)

lm = lm + gen(max_tokens=20, stop="\n", temperature=0.0)

s = str(lm)
assert "[FIX2]" in s and "ym" not in s
assert calls["count"] == 1

Loading…
Cancel
Save
Baidu
map