Skip to content

reader

DB utils for getting artifact related data from the DB.

DBReader

Class encapsulating functions to read artifact related data from the DB.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class DBReader:
    """Class encapsulating functions to read artifact related data from the DB."""

    @staticmethod
    def get_model(model_id: str, session: Session) -> tuple[Model, DBModel]:
        """Reads the model with the given identifier using the provided session, and returns a Model and DBModel object."""
        model_orm = session.scalar(
            select(DBModel).where(DBModel.name == model_id)
        )

        if model_orm is None:
            raise errors.ErrorNotFound(
                f"Model with identifier {model_id} was not found in the artifact store."
            )
        else:
            return (
                Model(
                    identifier=model_orm.name,
                    versions=[
                        Version(identifier=version_orm.name)
                        for version_orm in model_orm.versions
                    ],
                ),
                model_orm,
            )

    @staticmethod
    def get_version(
        model_id: str,
        version_id: str,
        session: Session,
    ) -> tuple[Version, DBVersion]:
        """Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found."""
        version_orm = session.scalar(
            select(DBVersion)
            .where(DBVersion.name == version_id)
            .where(DBVersion.model_id == DBModel.id)
            .where(DBModel.name == model_id)
        )

        if version_orm is None:
            raise errors.ErrorNotFound(
                f"Version with identifier {version_id}  and associated to model {model_id} was not found in the artifact store."
            )
        else:
            return (Version(identifier=version_orm.name)), version_orm

    @staticmethod
    def get_artifact_type(
        type: ArtifactType, session: Session
    ) -> DBArtifactType:
        """Gets the artifact type DB object corresponding to the given internal type."""
        artifact_type_orm = session.scalar(
            select(DBArtifactType).where(DBArtifactType.name == type)
        )

        if artifact_type_orm is None:
            raise Exception(f"Unknown artifact type requested: {type}")
        return artifact_type_orm

    @staticmethod
    def _get_model_level_artifacts_stmt(
        model_id: str,
    ) -> Select[tuple[DBArtifact]]:
        """General statement to get artifacts stored at a model level."""
        return (
            select(DBArtifact)
            .where(DBArtifact.model_id == DBModel.id)
            .where(DBModel.name == model_id)
            .where(DBArtifact.version_id.is_(None))
        )

    @staticmethod
    def _get_version_artifacts_stmt(
        model_id: str,
        version_id: str,
    ) -> Select[tuple[DBArtifact]]:
        """General statement to get artifacts stored at a model and version level."""
        return (
            select(DBArtifact)
            .where(DBArtifact.version_id == DBVersion.id)
            .where(DBVersion.name == version_id)
            .where(DBVersion.model_id == DBModel.id)
            .where(DBModel.name == model_id)
        )

    @staticmethod
    def get_artifact(
        model_id: str,
        version_id: str,
        artifact_id: str,
        session: Session,
    ) -> tuple[ArtifactModel, DBArtifact]:
        """Reads the artifact with the given identifier using the provided session, and returns an internal object."""
        select_stmt = DBReader._get_version_artifacts_stmt(model_id, version_id)
        select_stmt = select_stmt.where(DBArtifact.identifier == artifact_id)
        artifact_orm: Optional[DBArtifact] = session.scalar(select_stmt)

        if artifact_orm is None:
            # Try at model level.
            select_stmt = DBReader._get_model_level_artifacts_stmt(model_id)
            select_stmt = select_stmt.where(
                DBArtifact.identifier == artifact_id
            )
            artifact_orm = session.scalar(select_stmt)

            if artifact_orm is None:
                raise errors.ErrorNotFound(
                    f"Artifact with identifier {artifact_id}  and associated to model {model_id}, and version {version_id} was not found in the artifact store."
                )

        return (
            main_factory.create_artifact_model(artifact_orm),
            artifact_orm,
        )

    @staticmethod
    def get_artifacts(
        model_id: str,
        version_id: str,
        session: Session,
    ) -> list[ArtifactModel]:
        """Loads and returns a list with all the artifacts, for the given model/version."""
        select_stmt = DBReader._get_version_artifacts_stmt(model_id, version_id)
        artifact_models = DBReader._load_artifact_models(select_stmt, session)

        # Also get model level artifacts.
        select_stmt = DBReader._get_model_level_artifacts_stmt(model_id)
        artifact_models.extend(
            DBReader._load_artifact_models(select_stmt, session)
        )

        return artifact_models

    @staticmethod
    def _load_artifact_models(
        select_stmt: Select[tuple[DBArtifact]], session: Session
    ) -> list[ArtifactModel]:
        """Load the artifacts obtained from the given statement into a list as ArtifactModels."""
        artifact_orms: ScalarResult[DBArtifact] = session.scalars(select_stmt)
        artifact_models: list[ArtifactModel] = []
        for artifact_orm in artifact_orms:
            artifact_model = main_factory.create_artifact_model(artifact_orm)
            artifact_models.append(artifact_model)

        return artifact_models

