Skip to content

reader

mlte/store/artifact/underlying/rdbs/reader.py

DB utils for getting artifact related data from the DB.

DBReader

Class encapsulating functions to read artifact related data from the DB.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
class DBReader:
    """Class encapsulating functions to read artifact related data from the DB."""

    # Artifact Type - DB Object mapping.
    SUPPORTED_ARTIFACT_DB_CLASSES: dict[
        ArtifactType,
        Union[
            type[DBSpec],
            type[DBValidatedSpec],
            type[DBNegotiationCard],
            type[DBReport],
            type[DBValue],
        ],
    ] = {
        ArtifactType.SPEC: DBSpec,
        ArtifactType.VALIDATED_SPEC: DBValidatedSpec,
        ArtifactType.NEGOTIATION_CARD: DBNegotiationCard,
        ArtifactType.REPORT: DBReport,
        ArtifactType.VALUE: DBValue,
    }

    @staticmethod
    def get_model(model_id: str, session: Session) -> Tuple[Model, DBModel]:
        """Reads the model with the given identifier using the provided session, and returns a Model and DBModel object."""
        model_obj = session.scalar(
            select(DBModel).where(DBModel.name == model_id)
        )

        if model_obj is None:
            raise errors.ErrorNotFound(
                f"Model with identifier {model_id} was not found in the artifact store."
            )
        else:
            return (
                Model(
                    identifier=model_obj.name,
                    versions=[
                        Version(identifier=version_obj.name)
                        for version_obj in model_obj.versions
                    ],
                ),
                model_obj,
            )

    @staticmethod
    def get_version(
        model_id: str,
        version_id: str,
        session: Session,
    ) -> Tuple[Version, DBVersion]:
        """Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found."""
        version_obj = session.scalar(
            select(DBVersion)
            .where(DBVersion.name == version_id)
            .where(DBVersion.model_id == DBModel.id)
            .where(DBModel.name == model_id)
        )

        if version_obj is None:
            raise errors.ErrorNotFound(
                f"Version with identifier {version_id}  and associated to model {model_id} was not found in the artifact store."
            )
        else:
            return (Version(identifier=version_obj.name)), version_obj

    @staticmethod
    def get_artifact_type(
        type: ArtifactType, session: Session
    ) -> DBArtifactType:
        """Gets the artifact type DB object corresponding to the given internal type."""
        artifact_type_obj = session.scalar(
            select(DBArtifactType).where(DBArtifactType.name == type)
        )

        if artifact_type_obj is None:
            raise Exception(f"Unknown artifact type requested: {type}")
        return artifact_type_obj

    @staticmethod
    def get_artifact(
        model_id: str,
        version_id: str,
        artifact_id: str,
        session: Session,
    ) -> Tuple[
        ArtifactModel,
        Union[DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue],
    ]:
        """Reads the artifact with the given identifier using the provided session, and returns an internal object."""
        # First get the class of the artifact we are trying to read, so we can use the ORM by passing the DB object type.
        artifact_header_obj = DBReader.get_artifact_header(artifact_id, session)
        artifact_type = ArtifactType(artifact_header_obj.type.name)

        # Get artifact.
        artifact_class = DBReader.get_artifact_class(artifact_type)
        artifact_obj: Union[
            DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue
        ] = session.scalar(
            select(artifact_class)
            .where(DBVersion.model_id == DBModel.id)
            .where(DBVersion.name == version_id)
            .where(DBModel.name == model_id)
            .where(DBArtifactHeader.id == artifact_class.artifact_header_id)
            .where(DBArtifactHeader.identifier == artifact_id)
            .where(DBArtifactHeader.version_id == DBVersion.id)
        )

        if artifact_obj is None:
            raise errors.ErrorNotFound(
                f"Artifact with identifier {artifact_id}  and associated to model {model_id}, and version {version_id} was not found in the artifact store."
            )
        else:
            return (
                factory.create_artifact_from_db(artifact_header_obj, session),
                artifact_obj,
            )

    @staticmethod
    def get_artifacts_for_type(
        model_id: str,
        version_id: str,
        artifact_type: ArtifactType,
        session: Session,
    ) -> List[ArtifactModel]:
        """Loads and returns a list with all the artifacts of the given type, for the given model/version."""
        artifact_class = DBReader.get_artifact_class(artifact_type)
        artifact_objs: ScalarResult[
            Union[DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue]
        ] = session.scalars(
            (
                select(artifact_class)
                .where(DBVersion.model_id == DBModel.id)
                .where(DBVersion.name == version_id)
                .where(DBModel.name == model_id)
                .where(DBArtifactHeader.id == artifact_class.artifact_header_id)
                .where(DBArtifactHeader.version_id == DBVersion.id)
            )
        )
        artifacts = []
        for artifact_obj in artifact_objs:
            artifact = factory.create_artifact_from_db(
                artifact_obj.artifact_header, session
            )
            artifacts.append(artifact)
        return artifacts

    @staticmethod
    def get_artifact_header(
        artifact_id: str, session: Session
    ) -> DBArtifactHeader:
        """Gets the artifact header object of the artifact identifier provided."""
        artifact_header_obj = session.scalar(
            select(DBArtifactHeader).where(
                DBArtifactHeader.identifier == artifact_id
            )
        )
        if artifact_header_obj is None:
            raise errors.ErrorNotFound(
                f"Artifact with identifier {artifact_id} was not found in the artifact store."
            )
        else:
            return artifact_header_obj

    @staticmethod
    def get_artifact_class(
        artifact_type: ArtifactType,
    ) -> Union[
        type[DBSpec],
        type[DBValidatedSpec],
        type[DBNegotiationCard],
        type[DBReport],
        type[DBValue],
    ]:
        """Gets the DB class of the artifact header provided."""
        if artifact_type in DBReader.SUPPORTED_ARTIFACT_DB_CLASSES:
            return DBReader.SUPPORTED_ARTIFACT_DB_CLASSES[
                ArtifactType(artifact_type)
            ]
        else:
            raise Exception(f"Unsupported artifact type: {artifact_type.value}")

    @staticmethod
    def get_spec(
        spec_identifier: str, version_id: int, session: Session
    ) -> DBSpec:
        """Gets the Spec with the given identifier."""
        property_obj = session.scalar(
            select(DBSpec)
            .where(DBSpec.artifact_header_id == DBArtifactHeader.id)
            .where(DBArtifactHeader.identifier == spec_identifier)
            .where(DBArtifactHeader.version_id == version_id)
        )
        if property_obj is None:
            raise errors.ErrorNotFound(
                f"Spec with identifier {spec_identifier} was not found in the artifact store."
            )
        else:
            return property_obj

    @staticmethod
    def get_validated_spec(
        validated_spec_identifier: str, version_id: int, session: Session
    ) -> DBValidatedSpec:
        """Gets the Spec with the given identifier."""
        property_obj = session.scalar(
            select(DBValidatedSpec)
            .where(DBValidatedSpec.artifact_header_id == DBArtifactHeader.id)
            .where(DBArtifactHeader.identifier == validated_spec_identifier)
            .where(DBArtifactHeader.version_id == version_id)
        )
        if property_obj is None:
            raise errors.ErrorNotFound(
                f"ValidatedSpec with identifier {validated_spec_identifier} was not found in the artifact store."
            )
        else:
            return property_obj

    @staticmethod
    def get_qa_category_id(
        qa_category_name: str,
        spec_identifier: str,
        version_id: int,
        session: Session,
    ) -> int:
        """Gets the id of the qa category with the given name for the indicated Spec."""
        qa_category_id = session.scalar(
            select(DBQACategory.id)
            .where(DBQACategory.name == qa_category_name)
            .where(DBSpec.id == DBQACategory.spec_id)
            .where(DBSpec.artifact_header_id == DBArtifactHeader.id)
            .where(DBArtifactHeader.identifier == spec_identifier)
            .where(DBArtifactHeader.version_id == version_id)
        )
        if qa_category_id is None:
            raise errors.ErrorNotFound(
                f"Quality attribute category with name {qa_category_name} for Spec with identifier {spec_identifier} was not found in the artifact store."
            )
        else:
            return qa_category_id

    @staticmethod
    def get_problem_type(type: ProblemType, session: Session) -> DBProblemType:
        """Gets the problem type DB object corresponding to the given internal type."""
        artifact_type_obj = session.scalar(
            select(DBProblemType).where(DBProblemType.name == type)
        )

        if artifact_type_obj is None:
            raise Exception(f"Unknown problem type requested: {type}")
        return artifact_type_obj

    @staticmethod
    def get_classification_type(
        type: DataClassification, session: Session
    ) -> DBDataClassification:
        """Gets the data classification DB object corresponding to the given internal type."""
        artifact_type_obj = session.scalar(
            select(DBDataClassification).where(
                DBDataClassification.name == type
            )
        )

        if artifact_type_obj is None:
            raise Exception(f"Unknown data classification requested: {type}")
        return artifact_type_obj

