decoding.generators

Generators are the user's highest-level interface to the decoding library. By composing instances of decoding.models.LanguageModel, decoding.scorers.Scorer, and control flow parameters that specify sync, stop, and search conditions, users can implement a wide variety of decoding algorithms with very little code.

The BestOfN and TreeSearch generators are currently fully supported. There is also experimental support for RolloutTreeSearch in the decoding.experimental module, which supports a simple wrapper interface for a more standard Monte Carlo Tree Search (MCTS) algorithm. It is also on the roadmap to bring twisted SMC to the decoding library.

NB: The examples below are illustrative of the API, but not particularly useful. See the examples directory for more interesting examples.

  1"""
  2Generators are the user's highest-level interface to the decoding library.
  3By composing instances of `decoding.models.LanguageModel`, `decoding.scorers.Scorer`,
  4and control flow parameters that specify sync, stop, and search conditions, users can
  5implement a wide variety of decoding algorithms with very little code.
  6
  7The `BestOfN` and `TreeSearch` generators are currently fully supported. There is also
  8experimental support for `RolloutTreeSearch` in the `decoding.experimental` module,
  9which supports a simple wrapper interface for a more standard Monte Carlo Tree Search
 10(MCTS) algorithm. It is also on the roadmap to bring twisted `SMC` to the
 11`decoding` library.
 12
 13**NB**: The examples below are illustrative of the API, but not particularly useful.
 14See the [`examples`](https://github.com/benlipkin/decoding/tree/main/examples)
 15directory for more interesting examples.
 16"""
 17
 18from collections.abc import Callable
 19from dataclasses import dataclass
 20
 21from vllm.sampling_params import LogitsProcessor, SamplingParams
 22from vllm.transformers_utils.tokenizers import MistralTokenizer
 23
 24from decoding.models import LanguageModel
 25from decoding.pmf import LogPMF, ScoredItem, sort_scored_items, topk_scored_items
 26from decoding.scorers import Scorer
 27
 28
 29@dataclass(frozen=True, kw_only=True)
 30class _SearchParams:
 31    n: int
 32    width: int
 33    max_steps: int
 34    stop_pass: Callable[[str], bool]
 35    stop_fail: Callable[[str], bool]
 36
 37
 38def BestOfN(  # noqa: PLR0913
 39    *,
 40    prompt: str,
 41    llm: LanguageModel,
 42    scorer: Scorer,
 43    n: int = 1,
 44    min_tokens: int = 0,
 45    max_tokens: int | None = None,
 46    stop_str: list[str] | str | None = None,
 47    stop_token_ids: list[int] | str | None = None,
 48    include_stop_str_in_output: bool = True,
 49    track_logprobs: bool = False,
 50    temperature: float = 1.0,
 51    logits_processors: list[LogitsProcessor] | None = None,
 52    seed: int | None = None,
 53) -> list[ScoredItem[str]]:
 54    """
 55    Generate `n` samples from the language model `llm` using the `scorer` to rank them.
 56    See the [`vLLM.SamplingParams`](https://docs.vllm.ai/en/latest/dev/sampling_params.html)
 57    docs to learn more about some of these parameters such as `logits_processors`.
 58
 59    Args:
 60        prompt: The input prompt string.
 61        llm: The language model to generate samples from.
 62        scorer: The scorer to rank the samples.
 63        n: The number of samples to generate.
 64        min_tokens: The minimum number of tokens in each sample.
 65        max_tokens: The maximum number of tokens in each sample.
 66        stop_str: A string or list of strings that, if generated, will stop decoding.
 67        stop_token_ids: A list of token IDs that, if generated, will stop decoding.
 68            A string can also be passed, which will specify all token IDs that contain
 69            that substring.
 70        include_stop_str_in_output: Whether to include the stop string in the output.
 71        track_logprobs: Whether to track log probabilities. This comes at a performance
 72            cost, so it is off by default. In most cases, as you are alrady sampling
 73            from the model, you do not want to double count the probabilities in the
 74            scorer anyways.
 75        temperature: The temperature for sampling.
 76        logits_processors: A list of logits processors.
 77        seed: The random seed.
 78
 79    Returns:
 80        A list of `decoding.pmf.ScoredItem` objects sorted by the `scorer`.
 81
 82    Raises:
 83        ValueError: If any of the argument configurations are invalid.
 84
 85    Examples:
 86        ```python
 87        from decoding.generators import BestOfN
 88        from decoding.models import LanguageModel
 89        from decoding.scorers import Scorer
 90
 91        llm = LanguageModel.from_id("gpt2")
 92        scorer = Scorer.from_f_str_to_num(lambda x: -len(x))
 93        samples = BestOfN(
 94            prompt="The",
 95            llm=llm,
 96            scorer=scorer,
 97            n=20,
 98            stop_str=".",
 99            seed=42,
100        )
101        assert len(samples) == 20
102        assert all(s.item.endswith(".") for s in samples)
103        assert all(s.score == -len(s.item) for s in samples)
104        assert samples[0].score >= samples[-1].score
105        ```
106
107    """
108    sampling_params = SamplingParams(
109        n=_guard_positive_int(n),
110        min_tokens=min_tokens,
111        max_tokens=max_tokens,
112        stop=stop_str,
113        stop_token_ids=_prepare_token_ids(stop_token_ids, llm=llm),
114        include_stop_str_in_output=include_stop_str_in_output,
115        logprobs=_prepare_track_logprobs(track_logprobs),
116        prompt_logprobs=_prepare_track_logprobs(track_logprobs),
117        temperature=temperature,
118        logits_processors=logits_processors,
119        seed=seed,
120        **_default_sampling_kwargs,  # type: ignore[reportArgumentType]
121    )
122    samples = _BestOfN([prompt], llm, scorer, sampling_params)
123    return sort_scored_items(samples)
124
125
126def TreeSearch(  # noqa: PLR0913
127    *,
128    prompt: str,
129    llm: LanguageModel,
130    step_scorer: Scorer,
131    final_scorer: Scorer | None = None,
132    stop_cond_pass: Callable[[str], bool],
133    stop_cond_fail: Callable[[str], bool] | None = None,
134    n: int = 1,
135    beam_width: int = 1,
136    beam_factor: int = 1,
137    max_steps: int | None = None,
138    min_tokens_per_step: int = 0,
139    max_tokens_per_step: int | None = None,
140    sync_str: list[str] | str | None = None,
141    sync_token_ids: list[int] | str | None = None,
142    include_sync_str_in_output: bool = True,
143    track_logprobs: bool = False,
144    temperature: float = 1.0,
145    logits_processors: list[LogitsProcessor] | None = None,
146    seed: int | None = None,
147) -> list[ScoredItem[str]]:
148    """
149    Generate `n` samples from the language model `llm` using the `step_scorer` to
150    rank them at each sync step and the `final_scorer` to rank the final beam.
151
152    Args:
153        prompt: The input prompt string.
154        llm: The language model to generate samples from.
155        step_scorer: The scorer to rank the samples at each sync step.
156        final_scorer: The scorer to rank the final beam.
157        stop_cond_pass: A function that returns `True` if the sample should pass.
158            This stops the sample from being extended.
159        stop_cond_fail: A function that returns `True` if the sample should fail.
160            This filters the sample from the live beam.
161        n: The number of passing samples to generate before returning.
162        beam_width: The width of the beam. This is the number of samples to
163            keep at each step.
164        beam_factor: The branching factor of the beam. This is the number of
165            new samples to generate from each live sample at each sync step.
166        max_steps: The maximum number of sync steps to take.
167        min_tokens_per_step: The minimum number of tokens in each step's extension.
168        max_tokens_per_step: The maximum number of tokens in each step's extension.
169        sync_str: A string or list of strings that, if generated, will stop extending
170            each sample in the live beam and await scoring, ranking, and filtering.
171        sync_token_ids: A list of token IDs that, if generated, will stop extending
172            each sample in the live beam and await scoring, ranking, and filtering.
173            A string can also be passed, which will specify all token IDs that contain
174            that substring.
175        include_sync_str_in_output: Whether to include the stop string in the output.
176        track_logprobs: Whether to track log probabilities. This comes at a performance
177            cost, so it is off by default. In most cases, as you are already sampling
178            from the model, you do not want to double count the probabilities in the
179            scorer anyways.
180        temperature: The temperature for sampling.
181        logits_processors: A list of logits processors.
182            NB: This is applied within each step as opposed to globally.
183        seed: The random seed.
184
185    Returns:
186        A list of `decoding.pmf.ScoredItem` objects sorted by the `final_scorer`.
187
188    Raises:
189        ValueError: If any of the argument configurations are invalid
190        RuntimeError: if all live samples in the beam fail,
191            or if max steps is reached before any samples pass.
192
193    Examples:
194        ```python
195        from decoding.generators import TreeSearch
196        from decoding.models import LanguageModel
197        from decoding.pmf import ScoredItem
198        from decoding.scorers import Scorer
199
200        def f(x):
201            if "." in x:
202                x = x.split(".")[0] + "."
203            return ScoredItem(item=x, score=-len(x))
204
205        llm = LanguageModel.from_id("gpt2")
206        scorer = Scorer.from_f_str_to_sample(f)
207        samples = TreeSearch(
208            prompt="The",
209            sync_token_ids=" ",
210            stop_cond_pass=lambda x: x.endswith("."),
211            llm=llm,
212            step_scorer=scorer,
213            final_scorer=scorer,
214            n=3,
215            beam_width=50,
216            beam_factor=5,
217            seed=42,
218        )
219        assert len(samples) == 3
220        assert all(s.item.endswith(".") for s in samples)
221        assert all(s.score == -len(s.item) for s in samples)
222        assert samples[0].score >= samples[-1].score
223        ```
224
225    """
226    if final_scorer is None:
227        final_scorer = step_scorer
228    search_params = _SearchParams(
229        n=_guard_positive_int(n),
230        width=_guard_positive_int(beam_width),
231        max_steps=_prepare_max_steps(max_steps),
232        stop_pass=_prepare_stop(stop_cond_pass),
233        stop_fail=_prepare_stop(stop_cond_fail),
234    )
235    _validate_search_params(search_params)
236    sampling_params = SamplingParams(
237        n=_guard_positive_int(beam_factor),
238        min_tokens=min_tokens_per_step,
239        max_tokens=max_tokens_per_step,
240        stop=sync_str,
241        stop_token_ids=_prepare_token_ids(sync_token_ids, llm=llm),
242        include_stop_str_in_output=include_sync_str_in_output,
243        logprobs=_prepare_track_logprobs(track_logprobs),
244        prompt_logprobs=_prepare_track_logprobs(track_logprobs),
245        temperature=temperature,
246        logits_processors=logits_processors,
247        seed=seed,
248        **_default_sampling_kwargs,  # type: ignore[reportArgumentType]
249    )
250    samples = _TreeSearch([prompt], llm, step_scorer, search_params, sampling_params)
251    return sort_scored_items(final_scorer(LogPMF.from_samples(samples)))
252
253
254def _BestOfN(
255    prompts: list[str],
256    llm: LanguageModel,
257    scorer: Scorer,
258    sampling_params: SamplingParams,
259) -> list[ScoredItem[str]]:
260    return scorer(llm(prompts=prompts, params=sampling_params))
261
262
263def _TreeSearch(
264    prompts: list[str],
265    llm: LanguageModel,
266    scorer: Scorer,
267    search_params: _SearchParams,
268    sampling_params: SamplingParams,
269) -> list[ScoredItem[str]]:
270    beam = [ScoredItem(item=p, score=-float("inf")) for p in prompts]
271    passing = []
272    for _ in range(search_params.max_steps):
273        stop_pass = [search_params.stop_pass(s.item) for s in beam]
274        stop_fail = [search_params.stop_fail(s.item) for s in beam]
275        passing = []
276        prompts = []
277        for sample, passed, failed in zip(beam, stop_pass, stop_fail, strict=True):
278            if passed and not failed:
279                passing.append(sample)
280            elif not failed:
281                prompts.append(sample.item)
282            else:  # failed
283                pass
284        if len(passing) >= search_params.n:
285            return passing
286        if len(prompts) == 0:
287            return _handle_failed_beam(passing)
288        live = _BestOfN(prompts, llm, scorer, sampling_params)
289        beam = passing + live
290        if len(beam) > search_params.width:
291            beam = topk_scored_items(beam, search_params.width)
292    return _handle_maxsteps(passing)
293
294
295def _prepare_token_ids(
296    token_ids: list[int] | str | None, *, llm: LanguageModel
297) -> list[int] | None:
298    if isinstance(token_ids, str):
299        return _get_token_ids_from_delimiter(llm=llm, delimiter=token_ids)
300    return token_ids
301
302
303def _get_token_ids_from_delimiter(*, llm: LanguageModel, delimiter: str) -> list[int]:
304    _validate_delimiter(delimiter)
305    tokenizer = llm.tokenizer
306    if isinstance(tokenizer, MistralTokenizer):
307        msg = "vLLM Mistral tokenizer does not currently support `batch_decode`."
308        raise NotImplementedError(msg)
309    tokens = list(tokenizer.get_vocab().values())
310    strs = tokenizer.batch_decode(tokens)
311    return [tokens[i] for i, s in enumerate(strs) if delimiter in s]
312
313
314def _validate_search_params(params: _SearchParams) -> None:
315    if params.n > params.width:
316        msg = "`beam_width` cannot be less than `n`."
317        raise ValueError(msg)
318
319
320def _validate_delimiter(delimiter: str) -> None:
321    if len(delimiter) != 1:
322        msg = f"Delimiter must be a single character, got: {delimiter}."
323        raise ValueError(msg)
324
325
326def _prepare_stop(
327    stop: Callable[[str], bool] | None,
328) -> Callable[[str], bool]:
329    if stop is None:
330
331        def _dont_stop(_: str) -> bool:
332            return False
333
334        return _dont_stop
335    return stop
336
337
338def _prepare_max_steps(max_steps: int | None) -> int:
339    if max_steps is None:
340        return 2**32
341    return _guard_positive_int(max_steps)
342
343
344def _prepare_track_logprobs(track_logprobs: bool) -> int | None:  # noqa: FBT001
345    return 0 if track_logprobs else None
346
347
348def _guard_positive_int(n: int) -> int:
349    if n < 1:
350        msg = f"Expected a positive integer, got: {n}."
351        raise ValueError(msg)
352    return n
353
354
355def _handle_failed_beam(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]:
356    if len(passing) == 0:
357        msg = "All live samples failed before any passed stop conditions."
358        msg += " Check compatibility of stop conditions or expand search."
359        raise RuntimeError(msg)
360    import warnings
361
362    msg = "All live samples failed before completing search,"
363    msg += " but some completed samples have already passed stopping conditions."
364    msg += " Returning available passing samples."
365    warnings.warn(msg, stacklevel=2)
366    return passing
367
368
369def _handle_maxsteps(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]:
370    if len(passing) == 0:
371        msg = "Max steps reached, and no samples passed stop conditions."
372        raise RuntimeError(msg)
373    import warnings
374
375    msg = "Max steps reached before completing search,"
376    msg += "but some samples have already passed stopping conditions."
377    msg += " Returning available passing samples."
378    warnings.warn(msg, stacklevel=2)
379    return passing
380
381
382_default_sampling_kwargs = {
383    "detokenize": True,
384    "ignore_eos": False,
385    "truncate_prompt_tokens": None,
386}
def BestOfN( *, prompt: str, llm: decoding.models.LanguageModel, scorer: decoding.scorers.Scorer, n: int = 1, min_tokens: int = 0, max_tokens: int | None = None, stop_str: list[str] | str | None = None, stop_token_ids: list[int] | str | None = None, include_stop_str_in_output: bool = True, track_logprobs: bool = False, temperature: float = 1.0, logits_processors: list[typing.Union[typing.Callable[[typing.List[int], torch.Tensor], torch.Tensor], typing.Callable[[typing.List[int], typing.List[int], torch.Tensor], torch.Tensor]]] | None = None, seed: int | None = None) -> list[decoding.pmf.ScoredItem[str]]:
 39def BestOfN(  # noqa: PLR0913
 40    *,
 41    prompt: str,
 42    llm: LanguageModel,
 43    scorer: Scorer,
 44    n: int = 1,
 45    min_tokens: int = 0,
 46    max_tokens: int | None = None,
 47    stop_str: list[str] | str | None = None,
 48    stop_token_ids: list[int] | str | None = None,
 49    include_stop_str_in_output: bool = True,
 50    track_logprobs: bool = False,
 51    temperature: float = 1.0,
 52    logits_processors: list[LogitsProcessor] | None = None,
 53    seed: int | None = None,
 54) -> list[ScoredItem[str]]:
 55    """
 56    Generate `n` samples from the language model `llm` using the `scorer` to rank them.
 57    See the [`vLLM.SamplingParams`](https://docs.vllm.ai/en/latest/dev/sampling_params.html)
 58    docs to learn more about some of these parameters such as `logits_processors`.
 59
 60    Args:
 61        prompt: The input prompt string.
 62        llm: The language model to generate samples from.
 63        scorer: The scorer to rank the samples.
 64        n: The number of samples to generate.
 65        min_tokens: The minimum number of tokens in each sample.
 66        max_tokens: The maximum number of tokens in each sample.
 67        stop_str: A string or list of strings that, if generated, will stop decoding.
 68        stop_token_ids: A list of token IDs that, if generated, will stop decoding.
 69            A string can also be passed, which will specify all token IDs that contain
 70            that substring.
 71        include_stop_str_in_output: Whether to include the stop string in the output.
 72        track_logprobs: Whether to track log probabilities. This comes at a performance
 73            cost, so it is off by default. In most cases, as you are alrady sampling
 74            from the model, you do not want to double count the probabilities in the
 75            scorer anyways.
 76        temperature: The temperature for sampling.
 77        logits_processors: A list of logits processors.
 78        seed: The random seed.
 79
 80    Returns:
 81        A list of `decoding.pmf.ScoredItem` objects sorted by the `scorer`.
 82
 83    Raises:
 84        ValueError: If any of the argument configurations are invalid.
 85
 86    Examples:
 87        ```python
 88        from decoding.generators import BestOfN
 89        from decoding.models import LanguageModel
 90        from decoding.scorers import Scorer
 91
 92        llm = LanguageModel.from_id("gpt2")
 93        scorer = Scorer.from_f_str_to_num(lambda x: -len(x))
 94        samples = BestOfN(
 95            prompt="The",
 96            llm=llm,
 97            scorer=scorer,
 98            n=20,
 99            stop_str=".",
100            seed=42,
101        )
102        assert len(samples) == 20
103        assert all(s.item.endswith(".") for s in samples)
104        assert all(s.score == -len(s.item) for s in samples)
105        assert samples[0].score >= samples[-1].score
106        ```
107
108    """
109    sampling_params = SamplingParams(
110        n=_guard_positive_int(n),
111        min_tokens=min_tokens,
112        max_tokens=max_tokens,
113        stop=stop_str,
114        stop_token_ids=_prepare_token_ids(stop_token_ids, llm=llm),
115        include_stop_str_in_output=include_stop_str_in_output,
116        logprobs=_prepare_track_logprobs(track_logprobs),
117        prompt_logprobs=_prepare_track_logprobs(track_logprobs),
118        temperature=temperature,
119        logits_processors=logits_processors,
120        seed=seed,
121        **_default_sampling_kwargs,  # type: ignore[reportArgumentType]
122    )
123    samples = _BestOfN([prompt], llm, scorer, sampling_params)
124    return sort_scored_items(samples)

