Skip to content

Commit 58a8cc2

Browse files
committed
solving problems for directied_evolution: correct the parameter iteration
1 parent 6d61d60 commit 58a8cc2

1 file changed

Lines changed: 156 additions & 58 deletions

File tree

tools/directed_evolution.py

Lines changed: 156 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def parse_args():
4040
)
4141

4242
# Input/Output
43-
parser.add_argument("--input", "-i", type=str, required=True,
43+
parser.add_argument("--input", "-i", type=str, default=None,
4444
help="Input FASTA file with single RNA sequence")
45-
parser.add_argument("--output", "-o", type=str, required=True,
45+
parser.add_argument("--output", "-o", type=str, default=None,
4646
help="Output FASTA file for results")
4747
parser.add_argument("--config", "-c", type=str, default=None,
4848
help="YAML configuration file (overrides CLI args)")
49-
parser.add_argument("--checkpoint", type=str, required=True,
49+
parser.add_argument("--checkpoint", type=str, default=None,
5050
help="Path to model checkpoint directory")
5151

5252
# RNA Type and Species Conditions
@@ -99,7 +99,15 @@ def parse_args():
9999
parser.add_argument("--verbose", action="store_true",
100100
help="Print verbose output")
101101

102-
return parser.parse_args()
102+
args = parser.parse_args()
103+
104+
# Get parser defaults for config merging
105+
parser_defaults = {}
106+
for action in parser._actions:
107+
if action.dest != 'help':
108+
parser_defaults[action.dest] = action.default
109+
110+
return args, parser_defaults
103111

104112

105113
def load_config(config_path: str) -> dict:
@@ -108,11 +116,20 @@ def load_config(config_path: str) -> dict:
108116
return yaml.safe_load(f)
109117

110118

111-
def merge_config_with_args(config: dict, args) -> None:
112-
"""Merge YAML config with command line arguments."""
119+
def merge_config_with_args(config: dict, args, parser_defaults: dict) -> None:
120+
"""Merge YAML config with command line arguments.
121+
122+
Priority: CLI args > Config file > Parser defaults
123+
"""
113124
for key, value in config.items():
114-
if value is not None and getattr(args, key, None) is None:
115-
setattr(args, key, value)
125+
if value is not None:
126+
# Check if this argument was explicitly provided on command line
127+
current_value = getattr(args, key, None)
128+
default_value = parser_defaults.get(key)
129+
130+
# Only override if the current value is the default (not explicitly set by user)
131+
if current_value == default_value:
132+
setattr(args, key, value)
116133

117134

118135
def parse_mutate_positions(positions_str: str, seq_length: int) -> List[int]:
@@ -139,49 +156,77 @@ def parse_mutate_range(range_str: str, seq_length: int) -> List[int]:
139156
return list(range(start, end))
140157

141158

142-
def generate_mutations(sequence: str, positions: List[int]) -> List[Tuple[str, List[Tuple[int, str, str]]]]:
143-
"""Generate point mutation candidates.
159+
def generate_mutations_single_position(sequence: str, position: int) -> List[str]:
160+
"""Generate mutations for a single position.
144161
145162
Args:
146-
sequence: Input RNA sequence (e.g., "AUCG...")
147-
positions: Positions to mutate
163+
sequence: Input RNA sequence
164+
position: Position to mutate
148165
149166
Returns:
150-
List of (mutated_sequence, mutation_records)
151-
- mutation_records: [(position, original_base, new_base), ...]
167+
List of 4 mutated sequences (including original base)
152168
"""
153169
bases = ['A', 'U', 'C', 'G']
154170
candidates = []
155-
mutation_records = []
156-
157-
# Generate all 4^len(positions) combinations
158-
if len(positions) == 1:
159-
# Single position - generate 4 candidates (one for each base)
160-
pos = positions[0]
161-
orig_base = sequence[pos]
162-
for new_base in bases:
163-
if new_base != orig_base:
164-
mut_seq = sequence[:pos] + new_base + sequence[pos+1:]
165-
candidates.append(mut_seq)
166-
mutation_records.append([(pos, orig_base, new_base)])
167-
else:
168-
# Multiple positions - use recursion to generate combinations
169-
def generate_recursive(idx: int, current_seq: str, current_records: List[Tuple[int, str, str]]):
170-
if idx == len(positions):
171-
candidates.append(current_seq)
172-
mutation_records.append(current_records.copy())
173-
return
174-
175-
pos = positions[idx]
176-
orig_base = sequence[pos]
177-
for new_base in bases:
178-
if new_base != orig_base:
179-
new_seq = current_seq[:pos] + new_base + current_seq[pos+1:]
180-
generate_recursive(idx + 1, new_seq, current_records + [(pos, orig_base, new_base)])
181-
182-
generate_recursive(0, sequence, [])
183-
184-
return candidates, mutation_records
171+
172+
for new_base in bases:
173+
mut_seq = sequence[:position] + new_base + sequence[position+1:]
174+
candidates.append(mut_seq)
175+
176+
return candidates
177+
178+
179+
def generate_mutations_with_beam_search(
180+
initial_sequence: str,
181+
positions: List[int],
182+
beam_width: int,
183+
score_fn
184+
) -> List[str]:
185+
"""Generate mutations with incremental beam search to avoid exponential explosion.
186+
187+
Args:
188+
initial_sequence: Starting RNA sequence
189+
positions: Positions to mutate
190+
beam_width: Maximum number of candidates to keep after each position
191+
score_fn: Function to score candidates, takes List[str] and returns List[float]
192+
193+
Returns:
194+
List of final candidate sequences (up to beam_width)
195+
"""
196+
if not positions:
197+
return [initial_sequence]
198+
199+
# Start with the initial sequence
200+
current_candidates = [initial_sequence]
201+
202+
# Process each mutation position incrementally
203+
for pos_idx, position in enumerate(positions):
204+
print(f" Processing mutation position {pos_idx + 1}/{len(positions)}: {position}")
205+
206+
# Generate mutations for all current candidates at this position
207+
next_candidates = []
208+
for candidate in current_candidates:
209+
mutations = generate_mutations_single_position(candidate, position)
210+
next_candidates.extend(mutations)
211+
212+
print(f" Generated {len(next_candidates)} candidates")
213+
214+
# Apply beam search if we exceed beam_width
215+
if len(next_candidates) > beam_width:
216+
# Score all candidates
217+
scores = score_fn(next_candidates)
218+
219+
# Sort by score and keep top beam_width
220+
scored_pairs = list(zip(next_candidates, scores))
221+
scored_pairs.sort(key=lambda x: x[1], reverse=True)
222+
current_candidates = [seq for seq, _ in scored_pairs[:beam_width]]
223+
224+
print(f" Beam search: kept top {len(current_candidates)} candidates")
225+
else:
226+
current_candidates = next_candidates
227+
print(f" Kept all {len(current_candidates)} candidates (< beam_width)")
228+
229+
return current_candidates
185230

186231

187232
def dynamic_beam_search(candidates: List[str], scores: List[float], beam_width: int,
@@ -384,17 +429,23 @@ def select_mutate_positions(self) -> List[int]:
384429
positions = random.sample(range(seq_length), min(self.args.mutations_per_iter, seq_length))
385430
return sorted(positions)
386431

387-
def generate_candidate_sequences(self, positions: List[int]) -> Tuple[List[str], List[List]]:
388-
"""Generate candidate mutated sequences.
432+
def generate_candidate_sequences(self, positions: List[int]) -> List[str]:
433+
"""Generate candidate mutated sequences with incremental beam search.
389434
390435
Args:
391436
positions: Positions to mutate
392437
393438
Returns:
394-
Tuple of (candidate_sequences, mutation_records)
439+
List of candidate sequences
395440
"""
396-
candidates, records = generate_mutations(self.input_sequence, positions)
397-
return candidates, records
441+
# Use the new beam search approach
442+
candidates = generate_mutations_with_beam_search(
443+
initial_sequence=self.input_sequence,
444+
positions=positions,
445+
beam_width=self.args.beam_width,
446+
score_fn=self.score_candidates
447+
)
448+
return candidates
398449

399450
def score_candidates(self, candidates: List[str]) -> List[float]:
400451
"""Score candidates using LLM.
@@ -439,6 +490,9 @@ def run(self):
439490
self.current_sequences = [self.input_sequence]
440491
self.current_scores = [0.0] # Initial score is 0 (not evaluated)
441492

493+
# Candidate pool for collecting diverse sequences
494+
candidate_pool = [] # List of (sequence, score) tuples
495+
442496
temperature = self.args.T_init
443497
best_sequence = self.input_sequence
444498
best_score = float('-inf')
@@ -451,9 +505,9 @@ def run(self):
451505
positions = self.select_mutate_positions()
452506
print(f"Mutating positions: {positions}")
453507

454-
# Step 2: Generate candidate sequences
455-
candidates, mutation_records = self.generate_candidate_sequences(positions)
456-
print(f"Generated {len(candidates)} candidates")
508+
# Step 2: Generate candidate sequences with incremental beam search
509+
candidates = self.generate_candidate_sequences(positions)
510+
print(f"Generated {len(candidates)} candidates (after beam search)")
457511

458512
# Step 3: Score candidates with LLM
459513
scores = self.score_candidates(candidates)
@@ -491,16 +545,22 @@ def run(self):
491545
accepted_sequences.append(seq)
492546
accepted_scores.append(combined_score)
493547

548+
# Add to candidate pool
549+
candidate_pool.append((seq, combined_score))
550+
494551
# If no sequences accepted, keep the best beam candidate
495552
if not accepted_sequences:
496553
best_seq, best_s = beam_candidates[0]
554+
best_mfe = mfe_values[0]
555+
combined = best_s - 0.1 * best_mfe
497556
accepted_sequences = [best_seq]
498-
accepted_scores = [best_s]
557+
accepted_scores = [combined]
558+
candidate_pool.append((best_seq, combined))
499559
print(" (SA rejected all, keeping best beam candidate)")
500560

501-
# Update current pool
502-
self.current_sequences = accepted_sequences[:self.args.output_count]
503-
self.current_scores = accepted_scores[:self.args.output_count]
561+
# Update current pool for next iteration
562+
self.current_sequences = accepted_sequences[:self.args.beam_width]
563+
self.current_scores = accepted_scores[:self.args.beam_width]
504564

505565
# Track best
506566
if self.current_scores:
@@ -513,6 +573,8 @@ def run(self):
513573
best_sequence = current_best_seq
514574
print(f" New best: score={best_score:.4f}")
515575

576+
print(f" Candidate pool size: {len(candidate_pool)}")
577+
516578
# Log iteration
517579
self.log_entries.append({
518580
'iteration': iteration + 1,
@@ -525,6 +587,34 @@ def run(self):
525587
# Cool down
526588
temperature = max(self.args.T_min, temperature * self.args.cooling_rate)
527589

590+
# Final: Select top output_count sequences from candidate pool
591+
print(f"\n{'='*60}")
592+
print("Evolution Complete")
593+
print(f"{'='*60}")
594+
print(f"Best sequence score: {best_score:.4f}")
595+
print(f"Total candidates in pool: {len(candidate_pool)}")
596+
597+
# Sort by score and select top sequences
598+
candidate_pool.sort(key=lambda x: x[1], reverse=True)
599+
600+
# Remove duplicates while preserving order
601+
seen = set()
602+
unique_candidates = []
603+
for seq, score in candidate_pool:
604+
if seq not in seen:
605+
seen.add(seq)
606+
unique_candidates.append((seq, score))
607+
608+
# Select top output_count
609+
final_candidates = unique_candidates[:self.args.output_count]
610+
611+
self.current_sequences = [seq for seq, _ in final_candidates]
612+
self.current_scores = [score for _, score in final_candidates]
613+
614+
print(f"Selected {len(self.current_sequences)} unique sequences for output")
615+
616+
return self.current_sequences, self.current_scores
617+
528618
print(f"\n{'='*60}")
529619
print("Evolution Complete")
530620
print(f"{'='*60}")
@@ -562,12 +652,20 @@ def save_log(self):
562652

563653
def main():
564654
"""Main entry point."""
565-
args = parse_args()
655+
args, parser_defaults = parse_args()
566656

567657
# Load config if provided
568658
if args.config:
569659
config = load_config(args.config)
570-
merge_config_with_args(config, args)
660+
merge_config_with_args(config, args, parser_defaults)
661+
662+
# Validate required arguments
663+
if not args.input:
664+
raise ValueError("--input is required (either via CLI or config file)")
665+
if not args.output:
666+
raise ValueError("--output is required (either via CLI or config file)")
667+
if not args.checkpoint:
668+
raise ValueError("--checkpoint is required (either via CLI or config file)")
571669

572670
# Validate arguments
573671
if args.rna_type:

0 commit comments

Comments
 (0)