Build own iterator over enumerate
Posted on Mon, 27 Jun 2016 in Python
Couple days ago one of my colleagues asked me to implement simple self-logging iterator inherited from enumerate. I tried to inherit directly from it and failed. I absolutely forgot how to use __new__ magic method. It was so embarrassing. I was in a hurry, so I promised myself to solve this puzzle. And I can tell you it is a simple task. Actually it takes 18 lines of code only.
Initial task
First of all I have to explain why we need such iterator, why standard enumerate isn't enough.
In our project we had a lot of console tasks with similar structure:
- Get bunch of object from DB
- Log how many objects you have
- Do something with each object from the bunch, log progress every X objects
- Log after finish
In python it looked like this:
iterable = get_bunch()
total = len(iterable)
print("total: {}".format(total))
for i, item in enumerate(iterable, start=1):
try:
func(item)
except Exception as e:
print("catch exception: {}".format(e))
if not i % 100:
print("done {} of {}".format(i, total))
print("Done!")
Difference between tasks was only in log messages and func function. It looked like copy-paste code. So we decided to refactor it.
Implementation
Let's try to make class inherited from enumerate. As I mentioned above, we have to override __new__ method because enumerate does it. You should remember that in accordance to documentation if __new__() returns an instance of cls, then the new instance’s __init__() method will be invoked with the same arguments.
So our implementation looks like these:
class LogEnumerate(enumerate):
def __new__(cls, iterable, start=1, *args, **kwargs):
return super(LogEnumerate, cls).__new__(cls, iterable, start)
def __init__(self, iterable, start=1, step=10,
start_message='', progress_message='', stop_message=''):
self.progress_message = progress_message
self.stop_message = stop_message
self.step = step
self.total = len(iterable)
print(start_message.format(start_message))
def __next__(self):
try:
i, item = super().__next__()
if not i % self.step:
print(self.progress_message.format(i, self.total))
return item
except StopIteration:
print(self.stop_message)
raise
Got a question? Hit me on Twitter: avkorablev