get_artifact(model_id, version_id, artifact_id, session) staticmethod

Reads the artifact with the given identifier using the provided session, and returns an internal object.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@staticmethod
def get_artifact(
    model_id: str,
    version_id: str,
    artifact_id: str,
    session: Session,
) -> Tuple[
    ArtifactModel,
    Union[DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue],
]:
    """Reads the artifact with the given identifier using the provided session, and returns an internal object."""
    # First get the class of the artifact we are trying to read, so we can use the ORM by passing the DB object type.
    artifact_header_obj = DBReader.get_artifact_header(artifact_id, session)
    artifact_type = ArtifactType(artifact_header_obj.type.name)

    # Get artifact.
    artifact_class = DBReader.get_artifact_class(artifact_type)
    artifact_obj: Union[
        DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue
    ] = session.scalar(
        select(artifact_class)
        .where(DBVersion.model_id == DBModel.id)
        .where(DBVersion.name == version_id)
        .where(DBModel.name == model_id)
        .where(DBArtifactHeader.id == artifact_class.artifact_header_id)
        .where(DBArtifactHeader.identifier == artifact_id)
        .where(DBArtifactHeader.version_id == DBVersion.id)
    )

    if artifact_obj is None:
        raise errors.ErrorNotFound(
            f"Artifact with identifier {artifact_id}  and associated to model {model_id}, and version {version_id} was not found in the artifact store."
        )
    else:
        return (
            factory.create_artifact_from_db(artifact_header_obj, session),
            artifact_obj,
        )

