#!/usr/bin/env python3
#
# Copyright 2022 Graviti. Licensed under MIT License.
#
"""Paging list related class."""
from itertools import repeat
from math import ceil
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar
import pyarrow as pa
from graviti.paging.lists import MappedPagingList, PagingList, PyArrowPagingList
from graviti.paging.offset import Offsets
from graviti.paging.wrapper import StructArrayWrapper
_T = TypeVar("_T")
[docs]class LazyFactoryBase:
"""LazyFactoryBase is the base class of the lazy facotry."""
_patype: pa.DataType
def __getitem__(self, key: str) -> "LazySubFactory":
raise NotImplementedError
def __contains__(self, key: str) -> bool:
try:
self._patype.__getitem__(key)
return True
except KeyError:
return False
[docs] def create_list(self, mapper: Callable[[Any], _T]) -> PagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Raises:
NotImplementedError: The method of the base class should not be called.
"""
raise NotImplementedError
[docs] def create_mapped_list(self, mapper: Callable[[Any], _T]) -> MappedPagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Raises:
NotImplementedError: The method of the base class should not be called.
"""
raise NotImplementedError
[docs] def create_pyarrow_list(self) -> PyArrowPagingList[Any]:
"""Create a paging list from the factory.
Raises:
NotImplementedError: The method of the base class should not be called.
"""
raise NotImplementedError
[docs]class LazyFactory(LazyFactoryBase):
"""LazyFactory is a factory for requesting source data and creating paging lists.
Arguments:
total_count: The total count of the elements in the paging lists.
limit: The size of each lazy load page.
getter: A callable object to get the source data.
patype: The pyarrow DataType of the data in the factory.
Examples:
>>> import pyarrow as pa
>>> patype = pa.struct(
... {
... "remotePath": pa.string(),
... "label": pa.struct({"CLASSIFICATION": pa.struct({"category": pa.string()})}),
... }
... )
>>> TOTAL_COUNT = 1000
>>> def getter(offset: int, limit: int) -> List[Dict[str, Any]]:
... stop = min(offset + limit, TOTAL_COUNT)
... return [
... {
... "remotePath": f"{i:06}.jpg",
... "label": {"CLASSIFICATION": {"category": "cat" if i % 2 else "dog"}},
... }
... for i in range(offset, stop)
... ]
...
>>> factory = LazyFactory(TOTAL_COUNT, 128, getter, patype)
>>> paths = factory["remotePath"].create_pyarrow_list()
>>> categories = factory["label"]["CLASSIFICATION"]["category"].create_pyarrow_list()
>>> len(paths)
1000
>>> list(paths)
[<pyarrow.StringScalar: '000000.jpg'>,
<pyarrow.StringScalar: '000001.jpg'>,
<pyarrow.StringScalar: '000002.jpg'>,
<pyarrow.StringScalar: '000003.jpg'>,
<pyarrow.StringScalar: '000004.jpg'>,
<pyarrow.StringScalar: '000005.jpg'>,
...
...
<pyarrow.StringScalar: '000999.jpg'>]
>>> len(categories)
1000
>>> list(categories)
[<pyarrow.StringScalar: 'dog'>,
<pyarrow.StringScalar: 'cat'>,
<pyarrow.StringScalar: 'dog'>,
<pyarrow.StringScalar: 'cat'>,
<pyarrow.StringScalar: 'dog'>,
...
...
<pyarrow.StringScalar: 'cat'>]
"""
def __init__(
self, total_count: int, limit: int, getter: Callable[[int, int], Any], patype: pa.DataType
) -> None:
self._getter = getter
self._total_count = total_count
self._limit = limit
self._patype = patype
self._pages: List[Optional[pa.StructArray]] = [None] * ceil(total_count / limit)
def __getitem__(self, key: str) -> "LazySubFactory":
return LazySubFactory(self, (key,), self._patype[key].type)
[docs] def get_array(self, pos: int, keys: Tuple[str, ...]) -> pa.Array:
"""Get the array from the factory.
Arguments:
pos: The page number.
keys: The keys to access the array from factory.
Returns:
The requested pyarrow array.
"""
array = self._pages[pos]
if array is None:
array = pa.array(self._getter(pos * self._limit, self._limit), type=self._patype)
self._pages[pos] = array
for key in keys:
array = array.field(key)
return array
[docs] def create_list(self, mapper: Callable[[Any], _T]) -> PagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Returns:
A paging list created from the factory.
"""
return PagingList.from_factory(self, (), mapper)
[docs] def create_mapped_list(self, mapper: Callable[[Any], _T]) -> MappedPagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Returns:
A paging list created from the factory.
"""
return MappedPagingList.from_factory(self, (), mapper)
[docs] def create_pyarrow_list(self) -> PyArrowPagingList[Any]:
"""Create a paging list from the factory.
Returns:
A paging list created from the factory.
"""
return PyArrowPagingList.from_factory(self, (), self._patype)
[docs] def get_page_lengths(self) -> Iterator[int]:
"""A Generator which generates the length of the pages in the factory.
Yields:
The page lengths.
"""
div, mod = divmod(self._total_count, self._limit)
yield from repeat(self._limit, div)
if mod != 0:
yield mod
[docs] def get_offsets(self) -> Offsets:
"""Get the Offsets instance created by the total_count and limit of this factory.
Returns:
The Offsets instance created by the total_count and limit of this factory.
"""
return Offsets(self._total_count, self._limit)
[docs]class LazySubFactory(LazyFactoryBase):
"""LazySubFactory is a factory for creating paging lists.
Arguments:
factory: The source LazyFactory instance.
keys: The keys to access the array from the source LazyFactory.
patype: The pyarrow DataType of the data in the sub-factory.
"""
def __init__(self, factory: LazyFactory, keys: Tuple[str, ...], patype: pa.DataType) -> None:
self._factory = factory
self._keys = keys
self._patype = patype
def __getitem__(self, key: str) -> "LazySubFactory":
return LazySubFactory(self._factory, self._keys + (key,), self._patype[key].type)
[docs] def create_list(self, mapper: Callable[[Any], _T]) -> PagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Returns:
A paging list created from the factory.
"""
return PagingList.from_factory(self._factory, self._keys, mapper)
[docs] def create_mapped_list(self, mapper: Callable[[Any], _T]) -> MappedPagingList[_T]:
"""Create a paging list from the factory.
Arguments:
mapper: A callable object to convert every item in the pyarrow array.
Returns:
A paging list created from the factory.
"""
return MappedPagingList.from_factory(self._factory, self._keys, mapper)
[docs] def create_pyarrow_list(self) -> PyArrowPagingList[Any]:
"""Create a paging list from the factory.
Returns:
A paging list created from the factory.
"""
return PyArrowPagingList.from_factory(self._factory, self._keys, self._patype)
[docs]class LazyLowerCaseFactory(LazyFactory):
"""LazyLowerCaseFactory is a factory to handle the case insensitive data from graviti back-end.
Arguments:
total_count: The total count of the elements in the paging lists.
limit: The size of each lazy load page.
getter: A callable object to get the source data.
patype: The pyarrow DataType of the data in the factory.
"""
def __init__(
self, total_count: int, limit: int, getter: Callable[[int, int], Any], patype: pa.DataType
) -> None:
super().__init__(total_count, limit, getter, self._lower_patype(patype))
def __getitem__(self, key: str) -> "LazyLowerCaseSubFactory":
lower_key = key.lower()
return LazyLowerCaseSubFactory(self, (lower_key,), self._patype[lower_key].type)
def _lower_patype(self, patype: pa.DataType) -> pa.DataType:
if isinstance(patype, pa.StructType):
return pa.struct(
{field.name.lower(): self._lower_patype(field.type) for field in patype}
)
if isinstance(patype, pa.ListType):
return pa.list_(self._lower_patype(patype.value_type))
return patype
[docs] def get_array(self, pos: int, keys: Tuple[str, ...]) -> pa.Array:
"""Get the array from the factory.
Arguments:
pos: The page number.
keys: The keys to access the array from factory.
Returns:
The requested pyarrow array.
"""
array = self._pages[pos]
if array is None:
array = pa.array(self._getter(pos * self._limit, self._limit), type=self._patype)
array = StructArrayWrapper(array)
self._pages[pos] = array
for key in keys:
array = array.field(key)
return array
[docs]class LazyLowerCaseSubFactory(LazySubFactory):
"""LazyLowerCaseSubFactory is a sub-factory to handle the case insensitive data.
Arguments:
factory: The source LazyFactory instance.
keys: The keys to access the array from the source LazyFactory.
patype: The pyarrow DataType of the data in the sub-factory.
"""
def __getitem__(self, key: str) -> "LazyLowerCaseSubFactory":
lower_key = key.lower()
return LazyLowerCaseSubFactory(
self._factory, self._keys + (lower_key,), self._patype[lower_key].type
)