Finding all subclasses dynamically in python

Tags:

Python provide __subclass__ to find all subclasses but it doesn’t find a class if it’s not imported yet. One solution is dynamically import subclass.

Assuming that ‘impl’ subdirectory contains subclasses, this is how one can do that. It’s dump of my current code, but you’ll get the idea.

import importlib
import logging
import pkgutil
from pathlib import Path
from typing import Type

from coinwhip.model.prediction_model import PredictionModel
from coinwhip.web.chart.renderer import ChartKeys, ChartRenderer

logger = logging.getLogger(__name__)

# This is main func to load subclasses given a parent class, assuming
# 'src' directory at the top level and 'impl' directory for subclasses.
def import_modules(parent_klass: type):
    parent_module = parent_klass.__module__
    # subclasses, or implementing classes are under impls package
    impls_pkg = f"{parent_module[:parent_module.rindex(".")]}.impls"
    # location of subclass source code. All code of mine lives under src.
    impls_path = (Path("src") / Path(impls_pkg.replace(".", "/"))).absolute()
    # Go over modules
    modules = pkgutil.iter_modules([str(impls_path)])
    if not modules:
        raise RuntimeError(f"No modules found in {impls_path}")
    for _, module_name, _ in modules:
        if module_name.endswith("_test") or module_name == "__init__":
            continue
        qual_module_name = f"{impls_pkg}.{module_name}"
        logger.info(f"Dynamically importing {qual_module_name}")
        importlib.import_module(qual_module_name)


# Usage example
def get_all_chart_keys() -> list[str]:
    import_modules(ChartRenderer)
    keys = []
    # Find all subclass of ChartKeys
    for subclass in ChartKeys.__subclasses__():
        # This is my own biz logic. I'm retrieving enum
        keys.extend([key.value for key in subclass])
    # My own biz logic. Make sure there's no duplciate enum value.
    if len(keys) != len(set(keys)):
        raise ValueError(f"Duplicate chart keys found: {sorted(keys)}")
    return keys

# Another usage example. Here, we find all leaf classes
def get_all_prediction_models() -> list[Type[PredictionModel]]:
    klasses = []
    import_modules(PredictionModel)
    leaves = []
    klasses = [PredictionModel]
    while klasses:
        cur = klasses.pop()
        if cur.__subclasses__():
            klasses.extend(cur.__subclasses__())
        else:
            leaves.append(cur)
    return leaves