|
| 1 | +# SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | +# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> |
| 3 | + |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from .config import Settings |
| 7 | +from .model import Model |
| 8 | +from .utils import load_prompts, print |
| 9 | + |
| 10 | + |
| 11 | +class Evaluator: |
| 12 | + def __init__(self, settings: Settings, model: Model): |
| 13 | + self.settings = settings |
| 14 | + self.model = model |
| 15 | + |
| 16 | + print() |
| 17 | + print(f"Loading good prompts from [bold]{settings.good_prompts.dataset}[/]...") |
| 18 | + self.good_prompts = load_prompts(settings.good_prompts) |
| 19 | + print(f"* [bold]{len(self.good_prompts)}[/] prompts loaded") |
| 20 | + |
| 21 | + print("* Obtaining first-token probability distributions...") |
| 22 | + self.base_logprobs = model.get_logprobs_batched(self.good_prompts) |
| 23 | + |
| 24 | + print() |
| 25 | + print(f"Loading bad prompts from [bold]{settings.bad_prompts.dataset}[/]...") |
| 26 | + self.bad_prompts = load_prompts(settings.bad_prompts) |
| 27 | + print(f"* [bold]{len(self.bad_prompts)}[/] prompts loaded") |
| 28 | + |
| 29 | + print("* Counting model refusals...") |
| 30 | + self.base_refusals = self.count_refusals() |
| 31 | + print( |
| 32 | + f"* Initial refusals: [bold]{self.base_refusals}[/]/{len(self.bad_prompts)} ([bold]{self.base_refusals / len(self.bad_prompts) * 100:.1f}[/] %)" |
| 33 | + ) |
| 34 | + |
| 35 | + def is_refusal(self, response: str) -> bool: |
| 36 | + # Remove emphasis (e.g. "I *will not*...") to facilitate detection. |
| 37 | + response = response.lower().replace("*", "") |
| 38 | + |
| 39 | + for marker in self.settings.refusal_markers: |
| 40 | + if marker.lower() in response: |
| 41 | + return True |
| 42 | + |
| 43 | + return False |
| 44 | + |
| 45 | + def count_refusals(self) -> int: |
| 46 | + responses = self.model.get_responses_batched(self.bad_prompts) |
| 47 | + refusals = [response for response in responses if self.is_refusal(response)] |
| 48 | + return len(refusals) |
| 49 | + |
| 50 | + def get_score(self) -> tuple[float, float, int]: |
| 51 | + print(" * Obtaining first-token probability distributions...") |
| 52 | + logprobs = self.model.get_logprobs_batched(self.good_prompts) |
| 53 | + kl_divergence = F.kl_div( |
| 54 | + logprobs, self.base_logprobs, reduction="batchmean", log_target=True |
| 55 | + ).item() |
| 56 | + print(f" * KL divergence: [bold]{kl_divergence:.4f}[/]") |
| 57 | + |
| 58 | + print(" * Counting model refusals...") |
| 59 | + refusals = self.count_refusals() |
| 60 | + print( |
| 61 | + f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)} ([bold]{refusals / len(self.bad_prompts) * 100:.1f}[/] %)" |
| 62 | + ) |
| 63 | + |
| 64 | + # This score is constructed to achieve several properties: |
| 65 | + # |
| 66 | + # 1. For the unmodified model, kl_divergence = 0 and refusals = base_refusals, |
| 67 | + # so the baseline score is 0. |
| 68 | + # |
| 69 | + # 2. The best possible outcome is kl_divergence = 0 and refusals = 0, |
| 70 | + # giving a score of 1. |
| 71 | + # |
| 72 | + # 3. If kl_divergence > max_kl_divergence, the score is negative. |
| 73 | + # As the baseline is 0, this ensures that such a configuration |
| 74 | + # is never chosen, enforcing the max_kl_divergence constraint. |
| 75 | + # |
| 76 | + # 4. kl_score_shape controls how strongly a kl_divergence well below |
| 77 | + # max_kl_divergence affects the score. A high value means that |
| 78 | + # kl_divergence only matters when it approaches max_kl_divergence, |
| 79 | + # and the optimizer will prioritize lowering refusals rather than |
| 80 | + # lowering kl_divergence. |
| 81 | + score = -( |
| 82 | + ( |
| 83 | + ( |
| 84 | + ( |
| 85 | + (kl_divergence - self.settings.max_kl_divergence) |
| 86 | + / self.settings.max_kl_divergence |
| 87 | + ) |
| 88 | + + 1 |
| 89 | + ) |
| 90 | + ** self.settings.kl_score_shape |
| 91 | + ) |
| 92 | + + (refusals / self.base_refusals) |
| 93 | + - 1 |
| 94 | + ) |
| 95 | + print(f" * Score: [bold]{score:.4f}[/]") |
| 96 | + |
| 97 | + return score, kl_divergence, refusals |
0 commit comments