Skip to content

Classification

arkindex_worker.worker.classification

ElementsWorker methods for classifications and ML classes.

load_corpus_classes

load_corpus_classes()

Load all ML classes available in the worker’s corpus and store them in the self.classes cache.

Source code in arkindex_worker/worker/classification.py
18
19
20
21
22
23
24
25
26
27
28
29
def load_corpus_classes(self):
    """
    Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
    """
    corpus_classes = self.api_client.paginate(
        "ListCorpusMLClasses",
        id=self.corpus_id,
    )
    self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
    logger.info(
        f"Loaded {len(self.classes)} ML classes in corpus ({self.corpus_id})"
    )

get_ml_class_id

get_ml_class_id(ml_class)

Return the MLClass ID corresponding to the given class name on a specific corpus.

If no MLClass exists for this class name, a new one is created.

Parameters:

Name Type Description Default
ml_class str

Name of the MLClass.

required

Returns:

Type Description
str

ID of the retrieved or created MLClass.

Source code in arkindex_worker/worker/classification.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def get_ml_class_id(self, ml_class: str) -> str:
    """
    Return the MLClass ID corresponding to the given class name on a specific corpus.

    If no MLClass exists for this class name, a new one is created.
    :param ml_class: Name of the MLClass.
    :returns: ID of the retrieved or created MLClass.
    """
    if not self.classes:
        self.load_corpus_classes()

    ml_class_id = self.classes.get(ml_class)
    if ml_class_id is None:
        logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
        try:
            response = self.request(
                "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
            )
            ml_class_id = self.classes[ml_class] = response["id"]
            logger.debug(f"Created ML class {response['id']}")
        except ErrorResponse as e:
            # Only reload for 400 errors
            if e.status_code != 400:
                raise

            # Reload and make sure we have the class
            logger.info(
                f"Reloading corpus classes to see if {ml_class} already exists"
            )
            self.load_corpus_classes()
            assert (
                ml_class in self.classes
            ), "Missing class {ml_class} even after reloading"
            ml_class_id = self.classes[ml_class]

    return ml_class_id

retrieve_ml_class

retrieve_ml_class(ml_class_id)

Retrieve the name of the MLClass from its ID.

Parameters:

Name Type Description Default
ml_class_id str

ID of the searched MLClass.

required

Returns:

Type Description
str

The MLClass’s name

Source code in arkindex_worker/worker/classification.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def retrieve_ml_class(self, ml_class_id: str) -> str:
    """
    Retrieve the name of the MLClass from its ID.

    :param ml_class_id: ID of the searched MLClass.
    :return: The MLClass's name
    """
    # Load the corpus' MLclasses if they are not available yet
    if not self.classes:
        self.load_corpus_classes()

    # Filter classes by this ml_class_id
    ml_class_name = next(
        filter(
            lambda x: self.classes[x] == ml_class_id,
            self.classes,
        ),
        None,
    )
    assert (
        ml_class_name is not None
    ), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
    return ml_class_name

create_classification

create_classification(
    element, ml_class, confidence, high_confidence=False
)

Create a classification on the given element through the API.

Parameters:

Name Type Description Default
element Union[Element, CachedElement]

The element to create a classification on.

required
ml_class str

Name of the MLClass to use.

required
confidence float

Confidence score for the classification. Must be between 0 and 1.

required
high_confidence Optional[bool]

Whether or not the classification is of high confidence.

False

Returns:

Type Description
Dict[str, str]

The created classification, as returned by the CreateClassification API endpoint.

