Skip to content

API Reference

BatchProcessor

batch_please.batchers.BatchProcessor

Bases: Generic[T, R]

A class for processing items in batches.

This class takes an iterable of items or a dictionary of key-value pairs, processes them in batches using a provided function, and optionally saves progress to allow for checkpoint recovery. When using dictionary input, the keys are used as unique identifiers and the values are unpacked as keyword arguments to the processing function.

Attributes:

Name Type Description
process_func Callable

The function used to process each item. Can accept either a single item or keyword arguments from dictionary values.

batch_size int

The number of items to process in each batch.

pickle_file Optional[str]

The file to use for saving/loading progress.

processed_items Dict[str, R]

A dictionary mapping item keys to their processing results.

recover_from_checkpoint bool

Whether to attempt to recover from a checkpoint.

use_tqdm bool

Whether to use tqdm progress bars.

Source code in src/batch_please/batchers.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class BatchProcessor(Generic[T, R]):
    """
    A class for processing items in batches.

    This class takes an iterable of items or a dictionary of key-value pairs, processes them
    in batches using a provided function, and optionally saves progress to allow for checkpoint recovery.
    When using dictionary input, the keys are used as unique identifiers and the values are unpacked
    as keyword arguments to the processing function.

    Attributes:
        process_func (Callable): The function used to process each item.
            Can accept either a single item or keyword arguments from dictionary values.
        batch_size (int): The number of items to process in each batch.
        pickle_file (Optional[str]): The file to use for saving/loading progress.
        processed_items (Dict[str, R]): A dictionary mapping item keys to their processing results.
        recover_from_checkpoint (bool): Whether to attempt to recover from a checkpoint.
        use_tqdm (bool): Whether to use tqdm progress bars.
    """

    def __init__(
        self,
        process_func: Callable,
        batch_size: int = 100,
        pickle_file: Optional[str] = None,
        logfile: Optional[str] = None,
        recover_from_checkpoint: bool = False,
        use_tqdm: bool = False,
    ):
        """
        Initialize the BatchProcessor.

        Args:
            process_func (Callable): The function to process each item. Can either:
                - Accept a single positional argument when processing iterables
                - Accept keyword arguments when processing dictionaries
            batch_size (int, optional): The number of items to process in each batch. Defaults to 100.
            pickle_file (Optional[str], optional): The file to use for saving/loading progress. Defaults to None.
            logfile (Optional[str], optional): The file to use for logging. Defaults to None.
            recover_from_checkpoint (bool, optional): Whether to attempt to recover from a checkpoint. Defaults to False.
            use_tqdm (bool, optional): Whether to use tqdm progress bars. Defaults to False.
        """
        self.process_func = process_func
        self.batch_size = batch_size
        self.pickle_file = pickle_file
        self.processed_items: Dict[str, R] = {}
        self.recover_from_checkpoint = recover_from_checkpoint
        self.use_tqdm = use_tqdm

        # Set up logging
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        self.logger.handlers = []  # Clear any existing handlers
        formatter = logging.Formatter("%(asctime)s - %(message)s")

        # File handler (if logfile is provided)
        if logfile:
            file_handler = logging.FileHandler(logfile)
            file_handler.setFormatter(formatter)
            self.logger.addHandler(file_handler)

        # Recover from checkpoint if enabled
        if self.recover_from_checkpoint:
            self.load_progress()

    def process_item(self, job_number: int, item: Union[T, Dict[str, Any]]) -> R:
        """
        Process a single item.

        Args:
            job_number (int): The number of the job being processed.
            item (Union[T, Dict[str, Any]]): The item to process, either a direct value
                or a dictionary of keyword arguments.

        Returns:
            R: The result of processing the item.
        """
        if isinstance(item, dict):
            result = self.process_func(**cast(Dict[str, Any], item))
        else:
            result = self.process_func(item)  # type: ignore
        self.logger.info(f"Processed job {job_number}: {item}")
        return result

    def process_batch(
        self, batch: Union[List[T], Dict[str, Any]], batch_number: int, total_jobs: int
    ):
        """
        Process a batch of items.

        Args:
            batch (Union[List[T], Dict[str, Any]]): The batch of items to process, either a list of items
                or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.
            batch_number (int): The number of the current batch.
            total_jobs (int): The total number of jobs to process.
        """
        if isinstance(batch, dict):
            # Dict input - keys are identifiers, values are kwargs dictionaries
            if self.use_tqdm:
                batch_items = list(batch.items())
                batch_results = {
                    key: self.process_item(i, value)
                    for i, (key, value) in enumerate(
                        tqdm(batch_items, desc=f"Batch {batch_number}")
                    )
                }
            else:
                batch_results = {
                    key: self.process_item(i, value)
                    for i, (key, value) in enumerate(batch.items())
                }
        else:
            # Standard list input
            if self.use_tqdm:
                batch_results = {
                    str(item): self.process_item(i, item)
                    for i, item in enumerate(tqdm(batch, desc=f"Batch {batch_number}"))
                }
            else:
                batch_results = {
                    str(item): self.process_item(i, item)
                    for i, item in enumerate(batch)
                }

        self.processed_items.update(batch_results)

        if self.pickle_file:
            self.save_progress()

        completion_message = f"Batch {batch_number} completed. Total processed: {len(self.processed_items)}/{total_jobs}"
        print(completion_message)
        self.logger.info(completion_message)

    def load_progress(self):
        """
        Load progress from a checkpoint file if it exists.
        """
        if self.pickle_file and os.path.exists(self.pickle_file):
            with open(self.pickle_file, "rb") as f:
                data = pickle.load(f)
                self.processed_items = data
            self.logger.info(
                f"Recovered {len(self.processed_items)} items from checkpoint"
            )
        else:
            self.logger.info(
                "No checkpoint file found or checkpoint recovery not enabled"
            )

    def save_progress(self):
        """
        Save current progress to a checkpoint file.
        """
        with open(self.pickle_file, "wb") as f:
            pickle.dump(
                self.processed_items,
                f,
            )

    def process_items_in_batches(
        self, input_items: Union[Iterable[T], Dict[str, Any]]
    ) -> Dict[str, R]:
        """
        Process all input items in batches.

        Args:
            input_items (Union[Iterable[T], Dict[str, Any]]): The items to process. Can be either:
                - An iterable of items (each item is passed directly to the process function)
                - A dictionary where keys are identifiers and values are dictionaries of kwargs
                  to be unpacked into the process function

        Returns:
            Dict[str, R]: A dictionary containing the processed items and their results.
                For iterable inputs, the keys are the string representation of each item.
                For dictionary inputs, the original keys are preserved.
        """
        is_dict_input = isinstance(input_items, dict)

        if is_dict_input:
            # Handle dictionary input
            dict_input = cast(Dict[str, Any], input_items)

            if self.recover_from_checkpoint:
                recovered_items = set(self.processed_items.keys())
                dict_input = {
                    k: v for k, v in dict_input.items() if k not in recovered_items
                }

            total_jobs = len(dict_input)
            dict_items = list(dict_input.items())

            for i in range(0, total_jobs, self.batch_size):
                batch_items = dict_items[i : i + self.batch_size]
                batch_dict = dict(batch_items)
                self.process_batch(batch_dict, i // self.batch_size + 1, total_jobs)
        else:
            # Handle iterable input (original behavior)
            list_input = list(input_items)  # Convert iterable to list
            if self.recover_from_checkpoint:
                recovered_items = set(self.processed_items.keys())
                list_input = [
                    item for item in list_input if str(item) not in recovered_items
                ]

            total_jobs = len(list_input)

            for i in range(0, total_jobs, self.batch_size):
                batch = list_input[i : i + self.batch_size]
                self.process_batch(batch, i // self.batch_size + 1, total_jobs)

        return self.processed_items

__init__(process_func, batch_size=100, pickle_file=None, logfile=None, recover_from_checkpoint=False, use_tqdm=False)

Initialize the BatchProcessor.

Parameters:

Name Type Description Default
process_func Callable

The function to process each item. Can either: - Accept a single positional argument when processing iterables - Accept keyword arguments when processing dictionaries

required
batch_size int

The number of items to process in each batch. Defaults to 100.

100
pickle_file Optional[str]

The file to use for saving/loading progress. Defaults to None.

None
logfile Optional[str]

The file to use for logging. Defaults to None.

None
recover_from_checkpoint bool

Whether to attempt to recover from a checkpoint. Defaults to False.

False
use_tqdm bool

Whether to use tqdm progress bars. Defaults to False.

False
Source code in src/batch_please/batchers.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    process_func: Callable,
    batch_size: int = 100,
    pickle_file: Optional[str] = None,
    logfile: Optional[str] = None,
    recover_from_checkpoint: bool = False,
    use_tqdm: bool = False,
):
    """
    Initialize the BatchProcessor.

    Args:
        process_func (Callable): The function to process each item. Can either:
            - Accept a single positional argument when processing iterables
            - Accept keyword arguments when processing dictionaries
        batch_size (int, optional): The number of items to process in each batch. Defaults to 100.
        pickle_file (Optional[str], optional): The file to use for saving/loading progress. Defaults to None.
        logfile (Optional[str], optional): The file to use for logging. Defaults to None.
        recover_from_checkpoint (bool, optional): Whether to attempt to recover from a checkpoint. Defaults to False.
        use_tqdm (bool, optional): Whether to use tqdm progress bars. Defaults to False.
    """
    self.process_func = process_func
    self.batch_size = batch_size
    self.pickle_file = pickle_file
    self.processed_items: Dict[str, R] = {}
    self.recover_from_checkpoint = recover_from_checkpoint
    self.use_tqdm = use_tqdm

    # Set up logging
    self.logger = logging.getLogger(__name__)
    self.logger.setLevel(logging.INFO)
    self.logger.handlers = []  # Clear any existing handlers
    formatter = logging.Formatter("%(asctime)s - %(message)s")

    # File handler (if logfile is provided)
    if logfile:
        file_handler = logging.FileHandler(logfile)
        file_handler.setFormatter(formatter)
        self.logger.addHandler(file_handler)

    # Recover from checkpoint if enabled
    if self.recover_from_checkpoint:
        self.load_progress()

load_progress()

Load progress from a checkpoint file if it exists.

Source code in src/batch_please/batchers.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def load_progress(self):
    """
    Load progress from a checkpoint file if it exists.
    """
    if self.pickle_file and os.path.exists(self.pickle_file):
        with open(self.pickle_file, "rb") as f:
            data = pickle.load(f)
            self.processed_items = data
        self.logger.info(
            f"Recovered {len(self.processed_items)} items from checkpoint"
        )
    else:
        self.logger.info(
            "No checkpoint file found or checkpoint recovery not enabled"
        )

process_batch(batch, batch_number, total_jobs)

Process a batch of items.

Parameters:

Name Type Description Default
batch Union[List[T], Dict[str, Any]]

The batch of items to process, either a list of items or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.

required
batch_number int

The number of the current batch.

required
total_jobs int

The total number of jobs to process.

required
Source code in src/batch_please/batchers.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def process_batch(
    self, batch: Union[List[T], Dict[str, Any]], batch_number: int, total_jobs: int
):
    """
    Process a batch of items.

    Args:
        batch (Union[List[T], Dict[str, Any]]): The batch of items to process, either a list of items
            or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.
        batch_number (int): The number of the current batch.
        total_jobs (int): The total number of jobs to process.
    """
    if isinstance(batch, dict):
        # Dict input - keys are identifiers, values are kwargs dictionaries
        if self.use_tqdm:
            batch_items = list(batch.items())
            batch_results = {
                key: self.process_item(i, value)
                for i, (key, value) in enumerate(
                    tqdm(batch_items, desc=f"Batch {batch_number}")
                )
            }
        else:
            batch_results = {
                key: self.process_item(i, value)
                for i, (key, value) in enumerate(batch.items())
            }
    else:
        # Standard list input
        if self.use_tqdm:
            batch_results = {
                str(item): self.process_item(i, item)
                for i, item in enumerate(tqdm(batch, desc=f"Batch {batch_number}"))
            }
        else:
            batch_results = {
                str(item): self.process_item(i, item)
                for i, item in enumerate(batch)
            }

    self.processed_items.update(batch_results)

    if self.pickle_file:
        self.save_progress()

    completion_message = f"Batch {batch_number} completed. Total processed: {len(self.processed_items)}/{total_jobs}"
    print(completion_message)
    self.logger.info(completion_message)

process_item(job_number, item)

Process a single item.

Parameters:

Name Type Description Default
job_number int

The number of the job being processed.

required
item Union[T, Dict[str, Any]]

The item to process, either a direct value or a dictionary of keyword arguments.

required

Returns:

Name Type Description
R R

The result of processing the item.

Source code in src/batch_please/batchers.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def process_item(self, job_number: int, item: Union[T, Dict[str, Any]]) -> R:
    """
    Process a single item.

    Args:
        job_number (int): The number of the job being processed.
        item (Union[T, Dict[str, Any]]): The item to process, either a direct value
            or a dictionary of keyword arguments.

    Returns:
        R: The result of processing the item.
    """
    if isinstance(item, dict):
        result = self.process_func(**cast(Dict[str, Any], item))
    else:
        result = self.process_func(item)  # type: ignore
    self.logger.info(f"Processed job {job_number}: {item}")
    return result

process_items_in_batches(input_items)

Process all input items in batches.

Parameters:

Name Type Description Default
input_items Union[Iterable[T], Dict[str, Any]]

The items to process. Can be either: - An iterable of items (each item is passed directly to the process function) - A dictionary where keys are identifiers and values are dictionaries of kwargs to be unpacked into the process function

required

Returns:

Type Description
Dict[str, R]

Dict[str, R]: A dictionary containing the processed items and their results. For iterable inputs, the keys are the string representation of each item. For dictionary inputs, the original keys are preserved.

Source code in src/batch_please/batchers.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def process_items_in_batches(
    self, input_items: Union[Iterable[T], Dict[str, Any]]
) -> Dict[str, R]:
    """
    Process all input items in batches.

    Args:
        input_items (Union[Iterable[T], Dict[str, Any]]): The items to process. Can be either:
            - An iterable of items (each item is passed directly to the process function)
            - A dictionary where keys are identifiers and values are dictionaries of kwargs
              to be unpacked into the process function

    Returns:
        Dict[str, R]: A dictionary containing the processed items and their results.
            For iterable inputs, the keys are the string representation of each item.
            For dictionary inputs, the original keys are preserved.
    """
    is_dict_input = isinstance(input_items, dict)

    if is_dict_input:
        # Handle dictionary input
        dict_input = cast(Dict[str, Any], input_items)

        if self.recover_from_checkpoint:
            recovered_items = set(self.processed_items.keys())
            dict_input = {
                k: v for k, v in dict_input.items() if k not in recovered_items
            }

        total_jobs = len(dict_input)
        dict_items = list(dict_input.items())

        for i in range(0, total_jobs, self.batch_size):
            batch_items = dict_items[i : i + self.batch_size]
            batch_dict = dict(batch_items)
            self.process_batch(batch_dict, i // self.batch_size + 1, total_jobs)
    else:
        # Handle iterable input (original behavior)
        list_input = list(input_items)  # Convert iterable to list
        if self.recover_from_checkpoint:
            recovered_items = set(self.processed_items.keys())
            list_input = [
                item for item in list_input if str(item) not in recovered_items
            ]

        total_jobs = len(list_input)

        for i in range(0, total_jobs, self.batch_size):
            batch = list_input[i : i + self.batch_size]
            self.process_batch(batch, i // self.batch_size + 1, total_jobs)

    return self.processed_items

save_progress()

Save current progress to a checkpoint file.

Source code in src/batch_please/batchers.py
174
175
176
177
178
179
180
181
182
def save_progress(self):
    """
    Save current progress to a checkpoint file.
    """
    with open(self.pickle_file, "wb") as f:
        pickle.dump(
            self.processed_items,
            f,
        )

AsyncBatchProcessor

batch_please.batchers.AsyncBatchProcessor

Bases: BatchProcessor[T, R]

An asynchronous version of the BatchProcessor.

This class processes items or dictionaries of kwargs in batches asynchronously, with optional concurrency limits. When using dictionary input, the keys are used as unique identifiers and the values are unpacked as keyword arguments to the processing function.

Source code in src/batch_please/batchers.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
class AsyncBatchProcessor(BatchProcessor[T, R]):
    """
    An asynchronous version of the BatchProcessor.

    This class processes items or dictionaries of kwargs in batches asynchronously,
    with optional concurrency limits. When using dictionary input, the keys are used
    as unique identifiers and the values are unpacked as keyword arguments to the
    processing function.
    """

    def __init__(
        self,
        process_func: Callable,
        batch_size: int = 100,
        pickle_file: Optional[str] = None,
        logfile: Optional[str] = None,
        recover_from_checkpoint: bool = False,
        max_concurrent: Optional[int] = None,
        use_tqdm: bool = False,
    ):
        """
        Initialize the AsyncBatchProcessor.

        Args:
            process_func (Callable): The async function to process each item. Can either:
                - Accept a single positional argument when processing iterables
                - Accept keyword arguments when processing dictionaries
                Must return an awaitable.
            batch_size (int, optional): The number of items to process in each batch. Defaults to 100.
            pickle_file (Optional[str], optional): The file to use for saving/loading progress. Defaults to None.
            logfile (Optional[str], optional): The file to use for logging. Defaults to None.
            recover_from_checkpoint (bool, optional): Whether to attempt to recover from a checkpoint. Defaults to False.
            max_concurrent (Optional[int], optional): The maximum number of concurrent operations. Defaults to None.
            use_tqdm (bool, optional): Whether to use tqdm progress bars. Defaults to False.
        """
        super().__init__(
            process_func,
            batch_size,
            pickle_file,
            logfile,
            recover_from_checkpoint,
            use_tqdm,
        )
        self.semaphore = (
            asyncio.Semaphore(max_concurrent) if max_concurrent is not None else None
        )

    async def process_item(
        self, job_number: int, item: Union[T, Dict[str, Any]]
    ) -> Tuple[Union[T, str], R]:
        """
        Process a single item asynchronously.

        Args:
            job_number (int): The number of the job being processed.
            item (Union[T, Dict[str, Any]]): The item to process, either a direct value
                or a dictionary of keyword arguments.

        Returns:
            Tuple[Union[T, str], R]: A tuple containing the input item (or its key for dict inputs)
                and the result of processing it.
        """

        async def _process():
            if isinstance(item, dict):
                result = await self.process_func(**cast(Dict[str, Any], item))
            else:
                result = await self.process_func(item)  # type: ignore
            self.logger.info(f"Processed job {job_number}: {item}")
            return item, result

        if self.semaphore:
            async with self.semaphore:
                return await _process()
        else:
            return await _process()

    async def process_batch(
        self, batch: Union[List[T], Dict[str, Any]], batch_number: int, total_jobs: int
    ):
        """
        Process a batch of items asynchronously.

        Args:
            batch (Union[List[T], Dict[str, Any]]): The batch of items to process, either a list of items
                or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.
            batch_number (int): The number of the current batch.
            total_jobs (int): The total number of jobs to process.
        """
        if self.use_tqdm:
            pbar = tqdm(total=len(batch), desc=f"Batch {batch_number}")

        if isinstance(batch, dict):
            # Dictionary input
            tasks = [
                self.process_item(i + (batch_number - 1) * self.batch_size, value)
                for i, (key, value) in enumerate(batch.items())
            ]

            batch_results = {}
            for task in asyncio.as_completed(tasks):
                item_or_value, result = await task
                # For dict inputs, item_or_value will be the value dict, so we need to find the original key
                if isinstance(item_or_value, dict):
                    # Find the key that maps to this value dict
                    for k, v in batch.items():
                        if v is item_or_value:  # Identity comparison
                            batch_results[k] = result
                            break
                else:
                    # This shouldn't happen with dict input, but just in case
                    batch_results[str(item_or_value)] = result

                if self.use_tqdm:
                    pbar.update(1)
        else:
            # List input (original behavior)
            tasks = [
                self.process_item(i + (batch_number - 1) * self.batch_size, item)
                for i, item in enumerate(batch)
            ]

            batch_results = {}
            for task in asyncio.as_completed(tasks):
                item, result = await task
                batch_results[str(item)] = result
                if self.use_tqdm:
                    pbar.update(1)

        if self.use_tqdm:
            pbar.close()

        self.processed_items.update(batch_results)

        if self.pickle_file:
            self.save_progress()

        self.logger.info(
            f"Batch {batch_number} completed. Total processed: {len(self.processed_items)}/{total_jobs}"
        )

    async def process_items_in_batches(
        self, input_items: Union[Iterable[T], Dict[str, Any]]
    ) -> Dict[str, R]:
        """
        Process all input items in batches asynchronously.

        Args:
            input_items (Union[Iterable[T], Dict[str, Any]]): The items to process. Can be either:
                - An iterable of items (each item is passed directly to the process function)
                - A dictionary where keys are identifiers and values are dictionaries of kwargs
                  to be unpacked into the process function

        Returns:
            Dict[str, R]: A dictionary containing the processed items and their results.
                For iterable inputs, the keys are the string representation of each item.
                For dictionary inputs, the original keys are preserved.
        """
        is_dict_input = isinstance(input_items, dict)

        if is_dict_input:
            # Handle dictionary input
            dict_input = cast(Dict[str, Any], input_items)

            if self.recover_from_checkpoint:
                recovered_items = set(self.processed_items.keys())
                dict_input = {
                    k: v for k, v in dict_input.items() if k not in recovered_items
                }

            total_jobs = len(dict_input)
            dict_items = list(dict_input.items())

            for i in range(0, total_jobs, self.batch_size):
                batch_items = dict_items[i : i + self.batch_size]
                batch_dict = dict(batch_items)
                await self.process_batch(
                    batch_dict, i // self.batch_size + 1, total_jobs
                )
        else:
            # Handle iterable input (original behavior)
            list_input = list(input_items)  # Convert iterable to list
            if self.recover_from_checkpoint:
                recovered_items = set(self.processed_items.keys())
                list_input = [
                    item for item in list_input if str(item) not in recovered_items
                ]

            total_jobs = len(list_input)

            for i in range(0, total_jobs, self.batch_size):
                batch = list_input[i : i + self.batch_size]
                await self.process_batch(batch, i // self.batch_size + 1, total_jobs)

        return self.processed_items

__init__(process_func, batch_size=100, pickle_file=None, logfile=None, recover_from_checkpoint=False, max_concurrent=None, use_tqdm=False)

Initialize the AsyncBatchProcessor.

Parameters:

Name Type Description Default
process_func Callable

The async function to process each item. Can either: - Accept a single positional argument when processing iterables - Accept keyword arguments when processing dictionaries Must return an awaitable.

required
batch_size int

The number of items to process in each batch. Defaults to 100.

100
pickle_file Optional[str]

The file to use for saving/loading progress. Defaults to None.

None
logfile Optional[str]

The file to use for logging. Defaults to None.

None
recover_from_checkpoint bool

Whether to attempt to recover from a checkpoint. Defaults to False.

False
max_concurrent Optional[int]

The maximum number of concurrent operations. Defaults to None.

None
use_tqdm bool

Whether to use tqdm progress bars. Defaults to False.

False
Source code in src/batch_please/batchers.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def __init__(
    self,
    process_func: Callable,
    batch_size: int = 100,
    pickle_file: Optional[str] = None,
    logfile: Optional[str] = None,
    recover_from_checkpoint: bool = False,
    max_concurrent: Optional[int] = None,
    use_tqdm: bool = False,
):
    """
    Initialize the AsyncBatchProcessor.

    Args:
        process_func (Callable): The async function to process each item. Can either:
            - Accept a single positional argument when processing iterables
            - Accept keyword arguments when processing dictionaries
            Must return an awaitable.
        batch_size (int, optional): The number of items to process in each batch. Defaults to 100.
        pickle_file (Optional[str], optional): The file to use for saving/loading progress. Defaults to None.
        logfile (Optional[str], optional): The file to use for logging. Defaults to None.
        recover_from_checkpoint (bool, optional): Whether to attempt to recover from a checkpoint. Defaults to False.
        max_concurrent (Optional[int], optional): The maximum number of concurrent operations. Defaults to None.
        use_tqdm (bool, optional): Whether to use tqdm progress bars. Defaults to False.
    """
    super().__init__(
        process_func,
        batch_size,
        pickle_file,
        logfile,
        recover_from_checkpoint,
        use_tqdm,
    )
    self.semaphore = (
        asyncio.Semaphore(max_concurrent) if max_concurrent is not None else None
    )

process_batch(batch, batch_number, total_jobs) async

Process a batch of items asynchronously.

Parameters:

Name Type Description Default
batch Union[List[T], Dict[str, Any]]

The batch of items to process, either a list of items or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.

required
batch_number int

The number of the current batch.

required
total_jobs int

The total number of jobs to process.

required
Source code in src/batch_please/batchers.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
async def process_batch(
    self, batch: Union[List[T], Dict[str, Any]], batch_number: int, total_jobs: int
):
    """
    Process a batch of items asynchronously.

    Args:
        batch (Union[List[T], Dict[str, Any]]): The batch of items to process, either a list of items
            or a dictionary where keys are identifiers and values are items or dictionaries of kwargs.
        batch_number (int): The number of the current batch.
        total_jobs (int): The total number of jobs to process.
    """
    if self.use_tqdm:
        pbar = tqdm(total=len(batch), desc=f"Batch {batch_number}")

    if isinstance(batch, dict):
        # Dictionary input
        tasks = [
            self.process_item(i + (batch_number - 1) * self.batch_size, value)
            for i, (key, value) in enumerate(batch.items())
        ]

        batch_results = {}
        for task in asyncio.as_completed(tasks):
            item_or_value, result = await task
            # For dict inputs, item_or_value will be the value dict, so we need to find the original key
            if isinstance(item_or_value, dict):
                # Find the key that maps to this value dict
                for k, v in batch.items():
                    if v is item_or_value:  # Identity comparison
                        batch_results[k] = result
                        break
            else:
                # This shouldn't happen with dict input, but just in case
                batch_results[str(item_or_value)] = result

            if self.use_tqdm:
                pbar.update(1)
    else:
        # List input (original behavior)
        tasks = [
            self.process_item(i + (batch_number - 1) * self.batch_size, item)
            for i, item in enumerate(batch)
        ]

        batch_results = {}
        for task in asyncio.as_completed(tasks):
            item, result = await task
            batch_results[str(item)] = result
            if self.use_tqdm:
                pbar.update(1)

    if self.use_tqdm:
        pbar.close()

    self.processed_items.update(batch_results)

    if self.pickle_file:
        self.save_progress()

    self.logger.info(
        f"Batch {batch_number} completed. Total processed: {len(self.processed_items)}/{total_jobs}"
    )

process_item(job_number, item) async

Process a single item asynchronously.

Parameters:

Name Type Description Default
job_number int

The number of the job being processed.

required
item Union[T, Dict[str, Any]]

The item to process, either a direct value or a dictionary of keyword arguments.

required

Returns:

Type Description
Tuple[Union[T, str], R]

Tuple[Union[T, str], R]: A tuple containing the input item (or its key for dict inputs) and the result of processing it.

Source code in src/batch_please/batchers.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
async def process_item(
    self, job_number: int, item: Union[T, Dict[str, Any]]
) -> Tuple[Union[T, str], R]:
    """
    Process a single item asynchronously.

    Args:
        job_number (int): The number of the job being processed.
        item (Union[T, Dict[str, Any]]): The item to process, either a direct value
            or a dictionary of keyword arguments.

    Returns:
        Tuple[Union[T, str], R]: A tuple containing the input item (or its key for dict inputs)
            and the result of processing it.
    """

    async def _process():
        if isinstance(item, dict):
            result = await self.process_func(**cast(Dict[str, Any], item))
        else:
            result = await self.process_func(item)  # type: ignore
        self.logger.info(f"Processed job {job_number}: {item}")
        return item, result

    if self.semaphore:
        async with self.semaphore:
            return await _process()
    else:
        return await _process()

process_items_in_batches(input_items) async

Process all input items in batches asynchronously.

Parameters:

Name Type Description Default
input_items Union[Iterable[T], Dict[str, Any]]

The items to process. Can be either: - An iterable of items (each item is passed directly to the process function) - A dictionary where keys are identifiers and values are dictionaries of kwargs to be unpacked into the process function

required

Returns:

Type Description
Dict[str, R]

Dict[str, R]: A dictionary containing the processed items and their results. For iterable inputs, the keys are the string representation of each item. For dictionary inputs, the original keys are preserved.

Source code in src/batch_please/batchers.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
async def process_items_in_batches(
    self, input_items: Union[Iterable[T], Dict[str, Any]]
) -> Dict[str, R]:
    """
    Process all input items in batches asynchronously.

    Args:
        input_items (Union[Iterable[T], Dict[str, Any]]): The items to process. Can be either:
            - An iterable of items (each item is passed directly to the process function)
            - A dictionary where keys are identifiers and values are dictionaries of kwargs
              to be unpacked into the process function

    Returns:
        Dict[str, R]: A dictionary containing the processed items and their results.
            For iterable inputs, the keys are the string representation of each item.
            For dictionary inputs, the original keys are preserved.
    """
    is_dict_input = isinstance(input_items, dict)

    if is_dict_input:
        # Handle dictionary input
        dict_input = cast(Dict[str, Any], input_items)

        if self.recover_from_checkpoint:
            recovered_items = set(self.processed_items.keys())
            dict_input = {
                k: v for k, v in dict_input.items() if k not in recovered_items
            }

        total_jobs = len(dict_input)
        dict_items = list(dict_input.items())

        for i in range(0, total_jobs, self.batch_size):
            batch_items = dict_items[i : i + self.batch_size]
            batch_dict = dict(batch_items)
            await self.process_batch(
                batch_dict, i // self.batch_size + 1, total_jobs
            )
    else:
        # Handle iterable input (original behavior)
        list_input = list(input_items)  # Convert iterable to list
        if self.recover_from_checkpoint:
            recovered_items = set(self.processed_items.keys())
            list_input = [
                item for item in list_input if str(item) not in recovered_items
            ]

        total_jobs = len(list_input)

        for i in range(0, total_jobs, self.batch_size):
            batch = list_input[i : i + self.batch_size]
            await self.process_batch(batch, i // self.batch_size + 1, total_jobs)

    return self.processed_items

```