Skip to content

Classification

arkindex_worker.worker.classification

ElementsWorker methods for classifications and ML classes.

Attributes

Classes

ClassificationMixin

Functions
load_corpus_classes
load_corpus_classes()

Load all ML classes available in the worker’s corpus and store them in the self.classes cache.

Source code in arkindex_worker/worker/classification.py
20
21
22
23
24
25
26
27
28
29
30
31
def load_corpus_classes(self):
    """
    Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
    """
    corpus_classes = self.api_client.paginate(
        "ListCorpusMLClasses",
        id=self.corpus_id,
    )
    self.classes = {ml_class["name"]: ml_class["id"] for ml_class in corpus_classes}
    logger.info(
        f'Loaded {len(self.classes)} ML {pluralize("class", len(self.classes))} in corpus ({self.corpus_id})'
    )
get_ml_class_id
get_ml_class_id(ml_class: str) -> str

Return the MLClass ID corresponding to the given class name on a specific corpus.

If no MLClass exists for this class name, a new one is created.

Parameters:

Name Type Description Default
ml_class str

Name of the MLClass.

required

Returns:

Type Description
str

ID of the retrieved or created MLClass.

Source code in arkindex_worker/worker/classification.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def get_ml_class_id(self, ml_class: str) -> str:
    """
    Return the MLClass ID corresponding to the given class name on a specific corpus.

    If no MLClass exists for this class name, a new one is created.
    :param ml_class: Name of the MLClass.
    :returns: ID of the retrieved or created MLClass.
    """
    if not self.classes:
        self.load_corpus_classes()

    ml_class_id = self.classes.get(ml_class)
    if ml_class_id is None:
        logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
        try:
            response = self.api_client.request(
                "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
            )
            ml_class_id = self.classes[ml_class] = response["id"]
            logger.debug(f"Created ML class {response['id']}")
        except ErrorResponse as e:
            # Only reload for 400 errors
            if e.status_code != 400:
                raise

            # Reload and make sure we have the class
            logger.info(
                f"Reloading corpus classes to see if {ml_class} already exists"
            )
            self.load_corpus_classes()
            assert (
                ml_class in self.classes
            ), "Missing class {ml_class} even after reloading"
            ml_class_id = self.classes[ml_class]

    return ml_class_id
retrieve_ml_class
retrieve_ml_class(ml_class_id: str) -> str

Retrieve the name of the MLClass from its ID.

Parameters:

Name Type Description Default
ml_class_id str

ID of the searched MLClass.

required

Returns:

Type Description
str

The MLClass’s name

Source code in arkindex_worker/worker/classification.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def retrieve_ml_class(self, ml_class_id: str) -> str:
    """
    Retrieve the name of the MLClass from its ID.

    :param ml_class_id: ID of the searched MLClass.
    :return: The MLClass's name
    """
    # Load the corpus' MLclasses if they are not available yet
    if not self.classes:
        self.load_corpus_classes()

    # Filter classes by this ml_class_id
    ml_class_name = next(
        filter(
            lambda x: self.classes[x] == ml_class_id,
            self.classes,
        ),
        None,
    )
    assert (
        ml_class_name is not None
    ), f"Missing class with id ({ml_class_id}) in corpus ({self.corpus_id})"
    return ml_class_name
create_classification
create_classification(
    element: Element | CachedElement,
    ml_class: str,
    confidence: float,
    high_confidence: bool = False,
) -> dict[str, str]

Create a classification on the given element through the API.

Parameters:

Name Type Description Default
element Element | CachedElement

The element to create a classification on.

required
ml_class str

Name of the MLClass to use.

required
confidence float

Confidence score for the classification. Must be between 0 and 1.

required
high_confidence bool

Whether or not the classification is of high confidence.

False

Returns:

Type Description
dict[str, str]

The created classification, as returned by the CreateClassification API endpoint.

Source code in arkindex_worker/worker/classification.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def create_classification(
    self,
    element: Element | CachedElement,
    ml_class: str,
    confidence: float,
    high_confidence: bool = False,
) -> dict[str, str]:
    """
    Create a classification on the given element through the API.

    :param element: The element to create a classification on.
    :param ml_class: Name of the MLClass to use.
    :param confidence: Confidence score for the classification. Must be between 0 and 1.
    :param high_confidence: Whether or not the classification is of high confidence.
    :returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
    """
    assert element and isinstance(
        element, Element | CachedElement
    ), "element shouldn't be null and should be an Element or CachedElement"
    assert ml_class and isinstance(
        ml_class, str
    ), "ml_class shouldn't be null and should be of type str"
    assert (
        isinstance(confidence, float) and 0 <= confidence <= 1
    ), "confidence shouldn't be null and should be a float in [0..1] range"
    assert isinstance(
        high_confidence, bool
    ), "high_confidence shouldn't be null and should be of type bool"
    if self.is_read_only:
        logger.warning(
            "Cannot create classification as this worker is in read-only mode"
        )
        return
    try:
        created = self.api_client.request(
            "CreateClassification",
            body={
                "element": str(element.id),
                "ml_class": self.get_ml_class_id(ml_class),
                "worker_run_id": self.worker_run_id,
                "confidence": confidence,
                "high_confidence": high_confidence,
            },
        )

        if self.use_cache:
            # Store classification in local cache
            try:
                to_insert = [
                    {
                        "id": created["id"],
                        "element_id": element.id,
                        "class_name": ml_class,
                        "confidence": created["confidence"],
                        "state": created["state"],
                        "worker_run_id": self.worker_run_id,
                    }
                ]
                CachedClassification.insert_many(to_insert).execute()
            except IntegrityError as e:
                logger.warning(
                    f"Couldn't save created classification in local cache: {e}"
                )
    except ErrorResponse as e:
        # Detect already existing classification
        if e.status_code == 400 and "non_field_errors" in e.content:
            if (
                "The fields element, worker_run, ml_class must make a unique set."
                in e.content["non_field_errors"]
            ):
                logger.warning(
                    f"This worker run has already set {ml_class} on element {element.id}"
                )
            else:
                raise
            return

        # Propagate any other API error
        raise

    return created