get_artifact(model_id, version_id, artifact_id, session) staticmethod

Reads the artifact with the given identifier using the provided session, and returns an internal object.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
@staticmethod
def get_artifact(
    model_id: str,
    version_id: str,
    artifact_id: str,
    session: Session,
) -> tuple[ArtifactModel, DBArtifact]:
    """Reads the artifact with the given identifier using the provided session, and returns an internal object."""
    select_stmt = DBReader._get_version_artifacts_stmt(model_id, version_id)
    select_stmt = select_stmt.where(DBArtifact.identifier == artifact_id)
    artifact_orm: Optional[DBArtifact] = session.scalar(select_stmt)

    if artifact_orm is None:
        # Try at model level.
        select_stmt = DBReader._get_model_level_artifacts_stmt(model_id)
        select_stmt = select_stmt.where(
            DBArtifact.identifier == artifact_id
        )
        artifact_orm = session.scalar(select_stmt)

        if artifact_orm is None:
            raise errors.ErrorNotFound(
                f"Artifact with identifier {artifact_id}  and associated to model {model_id}, and version {version_id} was not found in the artifact store."
            )

    return (
        main_factory.create_artifact_model(artifact_orm),
        artifact_orm,
    )

get_artifact_type(type, session) staticmethod

Gets the artifact type DB object corresponding to the given internal type.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
70
71
72
73
74
75
76
77
78
79
80
81
@staticmethod
def get_artifact_type(
    type: ArtifactType, session: Session
) -> DBArtifactType:
    """Gets the artifact type DB object corresponding to the given internal type."""
    artifact_type_orm = session.scalar(
        select(DBArtifactType).where(DBArtifactType.name == type)
    )

    if artifact_type_orm is None:
        raise Exception(f"Unknown artifact type requested: {type}")
    return artifact_type_orm

get_artifacts(model_id, version_id, session) staticmethod

Loads and returns a list with all the artifacts, for the given model/version.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@staticmethod
def get_artifacts(
    model_id: str,
    version_id: str,
    session: Session,
) -> list[ArtifactModel]:
    """Loads and returns a list with all the artifacts, for the given model/version."""
    select_stmt = DBReader._get_version_artifacts_stmt(model_id, version_id)
    artifact_models = DBReader._load_artifact_models(select_stmt, session)

    # Also get model level artifacts.
    select_stmt = DBReader._get_model_level_artifacts_stmt(model_id)
    artifact_models.extend(
        DBReader._load_artifact_models(select_stmt, session)
    )

    return artifact_models

get_model(model_id, session) staticmethod

Reads the model with the given identifier using the provided session, and returns a Model and DBModel object.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@staticmethod
def get_model(model_id: str, session: Session) -> tuple[Model, DBModel]:
    """Reads the model with the given identifier using the provided session, and returns a Model and DBModel object."""
    model_orm = session.scalar(
        select(DBModel).where(DBModel.name == model_id)
    )

    if model_orm is None:
        raise errors.ErrorNotFound(
            f"Model with identifier {model_id} was not found in the artifact store."
        )
    else:
        return (
            Model(
                identifier=model_orm.name,
                versions=[
                    Version(identifier=version_orm.name)
                    for version_orm in model_orm.versions
                ],
            ),
            model_orm,
        )

get_version(model_id, version_id, session) staticmethod

Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@staticmethod
def get_version(
    model_id: str,
    version_id: str,
    session: Session,
) -> tuple[Version, DBVersion]:
    """Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found."""
    version_orm = session.scalar(
        select(DBVersion)
        .where(DBVersion.name == version_id)
        .where(DBVersion.model_id == DBModel.id)
        .where(DBModel.name == model_id)
    )

    if version_orm is None:
        raise errors.ErrorNotFound(
            f"Version with identifier {version_id}  and associated to model {model_id} was not found in the artifact store."
        )
    else:
        return (Version(identifier=version_orm.name)), version_orm