get_artifact_class(artifact_type) staticmethod

Gets the DB class of the artifact header provided.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@staticmethod
def get_artifact_class(
    artifact_type: ArtifactType,
) -> Union[
    type[DBSpec],
    type[DBValidatedSpec],
    type[DBNegotiationCard],
    type[DBReport],
    type[DBValue],
]:
    """Gets the DB class of the artifact header provided."""
    if artifact_type in DBReader.SUPPORTED_ARTIFACT_DB_CLASSES:
        return DBReader.SUPPORTED_ARTIFACT_DB_CLASSES[
            ArtifactType(artifact_type)
        ]
    else:
        raise Exception(f"Unsupported artifact type: {artifact_type.value}")

get_artifact_header(artifact_id, session) staticmethod

Gets the artifact header object of the artifact identifier provided.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@staticmethod
def get_artifact_header(
    artifact_id: str, session: Session
) -> DBArtifactHeader:
    """Gets the artifact header object of the artifact identifier provided."""
    artifact_header_obj = session.scalar(
        select(DBArtifactHeader).where(
            DBArtifactHeader.identifier == artifact_id
        )
    )
    if artifact_header_obj is None:
        raise errors.ErrorNotFound(
            f"Artifact with identifier {artifact_id} was not found in the artifact store."
        )
    else:
        return artifact_header_obj

get_artifact_type(type, session) staticmethod

Gets the artifact type DB object corresponding to the given internal type.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
105
106
107
108
109
110
111
112
113
114
115
116
@staticmethod
def get_artifact_type(
    type: ArtifactType, session: Session
) -> DBArtifactType:
    """Gets the artifact type DB object corresponding to the given internal type."""
    artifact_type_obj = session.scalar(
        select(DBArtifactType).where(DBArtifactType.name == type)
    )

    if artifact_type_obj is None:
        raise Exception(f"Unknown artifact type requested: {type}")
    return artifact_type_obj

get_artifacts_for_type(model_id, version_id, artifact_type, session) staticmethod

Loads and returns a list with all the artifacts of the given type, for the given model/version.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
@staticmethod
def get_artifacts_for_type(
    model_id: str,
    version_id: str,
    artifact_type: ArtifactType,
    session: Session,
) -> List[ArtifactModel]:
    """Loads and returns a list with all the artifacts of the given type, for the given model/version."""
    artifact_class = DBReader.get_artifact_class(artifact_type)
    artifact_objs: ScalarResult[
        Union[DBSpec, DBValidatedSpec, DBNegotiationCard, DBReport, DBValue]
    ] = session.scalars(
        (
            select(artifact_class)
            .where(DBVersion.model_id == DBModel.id)
            .where(DBVersion.name == version_id)
            .where(DBModel.name == model_id)
            .where(DBArtifactHeader.id == artifact_class.artifact_header_id)
            .where(DBArtifactHeader.version_id == DBVersion.id)
        )
    )
    artifacts = []
    for artifact_obj in artifact_objs:
        artifact = factory.create_artifact_from_db(
            artifact_obj.artifact_header, session
        )
        artifacts.append(artifact)
    return artifacts

get_classification_type(type, session) staticmethod

Gets the data classification DB object corresponding to the given internal type.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
@staticmethod
def get_classification_type(
    type: DataClassification, session: Session
) -> DBDataClassification:
    """Gets the data classification DB object corresponding to the given internal type."""
    artifact_type_obj = session.scalar(
        select(DBDataClassification).where(
            DBDataClassification.name == type
        )
    )

    if artifact_type_obj is None:
        raise Exception(f"Unknown data classification requested: {type}")
    return artifact_type_obj

