File size: 4,348 Bytes
ab4488b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# SPDX-License-Identifier: MIT
# Copyright (c) 2019 Martijn Pieters
# Licensed under the MIT license as detailed in LICENSE.txt
import asyncio
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import Dict, Optional, Type
from .compat import wait_for
class AsyncLimiter(AbstractAsyncContextManager):
"""A leaky bucket rate limiter.
This is an :ref:`asynchronous context manager <async-context-managers>`;
when used with :keyword:`async with`, entering the context acquires
capacity::
limiter = AsyncLimiter(10)
for foo in bar:
async with limiter:
# process foo elements at 10 items per minute
:param max_rate: Allow up to `max_rate` / `time_period` acquisitions before
blocking.
:param time_period: duration, in seconds, of the time period in which to
limit the rate. Note that up to `max_rate` acquisitions are allowed
within this time period in a burst.
"""
__slots__ = (
"max_rate",
"time_period",
"_rate_per_sec",
"_level",
"_last_check",
"_waiters",
)
max_rate: float #: The configured `max_rate` value for this limiter.
time_period: float #: The configured `time_period` value for this limiter.
def __init__(self, max_rate: float, time_period: float = 60) -> None:
self.max_rate = max_rate
self.time_period = time_period
self._rate_per_sec = max_rate / time_period
self._level = 0.0
self._last_check = 0.0
# queue of waiting futures to signal capacity to
self._waiters: Dict[asyncio.Task, asyncio.Future] = {}
def _leak(self) -> None:
"""Drip out capacity from the bucket."""
loop = asyncio.get_running_loop()
if self._level:
# drip out enough level for the elapsed time since
# we last checked
elapsed = loop.time() - self._last_check
decrement = elapsed * self._rate_per_sec
self._level = max(self._level - decrement, 0)
self._last_check = loop.time()
def has_capacity(self, amount: float = 1) -> bool:
"""Check if there is enough capacity remaining in the limiter
:param amount: How much capacity you need to be available.
"""
self._leak()
requested = self._level + amount
# if there are tasks waiting for capacity, signal to the first
# there there may be some now (they won't wake up until this task
# yields with an await)
if requested < self.max_rate:
for fut in self._waiters.values():
if not fut.done():
fut.set_result(True)
break
return self._level + amount <= self.max_rate
async def acquire(self, amount: float = 1) -> None:
"""Acquire capacity in the limiter.
If the limit has been reached, blocks until enough capacity has been
freed before returning.
:param amount: How much capacity you need to be available.
:exception: Raises :exc:`ValueError` if `amount` is greater than
:attr:`max_rate`.
"""
if amount > self.max_rate:
raise ValueError("Can't acquire more than the maximum capacity")
loop = asyncio.get_running_loop()
task = asyncio.current_task(loop)
assert task is not None
while not self.has_capacity(amount):
# wait for the next drip to have left the bucket
# add a future to the _waiters map to be notified
# 'early' if capacity has come up
fut = loop.create_future()
self._waiters[task] = fut
try:
await wait_for(
asyncio.shield(fut), 1 / self._rate_per_sec * amount, loop=loop
)
except asyncio.TimeoutError:
pass
fut.cancel()
self._waiters.pop(task, None)
self._level += amount
return None
async def __aenter__(self) -> None:
await self.acquire()
return None
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
return None
|