decoding.generators
Generators are the user's highest-level interface to the decoding library.
By composing instances of decoding.models.LanguageModel
, decoding.scorers.Scorer
,
and control flow parameters that specify sync, stop, and search conditions, users can
implement a wide variety of decoding algorithms with very little code.
The BestOfN
and TreeSearch
generators are currently fully supported. There is also
experimental support for RolloutTreeSearch
in the decoding.experimental
module,
which supports a simple wrapper interface for a more standard Monte Carlo Tree Search
(MCTS) algorithm. It is also on the roadmap to bring twisted SMC
to the
decoding
library.
NB: The examples below are illustrative of the API, but not particularly useful.
See the examples
directory for more interesting examples.
1""" 2Generators are the user's highest-level interface to the decoding library. 3By composing instances of `decoding.models.LanguageModel`, `decoding.scorers.Scorer`, 4and control flow parameters that specify sync, stop, and search conditions, users can 5implement a wide variety of decoding algorithms with very little code. 6 7The `BestOfN` and `TreeSearch` generators are currently fully supported. There is also 8experimental support for `RolloutTreeSearch` in the `decoding.experimental` module, 9which supports a simple wrapper interface for a more standard Monte Carlo Tree Search 10(MCTS) algorithm. It is also on the roadmap to bring twisted `SMC` to the 11`decoding` library. 12 13**NB**: The examples below are illustrative of the API, but not particularly useful. 14See the [`examples`](https://github.com/benlipkin/decoding/tree/main/examples) 15directory for more interesting examples. 16""" 17 18from collections.abc import Callable 19from dataclasses import dataclass 20 21from vllm.sampling_params import LogitsProcessor, SamplingParams 22from vllm.transformers_utils.tokenizers import MistralTokenizer 23 24from decoding.models import LanguageModel 25from decoding.pmf import LogPMF, ScoredItem, sort_scored_items, topk_scored_items 26from decoding.scorers import Scorer 27 28 29@dataclass(frozen=True, kw_only=True) 30class _SearchParams: 31 n: int 32 width: int 33 max_steps: int 34 stop_pass: Callable[[str], bool] 35 stop_fail: Callable[[str], bool] 36 37 38def BestOfN( # noqa: PLR0913 39 *, 40 prompt: str, 41 llm: LanguageModel, 42 scorer: Scorer, 43 n: int = 1, 44 min_tokens: int = 0, 45 max_tokens: int | None = None, 46 stop_str: list[str] | str | None = None, 47 stop_token_ids: list[int] | str | None = None, 48 include_stop_str_in_output: bool = True, 49 track_logprobs: bool = False, 50 temperature: float = 1.0, 51 logits_processors: list[LogitsProcessor] | None = None, 52 seed: int | None = None, 53) -> list[ScoredItem[str]]: 54 """ 55 Generate `n` samples from the language model `llm` using the `scorer` to rank them. 56 See the [`vLLM.SamplingParams`](https://docs.vllm.ai/en/latest/dev/sampling_params.html) 57 docs to learn more about some of these parameters such as `logits_processors`. 58 59 Args: 60 prompt: The input prompt string. 61 llm: The language model to generate samples from. 62 scorer: The scorer to rank the samples. 63 n: The number of samples to generate. 64 min_tokens: The minimum number of tokens in each sample. 65 max_tokens: The maximum number of tokens in each sample. 66 stop_str: A string or list of strings that, if generated, will stop decoding. 67 stop_token_ids: A list of token IDs that, if generated, will stop decoding. 68 A string can also be passed, which will specify all token IDs that contain 69 that substring. 70 include_stop_str_in_output: Whether to include the stop string in the output. 71 track_logprobs: Whether to track log probabilities. This comes at a performance 72 cost, so it is off by default. In most cases, as you are alrady sampling 73 from the model, you do not want to double count the probabilities in the 74 scorer anyways. 75 temperature: The temperature for sampling. 76 logits_processors: A list of logits processors. 77 seed: The random seed. 78 79 Returns: 80 A list of `decoding.pmf.ScoredItem` objects sorted by the `scorer`. 81 82 Raises: 83 ValueError: If any of the argument configurations are invalid. 84 85 Examples: 86 ```python 87 from decoding.generators import BestOfN 88 from decoding.models import LanguageModel 89 from decoding.scorers import Scorer 90 91 llm = LanguageModel.from_id("gpt2") 92 scorer = Scorer.from_f_str_to_num(lambda x: -len(x)) 93 samples = BestOfN( 94 prompt="The", 95 llm=llm, 96 scorer=scorer, 97 n=20, 98 stop_str=".", 99 seed=42, 100 ) 101 assert len(samples) == 20 102 assert all(s.item.endswith(".") for s in samples) 103 assert all(s.score == -len(s.item) for s in samples) 104 assert samples[0].score >= samples[-1].score 105 ``` 106 107 """ 108 sampling_params = SamplingParams( 109 n=_guard_positive_int(n), 110 min_tokens=min_tokens, 111 max_tokens=max_tokens, 112 stop=stop_str, 113 stop_token_ids=_prepare_token_ids(stop_token_ids, llm=llm), 114 include_stop_str_in_output=include_stop_str_in_output, 115 logprobs=_prepare_track_logprobs(track_logprobs), 116 prompt_logprobs=_prepare_track_logprobs(track_logprobs), 117 temperature=temperature, 118 logits_processors=logits_processors, 119 seed=seed, 120 **_default_sampling_kwargs, # type: ignore[reportArgumentType] 121 ) 122 samples = _BestOfN([prompt], llm, scorer, sampling_params) 123 return sort_scored_items(samples) 124 125 126def TreeSearch( # noqa: PLR0913 127 *, 128 prompt: str, 129 llm: LanguageModel, 130 step_scorer: Scorer, 131 final_scorer: Scorer | None = None, 132 stop_cond_pass: Callable[[str], bool], 133 stop_cond_fail: Callable[[str], bool] | None = None, 134 n: int = 1, 135 beam_width: int = 1, 136 beam_factor: int = 1, 137 max_steps: int | None = None, 138 min_tokens_per_step: int = 0, 139 max_tokens_per_step: int | None = None, 140 sync_str: list[str] | str | None = None, 141 sync_token_ids: list[int] | str | None = None, 142 include_sync_str_in_output: bool = True, 143 track_logprobs: bool = False, 144 temperature: float = 1.0, 145 logits_processors: list[LogitsProcessor] | None = None, 146 seed: int | None = None, 147) -> list[ScoredItem[str]]: 148 """ 149 Generate `n` samples from the language model `llm` using the `step_scorer` to 150 rank them at each sync step and the `final_scorer` to rank the final beam. 151 152 Args: 153 prompt: The input prompt string. 154 llm: The language model to generate samples from. 155 step_scorer: The scorer to rank the samples at each sync step. 156 final_scorer: The scorer to rank the final beam. 157 stop_cond_pass: A function that returns `True` if the sample should pass. 158 This stops the sample from being extended. 159 stop_cond_fail: A function that returns `True` if the sample should fail. 160 This filters the sample from the live beam. 161 n: The number of passing samples to generate before returning. 162 beam_width: The width of the beam. This is the number of samples to 163 keep at each step. 164 beam_factor: The branching factor of the beam. This is the number of 165 new samples to generate from each live sample at each sync step. 166 max_steps: The maximum number of sync steps to take. 167 min_tokens_per_step: The minimum number of tokens in each step's extension. 168 max_tokens_per_step: The maximum number of tokens in each step's extension. 169 sync_str: A string or list of strings that, if generated, will stop extending 170 each sample in the live beam and await scoring, ranking, and filtering. 171 sync_token_ids: A list of token IDs that, if generated, will stop extending 172 each sample in the live beam and await scoring, ranking, and filtering. 173 A string can also be passed, which will specify all token IDs that contain 174 that substring. 175 include_sync_str_in_output: Whether to include the stop string in the output. 176 track_logprobs: Whether to track log probabilities. This comes at a performance 177 cost, so it is off by default. In most cases, as you are already sampling 178 from the model, you do not want to double count the probabilities in the 179 scorer anyways. 180 temperature: The temperature for sampling. 181 logits_processors: A list of logits processors. 182 NB: This is applied within each step as opposed to globally. 183 seed: The random seed. 184 185 Returns: 186 A list of `decoding.pmf.ScoredItem` objects sorted by the `final_scorer`. 187 188 Raises: 189 ValueError: If any of the argument configurations are invalid 190 RuntimeError: if all live samples in the beam fail, 191 or if max steps is reached before any samples pass. 192 193 Examples: 194 ```python 195 from decoding.generators import TreeSearch 196 from decoding.models import LanguageModel 197 from decoding.pmf import ScoredItem 198 from decoding.scorers import Scorer 199 200 def f(x): 201 if "." in x: 202 x = x.split(".")[0] + "." 203 return ScoredItem(item=x, score=-len(x)) 204 205 llm = LanguageModel.from_id("gpt2") 206 scorer = Scorer.from_f_str_to_sample(f) 207 samples = TreeSearch( 208 prompt="The", 209 sync_token_ids=" ", 210 stop_cond_pass=lambda x: x.endswith("."), 211 llm=llm, 212 step_scorer=scorer, 213 final_scorer=scorer, 214 n=3, 215 beam_width=50, 216 beam_factor=5, 217 seed=42, 218 ) 219 assert len(samples) == 3 220 assert all(s.item.endswith(".") for s in samples) 221 assert all(s.score == -len(s.item) for s in samples) 222 assert samples[0].score >= samples[-1].score 223 ``` 224 225 """ 226 if final_scorer is None: 227 final_scorer = step_scorer 228 search_params = _SearchParams( 229 n=_guard_positive_int(n), 230 width=_guard_positive_int(beam_width), 231 max_steps=_prepare_max_steps(max_steps), 232 stop_pass=_prepare_stop(stop_cond_pass), 233 stop_fail=_prepare_stop(stop_cond_fail), 234 ) 235 _validate_search_params(search_params) 236 sampling_params = SamplingParams( 237 n=_guard_positive_int(beam_factor), 238 min_tokens=min_tokens_per_step, 239 max_tokens=max_tokens_per_step, 240 stop=sync_str, 241 stop_token_ids=_prepare_token_ids(sync_token_ids, llm=llm), 242 include_stop_str_in_output=include_sync_str_in_output, 243 logprobs=_prepare_track_logprobs(track_logprobs), 244 prompt_logprobs=_prepare_track_logprobs(track_logprobs), 245 temperature=temperature, 246 logits_processors=logits_processors, 247 seed=seed, 248 **_default_sampling_kwargs, # type: ignore[reportArgumentType] 249 ) 250 samples = _TreeSearch([prompt], llm, step_scorer, search_params, sampling_params) 251 return sort_scored_items(final_scorer(LogPMF.from_samples(samples))) 252 253 254def _BestOfN( 255 prompts: list[str], 256 llm: LanguageModel, 257 scorer: Scorer, 258 sampling_params: SamplingParams, 259) -> list[ScoredItem[str]]: 260 return scorer(llm(prompts=prompts, params=sampling_params)) 261 262 263def _TreeSearch( 264 prompts: list[str], 265 llm: LanguageModel, 266 scorer: Scorer, 267 search_params: _SearchParams, 268 sampling_params: SamplingParams, 269) -> list[ScoredItem[str]]: 270 beam = [ScoredItem(item=p, score=-float("inf")) for p in prompts] 271 passing = [] 272 for _ in range(search_params.max_steps): 273 stop_pass = [search_params.stop_pass(s.item) for s in beam] 274 stop_fail = [search_params.stop_fail(s.item) for s in beam] 275 passing = [] 276 prompts = [] 277 for sample, passed, failed in zip(beam, stop_pass, stop_fail, strict=True): 278 if passed and not failed: 279 passing.append(sample) 280 elif not failed: 281 prompts.append(sample.item) 282 else: # failed 283 pass 284 if len(passing) >= search_params.n: 285 return passing 286 if len(prompts) == 0: 287 return _handle_failed_beam(passing) 288 live = _BestOfN(prompts, llm, scorer, sampling_params) 289 beam = passing + live 290 if len(beam) > search_params.width: 291 beam = topk_scored_items(beam, search_params.width) 292 return _handle_maxsteps(passing) 293 294 295def _prepare_token_ids( 296 token_ids: list[int] | str | None, *, llm: LanguageModel 297) -> list[int] | None: 298 if isinstance(token_ids, str): 299 return _get_token_ids_from_delimiter(llm=llm, delimiter=token_ids) 300 return token_ids 301 302 303def _get_token_ids_from_delimiter(*, llm: LanguageModel, delimiter: str) -> list[int]: 304 _validate_delimiter(delimiter) 305 tokenizer = llm.tokenizer 306 if isinstance(tokenizer, MistralTokenizer): 307 msg = "vLLM Mistral tokenizer does not currently support `batch_decode`." 308 raise NotImplementedError(msg) 309 tokens = list(tokenizer.get_vocab().values()) 310 strs = tokenizer.batch_decode(tokens) 311 return [tokens[i] for i, s in enumerate(strs) if delimiter in s] 312 313 314def _validate_search_params(params: _SearchParams) -> None: 315 if params.n > params.width: 316 msg = "`beam_width` cannot be less than `n`." 317 raise ValueError(msg) 318 319 320def _validate_delimiter(delimiter: str) -> None: 321 if len(delimiter) != 1: 322 msg = f"Delimiter must be a single character, got: {delimiter}." 323 raise ValueError(msg) 324 325 326def _prepare_stop( 327 stop: Callable[[str], bool] | None, 328) -> Callable[[str], bool]: 329 if stop is None: 330 331 def _dont_stop(_: str) -> bool: 332 return False 333 334 return _dont_stop 335 return stop 336 337 338def _prepare_max_steps(max_steps: int | None) -> int: 339 if max_steps is None: 340 return 2**32 341 return _guard_positive_int(max_steps) 342 343 344def _prepare_track_logprobs(track_logprobs: bool) -> int | None: # noqa: FBT001 345 return 0 if track_logprobs else None 346 347 348def _guard_positive_int(n: int) -> int: 349 if n < 1: 350 msg = f"Expected a positive integer, got: {n}." 351 raise ValueError(msg) 352 return n 353 354 355def _handle_failed_beam(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]: 356 if len(passing) == 0: 357 msg = "All live samples failed before any passed stop conditions." 358 msg += " Check compatibility of stop conditions or expand search." 359 raise RuntimeError(msg) 360 import warnings 361 362 msg = "All live samples failed before completing search," 363 msg += " but some completed samples have already passed stopping conditions." 364 msg += " Returning available passing samples." 365 warnings.warn(msg, stacklevel=2) 366 return passing 367 368 369def _handle_maxsteps(passing: list[ScoredItem[str]]) -> list[ScoredItem[str]]: 370 if len(passing) == 0: 371 msg = "Max steps reached, and no samples passed stop conditions." 372 raise RuntimeError(msg) 373 import warnings 374 375 msg = "Max steps reached before completing search," 376 msg += "but some samples have already passed stopping conditions." 377 msg += " Returning available passing samples." 378 warnings.warn(msg, stacklevel=2) 379 return passing 380 381 382_default_sampling_kwargs = { 383 "detokenize": True, 384 "ignore_eos": False, 385 "truncate_prompt_tokens": None, 386}
39def BestOfN( # noqa: PLR0913 40 *, 41 prompt: str, 42 llm: LanguageModel, 43 scorer: Scorer, 44 n: int = 1, 45 min_tokens: int = 0, 46 max_tokens: int | None = None, 47 stop_str: list[str] | str | None = None, 48 stop_token_ids: list[int] | str | None = None, 49 include_stop_str_in_output: bool = True, 50 track_logprobs: bool = False, 51 temperature: float = 1.0, 52 logits_processors: list[LogitsProcessor] | None = None, 53 seed: int | None = None, 54) -> list[ScoredItem[str]]: 55 """ 56 Generate `n` samples from the language model `llm` using the `scorer` to rank them. 57 See the [`vLLM.SamplingParams`](https://docs.vllm.ai/en/latest/dev/sampling_params.html) 58 docs to learn more about some of these parameters such as `logits_processors`. 59 60 Args: 61 prompt: The input prompt string. 62 llm: The language model to generate samples from. 63 scorer: The scorer to rank the samples. 64 n: The number of samples to generate. 65 min_tokens: The minimum number of tokens in each sample. 66 max_tokens: The maximum number of tokens in each sample. 67 stop_str: A string or list of strings that, if generated, will stop decoding. 68 stop_token_ids: A list of token IDs that, if generated, will stop decoding. 69 A string can also be passed, which will specify all token IDs that contain 70 that substring. 71 include_stop_str_in_output: Whether to include the stop string in the output. 72 track_logprobs: Whether to track log probabilities. This comes at a performance 73 cost, so it is off by default. In most cases, as you are alrady sampling 74 from the model, you do not want to double count the probabilities in the 75 scorer anyways. 76 temperature: The temperature for sampling. 77 logits_processors: A list of logits processors. 78 seed: The random seed. 79 80 Returns: 81 A list of `decoding.pmf.ScoredItem` objects sorted by the `scorer`. 82 83 Raises: 84 ValueError: If any of the argument configurations are invalid. 85 86 Examples: 87 ```python 88 from decoding.generators import BestOfN 89 from decoding.models import LanguageModel 90 from decoding.scorers import Scorer 91 92 llm = LanguageModel.from_id("gpt2") 93 scorer = Scorer.from_f_str_to_num(lambda x: -len(x)) 94 samples = BestOfN( 95 prompt="The", 96 llm=llm, 97 scorer=scorer, 98 n=20, 99 stop_str=".", 100 seed=42, 101 ) 102 assert len(samples) == 20 103 assert all(s.item.endswith(".") for s in samples) 104 assert all(s.score == -len(s.item) for s in samples) 105 assert samples[0].score >= samples[-1].score 106 ``` 107 108 """ 109 sampling_params = SamplingParams( 110 n=_guard_positive_int(n), 111 min_tokens=min_tokens, 112 max_tokens=max_tokens, 113 stop=stop_str, 114 stop_token_ids=_prepare_token_ids(stop_token_ids, llm=llm), 115 include_stop_str_in_output=include_stop_str_in_output, 116 logprobs=_prepare_track_logprobs(track_logprobs), 117 prompt_logprobs=_prepare_track_logprobs(track_logprobs), 118 temperature=temperature, 119 logits_processors=logits_processors, 120 seed=seed, 121 **_default_sampling_kwargs, # type: ignore[reportArgumentType] 122 ) 123 samples = _BestOfN([prompt], llm, scorer, sampling_params) 124 return sort_scored_items(samples)
Generate n
samples from the language model llm
using the scorer
to rank them.
See the vLLM.SamplingParams
docs to learn more about some of these parameters such as logits_processors
.
Arguments:
- prompt: The input prompt string.
- llm: The language model to generate samples from.
- scorer: The scorer to rank the samples.
- n: The number of samples to generate.
- min_tokens: The minimum number of tokens in each sample.
- max_tokens: The maximum number of tokens in each sample.
- stop_str: A string or list of strings that, if generated, will stop decoding.
- stop_token_ids: A list of token IDs that, if generated, will stop decoding. A string can also be passed, which will specify all token IDs that contain that substring.
- include_stop_str_in_output: Whether to include the stop string in the output.
- track_logprobs: Whether to track log probabilities. This comes at a performance cost, so it is off by default. In most cases, as you are alrady sampling from the model, you do not want to double count the probabilities in the scorer anyways.
- temperature: The temperature for sampling.
- logits_processors: A list of logits processors.
- seed: The random seed.
Returns:
A list of
decoding.pmf.ScoredItem
objects sorted by thescorer
.
Raises:
- ValueError: If any of the argument configurations are invalid.
Examples:
from decoding.generators import BestOfN from decoding.models import LanguageModel from decoding.scorers import Scorer llm = LanguageModel.from_id("gpt2") scorer = Scorer.from_f_str_to_num(lambda x: -len(x)) samples = BestOfN( prompt="The", llm=llm, scorer=scorer, n=20, stop_str=".", seed=42, ) assert len(samples) == 20 assert all(s.item.endswith(".") for s in samples) assert all(s.score == -len(s.item) for s in samples) assert samples[0].score >= samples[-1].score
127def TreeSearch( # noqa: PLR0913 128 *, 129 prompt: str, 130 llm: LanguageModel, 131 step_scorer: Scorer, 132 final_scorer: Scorer | None = None, 133 stop_cond_pass: Callable[[str], bool], 134 stop_cond_fail: Callable[[str], bool] | None = None, 135 n: int = 1, 136 beam_width: int = 1, 137 beam_factor: int = 1, 138 max_steps: int | None = None, 139 min_tokens_per_step: int = 0, 140 max_tokens_per_step: int | None = None, 141 sync_str: list[str] | str | None = None, 142 sync_token_ids: list[int] | str | None = None, 143 include_sync_str_in_output: bool = True, 144 track_logprobs: bool = False, 145 temperature: float = 1.0, 146 logits_processors: list[LogitsProcessor] | None = None, 147 seed: int | None = None, 148) -> list[ScoredItem[str]]: 149 """ 150 Generate `n` samples from the language model `llm` using the `step_scorer` to 151 rank them at each sync step and the `final_scorer` to rank the final beam. 152 153 Args: 154 prompt: The input prompt string. 155 llm: The language model to generate samples from. 156 step_scorer: The scorer to rank the samples at each sync step. 157 final_scorer: The scorer to rank the final beam. 158 stop_cond_pass: A function that returns `True` if the sample should pass. 159 This stops the sample from being extended. 160 stop_cond_fail: A function that returns `True` if the sample should fail. 161 This filters the sample from the live beam. 162 n: The number of passing samples to generate before returning. 163 beam_width: The width of the beam. This is the number of samples to 164 keep at each step. 165 beam_factor: The branching factor of the beam. This is the number of 166 new samples to generate from each live sample at each sync step. 167 max_steps: The maximum number of sync steps to take. 168 min_tokens_per_step: The minimum number of tokens in each step's extension. 169 max_tokens_per_step: The maximum number of tokens in each step's extension. 170 sync_str: A string or list of strings that, if generated, will stop extending 171 each sample in the live beam and await scoring, ranking, and filtering. 172 sync_token_ids: A list of token IDs that, if generated, will stop extending 173 each sample in the live beam and await scoring, ranking, and filtering. 174 A string can also be passed, which will specify all token IDs that contain 175 that substring. 176 include_sync_str_in_output: Whether to include the stop string in the output. 177 track_logprobs: Whether to track log probabilities. This comes at a performance 178 cost, so it is off by default. In most cases, as you are already sampling 179 from the model, you do not want to double count the probabilities in the 180 scorer anyways. 181 temperature: The temperature for sampling. 182 logits_processors: A list of logits processors. 183 NB: This is applied within each step as opposed to globally. 184 seed: The random seed. 185 186 Returns: 187 A list of `decoding.pmf.ScoredItem` objects sorted by the `final_scorer`. 188 189 Raises: 190 ValueError: If any of the argument configurations are invalid 191 RuntimeError: if all live samples in the beam fail, 192 or if max steps is reached before any samples pass. 193 194 Examples: 195 ```python 196 from decoding.generators import TreeSearch 197 from decoding.models import LanguageModel 198 from decoding.pmf import ScoredItem 199 from decoding.scorers import Scorer 200 201 def f(x): 202 if "." in x: 203 x = x.split(".")[0] + "." 204 return ScoredItem(item=x, score=-len(x)) 205 206 llm = LanguageModel.from_id("gpt2") 207 scorer = Scorer.from_f_str_to_sample(f) 208 samples = TreeSearch( 209 prompt="The", 210 sync_token_ids=" ", 211 stop_cond_pass=lambda x: x.endswith("."), 212 llm=llm, 213 step_scorer=scorer, 214 final_scorer=scorer, 215 n=3, 216 beam_width=50, 217 beam_factor=5, 218 seed=42, 219 ) 220 assert len(samples) == 3 221 assert all(s.item.endswith(".") for s in samples) 222 assert all(s.score == -len(s.item) for s in samples) 223 assert samples[0].score >= samples[-1].score 224 ``` 225 226 """ 227 if final_scorer is None: 228 final_scorer = step_scorer 229 search_params = _SearchParams( 230 n=_guard_positive_int(n), 231 width=_guard_positive_int(beam_width), 232 max_steps=_prepare_max_steps(max_steps), 233 stop_pass=_prepare_stop(stop_cond_pass), 234 stop_fail=_prepare_stop(stop_cond_fail), 235 ) 236 _validate_search_params(search_params) 237 sampling_params = SamplingParams( 238 n=_guard_positive_int(beam_factor), 239 min_tokens=min_tokens_per_step, 240 max_tokens=max_tokens_per_step, 241 stop=sync_str, 242 stop_token_ids=_prepare_token_ids(sync_token_ids, llm=llm), 243 include_stop_str_in_output=include_sync_str_in_output, 244 logprobs=_prepare_track_logprobs(track_logprobs), 245 prompt_logprobs=_prepare_track_logprobs(track_logprobs), 246 temperature=temperature, 247 logits_processors=logits_processors, 248 seed=seed, 249 **_default_sampling_kwargs, # type: ignore[reportArgumentType] 250 ) 251 samples = _TreeSearch([prompt], llm, step_scorer, search_params, sampling_params) 252 return sort_scored_items(final_scorer(LogPMF.from_samples(samples)))
Generate n
samples from the language model llm
using the step_scorer
to
rank them at each sync step and the final_scorer
to rank the final beam.
Arguments:
- prompt: The input prompt string.
- llm: The language model to generate samples from.
- step_scorer: The scorer to rank the samples at each sync step.
- final_scorer: The scorer to rank the final beam.
- stop_cond_pass: A function that returns
True
if the sample should pass. This stops the sample from being extended. - stop_cond_fail: A function that returns
True
if the sample should fail. This filters the sample from the live beam. - n: The number of passing samples to generate before returning.
- beam_width: The width of the beam. This is the number of samples to keep at each step.
- beam_factor: The branching factor of the beam. This is the number of new samples to generate from each live sample at each sync step.
- max_steps: The maximum number of sync steps to take.
- min_tokens_per_step: The minimum number of tokens in each step's extension.
- max_tokens_per_step: The maximum number of tokens in each step's extension.
- sync_str: A string or list of strings that, if generated, will stop extending each sample in the live beam and await scoring, ranking, and filtering.
- sync_token_ids: A list of token IDs that, if generated, will stop extending each sample in the live beam and await scoring, ranking, and filtering. A string can also be passed, which will specify all token IDs that contain that substring.
- include_sync_str_in_output: Whether to include the stop string in the output.
- track_logprobs: Whether to track log probabilities. This comes at a performance cost, so it is off by default. In most cases, as you are already sampling from the model, you do not want to double count the probabilities in the scorer anyways.
- temperature: The temperature for sampling.
- logits_processors: A list of logits processors. NB: This is applied within each step as opposed to globally.
- seed: The random seed.
Returns:
A list of
decoding.pmf.ScoredItem
objects sorted by thefinal_scorer
.
Raises:
- ValueError: If any of the argument configurations are invalid
- RuntimeError: if all live samples in the beam fail, or if max steps is reached before any samples pass.
Examples:
from decoding.generators import TreeSearch from decoding.models import LanguageModel from decoding.pmf import ScoredItem from decoding.scorers import Scorer def f(x): if "." in x: x = x.split(".")[0] + "." return ScoredItem(item=x, score=-len(x)) llm = LanguageModel.from_id("gpt2") scorer = Scorer.from_f_str_to_sample(f) samples = TreeSearch( prompt="The", sync_token_ids=" ", stop_cond_pass=lambda x: x.endswith("."), llm=llm, step_scorer=scorer, final_scorer=scorer, n=3, beam_width=50, beam_factor=5, seed=42, ) assert len(samples) == 3 assert all(s.item.endswith(".") for s in samples) assert all(s.score == -len(s.item) for s in samples) assert samples[0].score >= samples[-1].score