Generate n samples from the language model llm using the scorer to rank them. See the vLLM.SamplingParams docs to learn more about some of these parameters such as logits_processors.

Arguments:
  • prompt: The input prompt string.
  • llm: The language model to generate samples from.
  • scorer: The scorer to rank the samples.
  • n: The number of samples to generate.
  • min_tokens: The minimum number of tokens in each sample.
  • max_tokens: The maximum number of tokens in each sample.
  • stop_str: A string or list of strings that, if generated, will stop decoding.
  • stop_token_ids: A list of token IDs that, if generated, will stop decoding. A string can also be passed, which will specify all token IDs that contain that substring.
  • include_stop_str_in_output: Whether to include the stop string in the output.
  • track_logprobs: Whether to track log probabilities. This comes at a performance cost, so it is off by default. In most cases, as you are alrady sampling from the model, you do not want to double count the probabilities in the scorer anyways.
  • temperature: The temperature for sampling.
  • logits_processors: A list of logits processors.
  • seed: The random seed.
Returns:

A list of decoding.pmf.ScoredItem objects sorted by the scorer.

Raises:
  • ValueError: If any of the argument configurations are invalid.
Examples:
from decoding.generators import BestOfN
from decoding.models import LanguageModel
from decoding.scorers import Scorer

llm = LanguageModel.from_id("gpt2")
scorer = Scorer.from_f_str_to_num(lambda x: -len(x))
samples = BestOfN(
    prompt="The",
    llm=llm,
    scorer=scorer,
    n=20,
    stop_str=".",
    seed=42,
)
assert len(samples) == 20
assert all(s.item.endswith(".") for s in samples)
assert all(s.score == -len(s.item) for s in samples)
assert samples[0].score >= samples[-1].score
def TreeSearch( *, prompt: str, llm: decoding.models.LanguageModel, step_scorer: decoding.scorers.Scorer, final_scorer: decoding.scorers.Scorer | None = None, stop_cond_pass: Callable[[str], bool], stop_cond_fail: Callable[[str], bool] | None = None, n: int = 1, beam_width: int = 1, beam_factor: int = 1, max_steps: int | None = None, min_tokens_per_step: int = 0, max_tokens_per_step: int | None = None, sync_str: list[str] | str | None = None, sync_token_ids: list[int] | str | None = None, include_sync_str_in_output: bool = True, track_logprobs: bool = False, temperature: float = 1.0, logits_processors: list[typing.Union[typing.Callable[[typing.List[int], torch.Tensor], torch.Tensor], typing.Callable[[typing.List[int], typing.List[int], torch.Tensor], torch.Tensor]]] | None = None, seed: int | None = None) -> list[decoding.pmf.ScoredItem[str]]:
127def TreeSearch(  # noqa: PLR0913
128    *,
129    prompt: str,
130    llm: LanguageModel,
131    step_scorer: Scorer,
132    final_scorer: Scorer | None = None,
133    stop_cond_pass: Callable[[str], bool],
134    stop_cond_fail: Callable[[str], bool] | None = None,
135    n: int = 1,
136    beam_width: int = 1,
137    beam_factor: int = 1,
138    max_steps: int | None = None,
139    min_tokens_per_step: int = 0,
140    max_tokens_per_step: int | None = None,
141    sync_str: list[str] | str | None = None,
142    sync_token_ids: list[int] | str | None = None,
143    include_sync_str_in_output: bool = True,
144    track_logprobs: bool = False,
145    temperature: float = 1.0,
146    logits_processors: list[LogitsProcessor] | None = None,
147    seed: int | None = None,
148) -> list[ScoredItem[str]]:
149    """
150    Generate `n` samples from the language model `llm` using the `step_scorer` to
151    rank them at each sync step and the `final_scorer` to rank the final beam.
152
153    Args:
154        prompt: The input prompt string.
155        llm: The language model to generate samples from.
156        step_scorer: The scorer to rank the samples at each sync step.
157        final_scorer: The scorer to rank the final beam.
158        stop_cond_pass: A function that returns `True` if the sample should pass.
159            This stops the sample from being extended.
160        stop_cond_fail: A function that returns `True` if the sample should fail.
161            This filters the sample from the live beam.
162        n: The number of passing samples to generate before returning.
163        beam_width: The width of the beam. This is the number of samples to
164            keep at each step.
165        beam_factor: The branching factor of the beam. This is the number of
166            new samples to generate from each live sample at each sync step.
167        max_steps: The maximum number of sync steps to take.
168        min_tokens_per_step: The minimum number of tokens in each step's extension.
169        max_tokens_per_step: The maximum number of tokens in each step's extension.
170        sync_str: A string or list of strings that, if generated, will stop extending
171            each sample in the live beam and await scoring, ranking, and filtering.
172        sync_token_ids: A list of token IDs that, if generated, will stop extending
173            each sample in the live beam and await scoring, ranking, and filtering.
174            A string can also be passed, which will specify all token IDs that contain
175            that substring.
176        include_sync_str_in_output: Whether to include the stop string in the output.
177        track_logprobs: Whether to track log probabilities. This comes at a performance
178            cost, so it is off by default. In most cases, as you are already sampling
179            from the model, you do not want to double count the probabilities in the
180            scorer anyways.
181        temperature: The temperature for sampling.
182        logits_processors: A list of logits processors.
183            NB: This is applied within each step as opposed to globally.
184        seed: The random seed.
185
186    Returns:
187        A list of `decoding.pmf.ScoredItem` objects sorted by the `final_scorer`.
188
189    Raises:
190        ValueError: If any of the argument configurations are invalid
191        RuntimeError: if all live samples in the beam fail,
192            or if max steps is reached before any samples pass.
193
194    Examples:
195        ```python
196        from decoding.generators import TreeSearch
197        from decoding.models import LanguageModel
198        from decoding.pmf import ScoredItem
199        from decoding.scorers import Scorer
200
201        def f(x):
202            if "." in x:
203                x = x.split(".")[0] + "."
204            return ScoredItem(item=x, score=-len(x))
205
206        llm = LanguageModel.from_id("gpt2")
207        scorer = Scorer.from_f_str_to_sample(f)
208        samples = TreeSearch(
209            prompt="The",
210            sync_token_ids=" ",
211            stop_cond_pass=lambda x: x.endswith("."),
212            llm=llm,
213            step_scorer=scorer,
214            final_scorer=scorer,
215            n=3,
216            beam_width=50,
217            beam_factor=5,
218            seed=42,
219        )
220        assert len(samples) == 3
221        assert all(s.item.endswith(".") for s in samples)
222        assert all(s.score == -len(s.item) for s in samples)
223        assert samples[0].score >= samples[-1].score
224        ```
225
226    """
227    if final_scorer is None:
228        final_scorer = step_scorer
229    search_params = _SearchParams(
230        n=_guard_positive_int(n),
231        width=_guard_positive_int(beam_width),
232        max_steps=_prepare_max_steps(max_steps),
233        stop_pass=_prepare_stop(stop_cond_pass),
234        stop_fail=_prepare_stop(stop_cond_fail),
235    )
236    _validate_search_params(search_params)
237    sampling_params = SamplingParams(
238        n=_guard_positive_int(beam_factor),
239        min_tokens=min_tokens_per_step,
240        max_tokens=max_tokens_per_step,
241        stop=sync_str,
242        stop_token_ids=_prepare_token_ids(sync_token_ids, llm=llm),
243        include_stop_str_in_output=include_sync_str_in_output,
244        logprobs=_prepare_track_logprobs(track_logprobs),
245        prompt_logprobs=_prepare_track_logprobs(track_logprobs),
246        temperature=temperature,
247        logits_processors=logits_processors,
248        seed=seed,
249        **_default_sampling_kwargs,  # type: ignore[reportArgumentType]
250    )
251    samples = _TreeSearch([prompt], llm, step_scorer, search_params, sampling_params)
252    return sort_scored_items(final_scorer(LogPMF.from_samples(samples)))

