Skip to content

Commit 5ba291f

Browse files
committed
Modified the signature of get_sa_coefficients()
1 parent 74c1af6 commit 5ba291f

File tree

5 files changed

+75
-9
lines changed

5 files changed

+75
-9
lines changed

t3/simulate/adapter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from abc import ABC, abstractmethod
8+
from typing import Optional
89

910

1011
class SimulateAdapter(ABC):
@@ -27,7 +28,12 @@ def simulate(self):
2728
pass
2829

2930
@abstractmethod
30-
def get_sa_coefficients(self):
31+
def get_sa_coefficients(self,
32+
top_SA_species: int = 10,
33+
top_SA_reactions: int = 10,
34+
max_workers: int = 24,
35+
save_yaml: bool = True,
36+
) -> Optional[dict]:
3137
"""
3238
Obtain the sensitivity analysis coefficients.
3339

t3/simulate/cantera_constantHP.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from rmgpy.tools.canteramodel import generate_cantera_conditions
1111
from rmgpy.tools.data import GenericData
1212

13+
from arc.common import save_yaml_file
14+
1315
from t3.common import get_observable_label_from_header, get_parameter_from_header
1416
from t3.logger import Logger
1517
from t3.simulate.adapter import SimulateAdapter
@@ -382,12 +384,23 @@ def simulate(self):
382384

383385
self.all_data.append((time, condition_data, reaction_sensitivity_data, thermodynamic_sensitivity_data))
384386

385-
def get_sa_coefficients(self):
387+
def get_sa_coefficients(self,
388+
top_SA_species: int = 10,
389+
top_SA_reactions: int = 10,
390+
max_workers: int = 24,
391+
save_yaml: bool = True,
392+
) -> Optional[dict]:
386393
"""
387394
Obtain the SA coefficients.
388395
396+
Args:
397+
top_SA_species (int, optional): The number of top sensitive species to return.
398+
top_SA_reactions (int, optional): The number of top sensitive reactions to return.
399+
max_workers (int, optional): The maximal number of workers to use for parallel processing.
400+
save_yaml (bool, optional): Save the SA dictionary to a YAML file.
401+
389402
Returns:
390-
sa_dict (dict): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
403+
sa_dict (Optional[dict]): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
391404
"""
392405
sa_dict = {'kinetics': dict(), 'thermo': dict(), 'time': list()}
393406

@@ -411,7 +424,8 @@ def get_sa_coefficients(self):
411424
sa_dict['thermo'][observable_label] = dict()
412425
parameter = get_parameter_from_header(spc)
413426
sa_dict['thermo'][observable_label][parameter] = spc.data
414-
427+
if save_yaml:
428+
save_yaml_file(path=self.paths['SA dict'], content=sa_dict)
415429
return sa_dict
416430

417431
def get_idt_by_T(self):

t3/simulate/cantera_constantTP.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from rmgpy.tools.canteramodel import generate_cantera_conditions
1111
from rmgpy.tools.data import GenericData
1212

13+
from arc.common import save_yaml_file
14+
1315
from t3.logger import Logger
1416
from t3.simulate.adapter import SimulateAdapter
1517
from t3.simulate.factory import register_simulate_adapter
@@ -381,12 +383,23 @@ def simulate(self):
381383

382384
self.all_data.append((time, condition_data, reaction_sensitivity_data, thermodynamic_sensitivity_data))
383385

384-
def get_sa_coefficients(self):
386+
def get_sa_coefficients(self,
387+
top_SA_species: int = 10,
388+
top_SA_reactions: int = 10,
389+
max_workers: int = 24,
390+
save_yaml: bool = True,
391+
) -> Optional[dict]:
385392
"""
386393
Obtain the SA coefficients.
387394
395+
Args:
396+
top_SA_species (int, optional): The number of top sensitive species to return.
397+
top_SA_reactions (int, optional): The number of top sensitive reactions to return.
398+
max_workers (int, optional): The maximal number of workers to use for parallel processing.
399+
save_yaml (bool, optional): Save the SA dictionary to a YAML file.
400+
388401
Returns:
389-
sa_dict (dict): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
402+
sa_dict (Optional[dict]): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
390403
"""
391404
sa_dict = {'kinetics': dict(), 'thermo': dict(), 'time': list()}
392405

@@ -415,6 +428,9 @@ def get_sa_coefficients(self):
415428
parameter = spc.label.split('[')[2].split(']')[0]
416429
sa_dict['thermo'][observable_label][parameter] = spc.data
417430

431+
if save_yaml:
432+
save_yaml_file(path=self.paths['SA dict'], content=sa_dict)
433+
418434
return sa_dict
419435

420436
def get_idt_by_T(self):

t3/simulate/cantera_constantUV.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from rmgpy.tools.canteramodel import generate_cantera_conditions
1111
from rmgpy.tools.data import GenericData
1212

13+
from arc.common import save_yaml_file
14+
1315
from t3.logger import Logger
1416
from t3.simulate.adapter import SimulateAdapter
1517
from t3.simulate.factory import register_simulate_adapter
@@ -381,12 +383,23 @@ def simulate(self):
381383

382384
self.all_data.append((time, condition_data, reaction_sensitivity_data, thermodynamic_sensitivity_data))
383385

384-
def get_sa_coefficients(self):
386+
def get_sa_coefficients(self,
387+
top_SA_species: int = 10,
388+
top_SA_reactions: int = 10,
389+
max_workers: int = 24,
390+
save_yaml: bool = True,
391+
) -> Optional[dict]:
385392
"""
386393
Obtain the SA coefficients.
387394
395+
Args:
396+
top_SA_species (int, optional): The number of top sensitive species to return.
397+
top_SA_reactions (int, optional): The number of top sensitive reactions to return.
398+
max_workers (int, optional): The maximal number of workers to use for parallel processing.
399+
save_yaml (bool, optional): Save the SA dictionary to a YAML file.
400+
388401
Returns:
389-
sa_dict (dict): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
402+
sa_dict (Optional[dict]): a SA dictionary, whose structure is given in the docstring for T3/t3/main.py
390403
"""
391404
sa_dict = {'kinetics': dict(), 'thermo': dict(), 'time': list()}
392405

@@ -415,6 +428,8 @@ def get_sa_coefficients(self):
415428
parameter = spc.label.split('[')[2].split(']')[0]
416429
sa_dict['thermo'][observable_label][parameter] = spc.data
417430

431+
if save_yaml:
432+
save_yaml_file(path=self.paths['SA dict'], content=sa_dict)
418433
return sa_dict
419434

420435
def get_idt_by_T(self):

t3/simulate/rmg_constantTP.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from rmgpy.tools.loader import load_rmg_py_job
1919
from rmgpy.tools.plot import plot_sensitivity
2020

21+
from arc.common import save_yaml_file
22+
2123
from t3.common import get_chem_to_rmg_rxn_index_map, get_species_by_label, get_values_within_range, \
2224
get_observable_label_from_header, get_parameter_from_header, time_lapse
2325
from t3.simulate.adapter import SimulateAdapter
@@ -224,10 +226,21 @@ def simulate(self):
224226

225227
self.logger.info(f'Simulation via RMG completed, execution time: {time_lapse(tic)}')
226228

227-
def get_sa_coefficients(self) -> Optional[dict]:
229+
def get_sa_coefficients(self,
230+
top_SA_species: int = 10,
231+
top_SA_reactions: int = 10,
232+
max_workers: int = 24,
233+
save_yaml: bool = True,
234+
) -> Optional[dict]:
228235
"""
229236
Obtain the SA coefficients.
230237
238+
Args:
239+
top_SA_species (int, optional): The number of top sensitive species to return.
240+
top_SA_reactions (int, optional): The number of top sensitive reactions to return.
241+
max_workers (int, optional): The maximal number of workers to use for parallel processing.
242+
save_yaml (bool, optional): Save the SA dictionary to a YAML file.
243+
231244
Returns:
232245
sa_dict (Optional[dict]): An SA dictionary, structure is given in the docstring for T3/t3/main.py
233246
"""
@@ -265,6 +278,8 @@ def get_sa_coefficients(self) -> Optional[dict]:
265278
parameter = chem_to_rmg_rxn_index_map[int(parameter)] \
266279
if all(c.isdigit() for c in parameter) else parameter
267280
sa_dict[sa_type][observable_label][parameter] = df[header].values
281+
if save_yaml:
282+
save_yaml_file(path=self.paths['SA dict'], content=sa_dict)
268283
return sa_dict
269284

270285
def generate_rmg_reactors_for_simulation(self) -> dict:

0 commit comments

Comments
 (0)