@@ -150,29 +150,46 @@ def __init__(self, sequence: str):
150150 self .sequence = sequence
151151 self .mutations = self .initialize (sequence )
152152
153- def initialize (self , sequence : str ) -> dict [int , list [str ]]:
153+ def initialize (self , sequence : str ) -> dict [int , set [str ]]:
154154 """Initialize with no changes allowed to the sequence."""
155- return {i : [ aa ] for i , aa in enumerate (sequence , start = 1 )}
155+ return {i : { aa } for i , aa in enumerate (sequence , start = 1 )}
156156
157- def allow (self , positions : int | list [int ], amino_acids : list [str ] | str ) -> None :
157+ def allow (
158+ self ,
159+ amino_acids : list [str ] | str | None = None ,
160+ positions : int | list [int ] | None = None ,
161+ ) -> None :
158162 """Allow specific amino acids at given positions."""
159163 if isinstance (positions , int ):
160164 positions = [positions ]
165+ elif positions is None :
166+ positions = [i + 1 for i in range (len (self .sequence ))]
161167 if isinstance (amino_acids , str ):
162168 amino_acids = list (amino_acids )
169+ elif amino_acids is None :
170+ amino_acids = list (self .sequence )
163171
164172 for position in positions :
165173 if position in self .mutations :
166- self .mutations [position ].extend (amino_acids )
174+ for aa in amino_acids :
175+ self .mutations [position ].add (aa )
167176 else :
168- self .mutations [position ] = amino_acids
177+ self .mutations [position ] = set ( amino_acids )
169178
170- def remove (self , positions : int | list [int ], amino_acids : list [str ] | str ) -> None :
179+ def remove (
180+ self ,
181+ amino_acids : list [str ] | str | None = None ,
182+ positions : int | list [int ] | None = None ,
183+ ) -> None :
171184 """Remove specific amino acids from being allowed at given positions."""
172185 if isinstance (positions , int ):
173186 positions = [positions ]
187+ elif positions is None :
188+ positions = [i + 1 for i in range (len (self .sequence ))]
174189 if isinstance (amino_acids , str ):
175190 amino_acids = list (amino_acids )
191+ elif amino_acids is None :
192+ amino_acids = list (self .sequence )
176193
177194 for position in positions :
178195 if position in self .mutations :
@@ -182,4 +199,4 @@ def remove(self, positions: int | list[int], amino_acids: list[str] | str) -> No
182199
183200 def as_dict (self ) -> dict [int , list [str ]]:
184201 """Convert the internal mutations representation into a dictionary."""
185- return self .mutations
202+ return { i : list ( aa ) for i , aa in self .mutations . items ()}
0 commit comments