Generate n samples from the language model llm using the step_scorer to rank them at each sync step and the final_scorer to rank the final beam.

Arguments:
  • prompt: The input prompt string.
  • llm: The language model to generate samples from.
  • step_scorer: The scorer to rank the samples at each sync step.
  • final_scorer: The scorer to rank the final beam.
  • stop_cond_pass: A function that returns True if the sample should pass. This stops the sample from being extended.
  • stop_cond_fail: A function that returns True if the sample should fail. This filters the sample from the live beam.
  • n: The number of passing samples to generate before returning.
  • beam_width: The width of the beam. This is the number of samples to keep at each step.
  • beam_factor: The branching factor of the beam. This is the number of new samples to generate from each live sample at each sync step.
  • max_steps: The maximum number of sync steps to take.
  • min_tokens_per_step: The minimum number of tokens in each step's extension.
  • max_tokens_per_step: The maximum number of tokens in each step's extension.
  • sync_str: A string or list of strings that, if generated, will stop extending each sample in the live beam and await scoring, ranking, and filtering.
  • sync_token_ids: A list of token IDs that, if generated, will stop extending each sample in the live beam and await scoring, ranking, and filtering. A string can also be passed, which will specify all token IDs that contain that substring.
  • include_sync_str_in_output: Whether to include the stop string in the output.
  • track_logprobs: Whether to track log probabilities. This comes at a performance cost, so it is off by default. In most cases, as you are already sampling from the model, you do not want to double count the probabilities in the scorer anyways.
  • temperature: The temperature for sampling.
  • logits_processors: A list of logits processors. NB: This is applied within each step as opposed to globally.
  • seed: The random seed.
Returns:

A list of decoding.pmf.ScoredItem objects sorted by the final_scorer.

Raises:
  • ValueError: If any of the argument configurations are invalid
  • RuntimeError: if all live samples in the beam fail, or if max steps is reached before any samples pass.
Examples:
from decoding.generators import TreeSearch
from decoding.models import LanguageModel
from decoding.pmf import ScoredItem
from decoding.scorers import Scorer

def f(x):
    if "." in x:
        x = x.split(".")[0] + "."
    return ScoredItem(item=x, score=-len(x))

llm = LanguageModel.from_id("gpt2")
scorer = Scorer.from_f_str_to_sample(f)
samples = TreeSearch(
    prompt="The",
    sync_token_ids=" ",
    stop_cond_pass=lambda x: x.endswith("."),
    llm=llm,
    step_scorer=scorer,
    final_scorer=scorer,
    n=3,
    beam_width=50,
    beam_factor=5,
    seed=42,
)
assert len(samples) == 3
assert all(s.item.endswith(".") for s in samples)
assert all(s.score == -len(s.item) for s in samples)
assert samples[0].score >= samples[-1].score