Source code for torchsig.models.model_utils.layer_tools

[docs] def get_layer_list(model): """ returns a list of all layers in the input model, including layers in any nested models therein layers are listed in forward-pass order """ arr = [] final_arr = [] try: arr = [m for m in model.modules()] if len(arr) > 1: for module in arr[1:]: final_arr += (module) return final_arr else: return arr except: raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))
[docs] def replace_layer(old_layer, new_layer, model): """ search through model until old_layer is found, and replace it with new layer; returns True is old_layer was found; False otherwise """ try: modules = model._modules for k in modules.keys(): if modules[k] == old_layer: modules[k] = new_layer return True else: if replace_layer(old_layer, new_layer, modules[k]): return True return False except: raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))
[docs] def is_same_type(layer1, layer2): """ returns True if layer1 and layer2 are of the same type; false otherwise if a class is input as layer2 [e.g., is_same_type(my_conv_layer, Conv2d) ], the type defined by the class is used if a string is input as layer2, the string is matched to the name of the class of layer1 """ if type(layer2) == type: return type(layer1) == layer2 elif type(layer2) == str: return type(layer1).__name__ == layer2 else: return type(layer1) == type(layer2)
[docs] def same_type_fn(layer1): """ curried version of is_same_type; returns a function f such than f(layer2) <-> is_same_type(layer1, layer2) """ return lambda x: is_same_type(x, layer1)
[docs] def replace_layers_on_condition(model, condition_fn, new_layer_factory_fn): """ search through model finding all layers L such that conditional_fn(L), and replace them with new_layer_factory_fn(L) returns true if at least one layer was replaced; false otherwise """ has_replaced = False try: modules = model._modules for k in modules.keys(): if condition_fn(modules[k]): modules[k] = new_layer_factory_fn(modules[k]) has_replaced = True else: has_replaced = replace_layers_on_condition(modules[k], condition_fn, new_layer_factory_fn) or has_replaced return has_replaced except: raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))
[docs] def replace_layers_on_conditions(model, condition_factory_pairs): """ search through model finding all layers L such that for some ordered pair [conditional_fn, new_layer_factory_fn] in condition_factory_pairs, conditional_fn(L), and replace them with new_layer_factory_fn(L) layers will only be replaced once, so the first conditional for which a layer returns true will be last conditional to which it is compared returns true if at least one layer was replaced; false otherwise """ has_replaced = False try: modules = model._modules for k in modules.keys(): for (condition_fn, new_layer_factory_fn) in condition_factory_pairs: if condition_fn(modules[k]): modules[k] = new_layer_factory_fn(modules[k]) has_replaced = True break else: has_replaced = replace_layers_on_conditions(modules[k], condition_factory_pairs) or has_replaced return has_replaced except: raise(NotImplementedError("expected module list to be populated, but no '_modules' field was found"))
[docs] def replace_layers_of_type(model, layer_type, new_layer_factory_fn): """ search through model finding all layers L of type layer_type and replace with new_layer_factory_fn(L) returns true if at least one layer was replaced; false otherwise """ return replace_layers_on_condition(model, lambda x: is_same_type(x,layer_type), new_layer_factory_fn)
[docs] def replace_layers_of_types(model, type_factory_pairs): """ search through model finding all layers L such that for some ordered pair [layer_type, new_layer_factory_fn] in type_factory_pairs, L is of type layer_type, and replace with new_layer_factory_fn(L) returns true if at least one layer was replaced; false otherwise """ condition_factory_pairs = [(same_type_fn(layer_type), new_layer_factory_fn) for (layer_type, new_layer_factory_fn) in type_factory_pairs] return replace_layers_on_conditions(model, condition_factory_pairs)