get_model(model_id, session) staticmethod

Reads the model with the given identifier using the provided session, and returns a Model and DBModel object.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@staticmethod
def get_model(model_id: str, session: Session) -> Tuple[Model, DBModel]:
    """Reads the model with the given identifier using the provided session, and returns a Model and DBModel object."""
    model_obj = session.scalar(
        select(DBModel).where(DBModel.name == model_id)
    )

    if model_obj is None:
        raise errors.ErrorNotFound(
            f"Model with identifier {model_id} was not found in the artifact store."
        )
    else:
        return (
            Model(
                identifier=model_obj.name,
                versions=[
                    Version(identifier=version_obj.name)
                    for version_obj in model_obj.versions
                ],
            ),
            model_obj,
        )

get_problem_type(type, session) staticmethod

Gets the problem type DB object corresponding to the given internal type.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
280
281
282
283
284
285
286
287
288
289
@staticmethod
def get_problem_type(type: ProblemType, session: Session) -> DBProblemType:
    """Gets the problem type DB object corresponding to the given internal type."""
    artifact_type_obj = session.scalar(
        select(DBProblemType).where(DBProblemType.name == type)
    )

    if artifact_type_obj is None:
        raise Exception(f"Unknown problem type requested: {type}")
    return artifact_type_obj

get_qa_category_id(qa_category_name, spec_identifier, version_id, session) staticmethod

Gets the id of the qa category with the given name for the indicated Spec.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@staticmethod
def get_qa_category_id(
    qa_category_name: str,
    spec_identifier: str,
    version_id: int,
    session: Session,
) -> int:
    """Gets the id of the qa category with the given name for the indicated Spec."""
    qa_category_id = session.scalar(
        select(DBQACategory.id)
        .where(DBQACategory.name == qa_category_name)
        .where(DBSpec.id == DBQACategory.spec_id)
        .where(DBSpec.artifact_header_id == DBArtifactHeader.id)
        .where(DBArtifactHeader.identifier == spec_identifier)
        .where(DBArtifactHeader.version_id == version_id)
    )
    if qa_category_id is None:
        raise errors.ErrorNotFound(
            f"Quality attribute category with name {qa_category_name} for Spec with identifier {spec_identifier} was not found in the artifact store."
        )
    else:
        return qa_category_id

get_spec(spec_identifier, version_id, session) staticmethod

Gets the Spec with the given identifier.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
@staticmethod
def get_spec(
    spec_identifier: str, version_id: int, session: Session
) -> DBSpec:
    """Gets the Spec with the given identifier."""
    property_obj = session.scalar(
        select(DBSpec)
        .where(DBSpec.artifact_header_id == DBArtifactHeader.id)
        .where(DBArtifactHeader.identifier == spec_identifier)
        .where(DBArtifactHeader.version_id == version_id)
    )
    if property_obj is None:
        raise errors.ErrorNotFound(
            f"Spec with identifier {spec_identifier} was not found in the artifact store."
        )
    else:
        return property_obj

get_validated_spec(validated_spec_identifier, version_id, session) staticmethod

Gets the Spec with the given identifier.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
@staticmethod
def get_validated_spec(
    validated_spec_identifier: str, version_id: int, session: Session
) -> DBValidatedSpec:
    """Gets the Spec with the given identifier."""
    property_obj = session.scalar(
        select(DBValidatedSpec)
        .where(DBValidatedSpec.artifact_header_id == DBArtifactHeader.id)
        .where(DBArtifactHeader.identifier == validated_spec_identifier)
        .where(DBArtifactHeader.version_id == version_id)
    )
    if property_obj is None:
        raise errors.ErrorNotFound(
            f"ValidatedSpec with identifier {validated_spec_identifier} was not found in the artifact store."
        )
    else:
        return property_obj

get_version(model_id, version_id, session) staticmethod

Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found.

Source code in mlte/store/artifact/underlying/rdbs/reader.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@staticmethod
def get_version(
    model_id: str,
    version_id: str,
    session: Session,
) -> Tuple[Version, DBVersion]:
    """Reads the version with the given identifier using the provided session, and returns a Version and DBVersion object. Raises ErrorNotFound if not found."""
    version_obj = session.scalar(
        select(DBVersion)
        .where(DBVersion.name == version_id)
        .where(DBVersion.model_id == DBModel.id)
        .where(DBModel.name == model_id)
    )

    if version_obj is None:
        raise errors.ErrorNotFound(
            f"Version with identifier {version_id}  and associated to model {model_id} was not found in the artifact store."
        )
    else:
        return (Version(identifier=version_obj.name)), version_obj