Source code in arkindex_worker/worker/classification.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def create_classification(
    self,
    element: Union[Element, CachedElement],
    ml_class: str,
    confidence: float,
    high_confidence: Optional[bool] = False,
) -> Dict[str, str]:
    """
    Create a classification on the given element through the API.

    :param element: The element to create a classification on.
    :param ml_class: Name of the MLClass to use.
    :param confidence: Confidence score for the classification. Must be between 0 and 1.
    :param high_confidence: Whether or not the classification is of high confidence.
    :returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
    """
    assert element and isinstance(
        element, (Element, CachedElement)
    ), "element shouldn't be null and should be an Element or CachedElement"
    assert ml_class and isinstance(
        ml_class, str
    ), "ml_class shouldn't be null and should be of type str"
    assert (
        isinstance(confidence, float) and 0 <= confidence <= 1
    ), "confidence shouldn't be null and should be a float in [0..1] range"
    assert isinstance(
        high_confidence, bool
    ), "high_confidence shouldn't be null and should be of type bool"
    if self.is_read_only:
        logger.warning(
            "Cannot create classification as this worker is in read-only mode"
        )
        return
    try:
        created = self.request(
            "CreateClassification",
            body={
                "element": str(element.id),
                "ml_class": self.get_ml_class_id(ml_class),
                "worker_run_id": self.worker_run_id,
                "confidence": confidence,
                "high_confidence": high_confidence,
            },
        )

        if self.use_cache:
            # Store classification in local cache
            try:
                to_insert = [
                    {
                        "id": created["id"],
                        "element_id": element.id,
                        "class_name": ml_class,
                        "confidence": created["confidence"],
                        "state": created["state"],
                        "worker_run_id": self.worker_run_id,
                    }
                ]
                CachedClassification.insert_many(to_insert).execute()
            except IntegrityError as e:
                logger.warning(
                    f"Couldn't save created classification in local cache: {e}"
                )
    except ErrorResponse as e:
        # Detect already existing classification
        if e.status_code == 400 and "non_field_errors" in e.content:
            if (
                "The fields element, worker_version, ml_class must make a unique set."
                in e.content["non_field_errors"]
            ):
                logger.warning(
                    f"This worker version has already set {ml_class} on element {element.id}"
                )
            elif (
                "The fields element, worker_run, ml_class must make a unique set."
                in e.content["non_field_errors"]
            ):
                logger.warning(
                    f"This worker run has already set {ml_class} on element {element.id}"
                )
            else:
                raise
            return

        # Propagate any other API error
        raise

    self.report.add_classification(element.id, ml_class)

    return created

create_classifications

create_classifications(element, classifications)

Create multiple classifications at once on the given element through the API.

Parameters:

Name Type Description Default
element Union[Element, CachedElement]

The element to create classifications on.

required
classifications List[Dict[str, Union[str, float, bool]]]

The classifications to create, a list of dicts. Each of them contains a ml_class_id (str), the ID of the MLClass for this classification; a confidence (float), the confidence score, between 0 and 1; a high_confidence (bool), the high confidence state of the classification.

required

Returns:

Type Description
List[Dict[str, Union[str, float, bool]]]

List of created classifications, as returned in the classifications field by the CreateClassifications API endpoint.

Source code in arkindex_worker/worker/classification.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def create_classifications(
    self,
    element: Union[Element, CachedElement],
    classifications: List[Dict[str, Union[str, float, bool]]],
) -> List[Dict[str, Union[str, float, bool]]]:
    """
    Create multiple classifications at once on the given element through the API.

    :param element: The element to create classifications on.
    :param classifications: The classifications to create, a list of dicts. Each of them contains
        a **ml_class_id** (str), the ID of the MLClass for this classification;
        a **confidence** (float), the confidence score, between 0 and 1;
        a **high_confidence** (bool), the high confidence state of the classification.

    :returns: List of created classifications, as returned in the ``classifications`` field by
       the ``CreateClassifications`` API endpoint.
    """
    assert element and isinstance(
        element, (Element, CachedElement)
    ), "element shouldn't be null and should be an Element or CachedElement"
    assert classifications and isinstance(
        classifications, list
    ), "classifications shouldn't be null and should be of type list"

    for index, classification in enumerate(classifications):
        ml_class_id = classification.get("ml_class_id")
        assert ml_class_id and isinstance(
            ml_class_id, str
        ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"

        # Make sure it's a valid UUID
        try:
            UUID(ml_class_id)
        except ValueError:
            raise ValueError(
                f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
            )

        confidence = classification.get("confidence")
        assert (
            confidence is not None
            and isinstance(confidence, float)
            and 0 <= confidence <= 1
        ), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"

        high_confidence = classification.get("high_confidence")
        if high_confidence is not None:
            assert isinstance(
                high_confidence, bool
            ), f"Classification at index {index} in classifications: high_confidence should be of type bool"

    if self.is_read_only:
        logger.warning(
            "Cannot create classifications as this worker is in read-only mode"
        )
        return

    created_cls = self.request(
        "CreateClassifications",
        body={
            "parent": str(element.id),
            "worker_run_id": self.worker_run_id,
            "classifications": classifications,
        },
    )["classifications"]

    for created_cl in created_cls:
        created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])
        self.report.add_classification(element.id, created_cl["class_name"])

    if self.use_cache:
        # Store classifications in local cache
        try:
            to_insert = [
                {
                    "id": created_cl["id"],
                    "element_id": element.id,
                    "class_name": created_cl.pop("class_name"),
                    "confidence": created_cl["confidence"],
                    "state": created_cl["state"],
                    "worker_run_id": self.worker_run_id,
                }
                for created_cl in created_cls
            ]
            CachedClassification.insert_many(to_insert).execute()
        except IntegrityError as e:
            logger.warning(
                f"Couldn't save created classifications in local cache: {e}"
            )

    return created_cls