@@ -40,13 +40,13 @@ def parse_args():
4040 )
4141
4242 # Input/Output
43- parser .add_argument ("--input" , "-i" , type = str , required = True ,
43+ parser .add_argument ("--input" , "-i" , type = str , default = None ,
4444 help = "Input FASTA file with single RNA sequence" )
45- parser .add_argument ("--output" , "-o" , type = str , required = True ,
45+ parser .add_argument ("--output" , "-o" , type = str , default = None ,
4646 help = "Output FASTA file for results" )
4747 parser .add_argument ("--config" , "-c" , type = str , default = None ,
4848 help = "YAML configuration file (overrides CLI args)" )
49- parser .add_argument ("--checkpoint" , type = str , required = True ,
49+ parser .add_argument ("--checkpoint" , type = str , default = None ,
5050 help = "Path to model checkpoint directory" )
5151
5252 # RNA Type and Species Conditions
@@ -99,7 +99,15 @@ def parse_args():
9999 parser .add_argument ("--verbose" , action = "store_true" ,
100100 help = "Print verbose output" )
101101
102- return parser .parse_args ()
102+ args = parser .parse_args ()
103+
104+ # Get parser defaults for config merging
105+ parser_defaults = {}
106+ for action in parser ._actions :
107+ if action .dest != 'help' :
108+ parser_defaults [action .dest ] = action .default
109+
110+ return args , parser_defaults
103111
104112
105113def load_config (config_path : str ) -> dict :
@@ -108,11 +116,20 @@ def load_config(config_path: str) -> dict:
108116 return yaml .safe_load (f )
109117
110118
111- def merge_config_with_args (config : dict , args ) -> None :
112- """Merge YAML config with command line arguments."""
119+ def merge_config_with_args (config : dict , args , parser_defaults : dict ) -> None :
120+ """Merge YAML config with command line arguments.
121+
122+ Priority: CLI args > Config file > Parser defaults
123+ """
113124 for key , value in config .items ():
114- if value is not None and getattr (args , key , None ) is None :
115- setattr (args , key , value )
125+ if value is not None :
126+ # Check if this argument was explicitly provided on command line
127+ current_value = getattr (args , key , None )
128+ default_value = parser_defaults .get (key )
129+
130+ # Only override if the current value is the default (not explicitly set by user)
131+ if current_value == default_value :
132+ setattr (args , key , value )
116133
117134
118135def parse_mutate_positions (positions_str : str , seq_length : int ) -> List [int ]:
@@ -139,49 +156,77 @@ def parse_mutate_range(range_str: str, seq_length: int) -> List[int]:
139156 return list (range (start , end ))
140157
141158
142- def generate_mutations (sequence : str , positions : List [ int ] ) -> List [Tuple [ str , List [ Tuple [ int , str , str ]]] ]:
143- """Generate point mutation candidates .
159+ def generate_mutations_single_position (sequence : str , position : int ) -> List [str ]:
160+ """Generate mutations for a single position .
144161
145162 Args:
146- sequence: Input RNA sequence (e.g., "AUCG...")
147- positions: Positions to mutate
163+ sequence: Input RNA sequence
164+ position: Position to mutate
148165
149166 Returns:
150- List of (mutated_sequence, mutation_records)
151- - mutation_records: [(position, original_base, new_base), ...]
167+ List of 4 mutated sequences (including original base)
152168 """
153169 bases = ['A' , 'U' , 'C' , 'G' ]
154170 candidates = []
155- mutation_records = []
156-
157- # Generate all 4^len(positions) combinations
158- if len (positions ) == 1 :
159- # Single position - generate 4 candidates (one for each base)
160- pos = positions [0 ]
161- orig_base = sequence [pos ]
162- for new_base in bases :
163- if new_base != orig_base :
164- mut_seq = sequence [:pos ] + new_base + sequence [pos + 1 :]
165- candidates .append (mut_seq )
166- mutation_records .append ([(pos , orig_base , new_base )])
167- else :
168- # Multiple positions - use recursion to generate combinations
169- def generate_recursive (idx : int , current_seq : str , current_records : List [Tuple [int , str , str ]]):
170- if idx == len (positions ):
171- candidates .append (current_seq )
172- mutation_records .append (current_records .copy ())
173- return
174-
175- pos = positions [idx ]
176- orig_base = sequence [pos ]
177- for new_base in bases :
178- if new_base != orig_base :
179- new_seq = current_seq [:pos ] + new_base + current_seq [pos + 1 :]
180- generate_recursive (idx + 1 , new_seq , current_records + [(pos , orig_base , new_base )])
181-
182- generate_recursive (0 , sequence , [])
183-
184- return candidates , mutation_records
171+
172+ for new_base in bases :
173+ mut_seq = sequence [:position ] + new_base + sequence [position + 1 :]
174+ candidates .append (mut_seq )
175+
176+ return candidates
177+
178+
179+ def generate_mutations_with_beam_search (
180+ initial_sequence : str ,
181+ positions : List [int ],
182+ beam_width : int ,
183+ score_fn
184+ ) -> List [str ]:
185+ """Generate mutations with incremental beam search to avoid exponential explosion.
186+
187+ Args:
188+ initial_sequence: Starting RNA sequence
189+ positions: Positions to mutate
190+ beam_width: Maximum number of candidates to keep after each position
191+ score_fn: Function to score candidates, takes List[str] and returns List[float]
192+
193+ Returns:
194+ List of final candidate sequences (up to beam_width)
195+ """
196+ if not positions :
197+ return [initial_sequence ]
198+
199+ # Start with the initial sequence
200+ current_candidates = [initial_sequence ]
201+
202+ # Process each mutation position incrementally
203+ for pos_idx , position in enumerate (positions ):
204+ print (f" Processing mutation position { pos_idx + 1 } /{ len (positions )} : { position } " )
205+
206+ # Generate mutations for all current candidates at this position
207+ next_candidates = []
208+ for candidate in current_candidates :
209+ mutations = generate_mutations_single_position (candidate , position )
210+ next_candidates .extend (mutations )
211+
212+ print (f" Generated { len (next_candidates )} candidates" )
213+
214+ # Apply beam search if we exceed beam_width
215+ if len (next_candidates ) > beam_width :
216+ # Score all candidates
217+ scores = score_fn (next_candidates )
218+
219+ # Sort by score and keep top beam_width
220+ scored_pairs = list (zip (next_candidates , scores ))
221+ scored_pairs .sort (key = lambda x : x [1 ], reverse = True )
222+ current_candidates = [seq for seq , _ in scored_pairs [:beam_width ]]
223+
224+ print (f" Beam search: kept top { len (current_candidates )} candidates" )
225+ else :
226+ current_candidates = next_candidates
227+ print (f" Kept all { len (current_candidates )} candidates (< beam_width)" )
228+
229+ return current_candidates
185230
186231
187232def dynamic_beam_search (candidates : List [str ], scores : List [float ], beam_width : int ,
@@ -384,17 +429,23 @@ def select_mutate_positions(self) -> List[int]:
384429 positions = random .sample (range (seq_length ), min (self .args .mutations_per_iter , seq_length ))
385430 return sorted (positions )
386431
387- def generate_candidate_sequences (self , positions : List [int ]) -> Tuple [ List [str ], List [ List ] ]:
388- """Generate candidate mutated sequences.
432+ def generate_candidate_sequences (self , positions : List [int ]) -> List [str ]:
433+ """Generate candidate mutated sequences with incremental beam search .
389434
390435 Args:
391436 positions: Positions to mutate
392437
393438 Returns:
394- Tuple of (candidate_sequences, mutation_records)
439+ List of candidate sequences
395440 """
396- candidates , records = generate_mutations (self .input_sequence , positions )
397- return candidates , records
441+ # Use the new beam search approach
442+ candidates = generate_mutations_with_beam_search (
443+ initial_sequence = self .input_sequence ,
444+ positions = positions ,
445+ beam_width = self .args .beam_width ,
446+ score_fn = self .score_candidates
447+ )
448+ return candidates
398449
399450 def score_candidates (self , candidates : List [str ]) -> List [float ]:
400451 """Score candidates using LLM.
@@ -439,6 +490,9 @@ def run(self):
439490 self .current_sequences = [self .input_sequence ]
440491 self .current_scores = [0.0 ] # Initial score is 0 (not evaluated)
441492
493+ # Candidate pool for collecting diverse sequences
494+ candidate_pool = [] # List of (sequence, score) tuples
495+
442496 temperature = self .args .T_init
443497 best_sequence = self .input_sequence
444498 best_score = float ('-inf' )
@@ -451,9 +505,9 @@ def run(self):
451505 positions = self .select_mutate_positions ()
452506 print (f"Mutating positions: { positions } " )
453507
454- # Step 2: Generate candidate sequences
455- candidates , mutation_records = self .generate_candidate_sequences (positions )
456- print (f"Generated { len (candidates )} candidates" )
508+ # Step 2: Generate candidate sequences with incremental beam search
509+ candidates = self .generate_candidate_sequences (positions )
510+ print (f"Generated { len (candidates )} candidates (after beam search) " )
457511
458512 # Step 3: Score candidates with LLM
459513 scores = self .score_candidates (candidates )
@@ -491,16 +545,22 @@ def run(self):
491545 accepted_sequences .append (seq )
492546 accepted_scores .append (combined_score )
493547
548+ # Add to candidate pool
549+ candidate_pool .append ((seq , combined_score ))
550+
494551 # If no sequences accepted, keep the best beam candidate
495552 if not accepted_sequences :
496553 best_seq , best_s = beam_candidates [0 ]
554+ best_mfe = mfe_values [0 ]
555+ combined = best_s - 0.1 * best_mfe
497556 accepted_sequences = [best_seq ]
498- accepted_scores = [best_s ]
557+ accepted_scores = [combined ]
558+ candidate_pool .append ((best_seq , combined ))
499559 print (" (SA rejected all, keeping best beam candidate)" )
500560
501- # Update current pool
502- self .current_sequences = accepted_sequences [:self .args .output_count ]
503- self .current_scores = accepted_scores [:self .args .output_count ]
561+ # Update current pool for next iteration
562+ self .current_sequences = accepted_sequences [:self .args .beam_width ]
563+ self .current_scores = accepted_scores [:self .args .beam_width ]
504564
505565 # Track best
506566 if self .current_scores :
@@ -513,6 +573,8 @@ def run(self):
513573 best_sequence = current_best_seq
514574 print (f" New best: score={ best_score :.4f} " )
515575
576+ print (f" Candidate pool size: { len (candidate_pool )} " )
577+
516578 # Log iteration
517579 self .log_entries .append ({
518580 'iteration' : iteration + 1 ,
@@ -525,6 +587,34 @@ def run(self):
525587 # Cool down
526588 temperature = max (self .args .T_min , temperature * self .args .cooling_rate )
527589
590+ # Final: Select top output_count sequences from candidate pool
591+ print (f"\n { '=' * 60 } " )
592+ print ("Evolution Complete" )
593+ print (f"{ '=' * 60 } " )
594+ print (f"Best sequence score: { best_score :.4f} " )
595+ print (f"Total candidates in pool: { len (candidate_pool )} " )
596+
597+ # Sort by score and select top sequences
598+ candidate_pool .sort (key = lambda x : x [1 ], reverse = True )
599+
600+ # Remove duplicates while preserving order
601+ seen = set ()
602+ unique_candidates = []
603+ for seq , score in candidate_pool :
604+ if seq not in seen :
605+ seen .add (seq )
606+ unique_candidates .append ((seq , score ))
607+
608+ # Select top output_count
609+ final_candidates = unique_candidates [:self .args .output_count ]
610+
611+ self .current_sequences = [seq for seq , _ in final_candidates ]
612+ self .current_scores = [score for _ , score in final_candidates ]
613+
614+ print (f"Selected { len (self .current_sequences )} unique sequences for output" )
615+
616+ return self .current_sequences , self .current_scores
617+
528618 print (f"\n { '=' * 60 } " )
529619 print ("Evolution Complete" )
530620 print (f"{ '=' * 60 } " )
@@ -562,12 +652,20 @@ def save_log(self):
562652
563653def main ():
564654 """Main entry point."""
565- args = parse_args ()
655+ args , parser_defaults = parse_args ()
566656
567657 # Load config if provided
568658 if args .config :
569659 config = load_config (args .config )
570- merge_config_with_args (config , args )
660+ merge_config_with_args (config , args , parser_defaults )
661+
662+ # Validate required arguments
663+ if not args .input :
664+ raise ValueError ("--input is required (either via CLI or config file)" )
665+ if not args .output :
666+ raise ValueError ("--output is required (either via CLI or config file)" )
667+ if not args .checkpoint :
668+ raise ValueError ("--checkpoint is required (either via CLI or config file)" )
571669
572670 # Validate arguments
573671 if args .rna_type :
0 commit comments