create_classifications
create_classifications(
    element: Element | CachedElement,
    classifications: list[dict[str, str | float | bool]],
    batch_size: int = DEFAULT_BATCH_SIZE,
) -> list[dict[str, str | float | bool]]

Create multiple classifications at once on the given element through the API.

Parameters:

Name Type Description Default
element Element | CachedElement

The element to create classifications on.

required
classifications list[dict[str, str | float | bool]]

A list of dicts representing a classification each, with the following keys: ml_class (str) Required. Name of the MLClass to use. confidence (float) Required. Confidence score for the classification. Must be between 0 and 1. high_confidence (bool) Optional. Whether or not the classification is of high confidence.

required
batch_size int

The size of each batch, which will be used to split the publication to avoid API errors.

DEFAULT_BATCH_SIZE

Returns:

Type Description
list[dict[str, str | float | bool]]

List of created classifications, as returned in the classifications field by the CreateClassifications API endpoint.

Source code in arkindex_worker/worker/classification.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@batch_publication
def create_classifications(
    self,
    element: Element | CachedElement,
    classifications: list[dict[str, str | float | bool]],
    batch_size: int = DEFAULT_BATCH_SIZE,
) -> list[dict[str, str | float | bool]]:
    """
    Create multiple classifications at once on the given element through the API.

    :param element: The element to create classifications on.
    :param classifications: A list of dicts representing a classification each, with the following keys:

        ml_class (str)
            Required. Name of the MLClass to use.
        confidence (float)
            Required. Confidence score for the classification. Must be between 0 and 1.
        high_confidence (bool)
            Optional. Whether or not the classification is of high confidence.

    :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.

    :returns: List of created classifications, as returned in the ``classifications`` field by
       the ``CreateClassifications`` API endpoint.
    """
    assert element and isinstance(
        element, Element | CachedElement
    ), "element shouldn't be null and should be an Element or CachedElement"
    assert classifications and isinstance(
        classifications, list
    ), "classifications shouldn't be null and should be of type list"

    for index, classification in enumerate(classifications):
        ml_class = classification.get("ml_class")
        assert (
            ml_class and isinstance(ml_class, str)
        ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"

        confidence = classification.get("confidence")
        assert (
            confidence is not None
            and isinstance(confidence, float)
            and 0 <= confidence <= 1
        ), f"Classification at index {index} in classifications: confidence shouldn't be null and should be a float in [0..1] range"

        high_confidence = classification.get("high_confidence")
        if high_confidence is not None:
            assert isinstance(
                high_confidence, bool
            ), f"Classification at index {index} in classifications: high_confidence should be of type bool"

    if self.is_read_only:
        logger.warning(
            "Cannot create classifications as this worker is in read-only mode"
        )
        return

    created_cls = [
        created_cl
        for batch in make_batches(classifications, "classification", batch_size)
        for created_cl in self.api_client.request(
            "CreateClassifications",
            body={
                "parent": str(element.id),
                "worker_run_id": self.worker_run_id,
                "classifications": [
                    {
                        **classification,
                        "ml_class": self.get_ml_class_id(
                            classification["ml_class"]
                        ),
                    }
                    for classification in batch
                ],
            },
        )["classifications"]
    ]

    for created_cl in created_cls:
        created_cl["class_name"] = self.retrieve_ml_class(created_cl["ml_class"])

    if self.use_cache:
        # Store classifications in local cache
        try:
            to_insert = [
                {
                    "id": created_cl["id"],
                    "element_id": element.id,
                    "class_name": created_cl.pop("class_name"),
                    "confidence": created_cl["confidence"],
                    "state": created_cl["state"],
                    "worker_run_id": self.worker_run_id,
                }
                for created_cl in created_cls
            ]
            CachedClassification.insert_many(to_insert).execute()
        except IntegrityError as e:
            logger.warning(
                f"Couldn't save created classifications in local cache: {e}"
            )

    return created_cls

Functions