Spaces:
Runtime error
Runtime error
import codecs | |
import collections | |
import logging | |
import mimetypes | |
import netrc | |
import os | |
import re | |
import sys | |
import time | |
import typing | |
import warnings | |
from pathlib import Path | |
from urllib.request import getproxies | |
import sniffio | |
from ._types import PrimitiveData | |
if typing.TYPE_CHECKING: # pragma: no cover | |
from ._models import URL | |
_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} | |
_HTML5_FORM_ENCODING_REPLACEMENTS.update( | |
{chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B} | |
) | |
_HTML5_FORM_ENCODING_RE = re.compile( | |
r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) | |
) | |
def normalize_header_key( | |
value: typing.Union[str, bytes], encoding: str = None | |
) -> bytes: | |
""" | |
Coerce str/bytes into a strictly byte-wise HTTP header key. | |
""" | |
if isinstance(value, bytes): | |
return value.lower() | |
return value.encode(encoding or "ascii").lower() | |
def normalize_header_value( | |
value: typing.Union[str, bytes], encoding: str = None | |
) -> bytes: | |
""" | |
Coerce str/bytes into a strictly byte-wise HTTP header value. | |
""" | |
if isinstance(value, bytes): | |
return value | |
return value.encode(encoding or "ascii") | |
def str_query_param(value: "PrimitiveData") -> str: | |
""" | |
Coerce a primitive data type into a string value for query params. | |
Note that we prefer JSON-style 'true'/'false' for boolean values here. | |
""" | |
if value is True: | |
return "true" | |
elif value is False: | |
return "false" | |
elif value is None: | |
return "" | |
return str(value) | |
def is_known_encoding(encoding: str) -> bool: | |
""" | |
Return `True` if `encoding` is a known codec. | |
""" | |
try: | |
codecs.lookup(encoding) | |
except LookupError: | |
return False | |
return True | |
def format_form_param(name: str, value: typing.Union[str, bytes]) -> bytes: | |
""" | |
Encode a name/value pair within a multipart form. | |
""" | |
if isinstance(value, bytes): | |
value = value.decode() | |
def replacer(match: typing.Match[str]) -> str: | |
return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] | |
value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) | |
return f'{name}="{value}"'.encode() | |
# Null bytes; no need to recreate these on each call to guess_json_utf | |
_null = b"\x00" | |
_null2 = _null * 2 | |
_null3 = _null * 3 | |
def guess_json_utf(data: bytes) -> typing.Optional[str]: | |
# JSON always starts with two ASCII characters, so detection is as | |
# easy as counting the nulls and from their location and count | |
# determine the encoding. Also detect a BOM, if present. | |
sample = data[:4] | |
if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): | |
return "utf-32" # BOM included | |
if sample[:3] == codecs.BOM_UTF8: | |
return "utf-8-sig" # BOM included, MS style (discouraged) | |
if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): | |
return "utf-16" # BOM included | |
nullcount = sample.count(_null) | |
if nullcount == 0: | |
return "utf-8" | |
if nullcount == 2: | |
if sample[::2] == _null2: # 1st and 3rd are null | |
return "utf-16-be" | |
if sample[1::2] == _null2: # 2nd and 4th are null | |
return "utf-16-le" | |
# Did not detect 2 valid UTF-16 ascii-range characters | |
if nullcount == 3: | |
if sample[:3] == _null3: | |
return "utf-32-be" | |
if sample[1:] == _null3: | |
return "utf-32-le" | |
# Did not detect a valid UTF-32 ascii-range character | |
return None | |
class NetRCInfo: | |
def __init__(self, files: typing.Optional[typing.List[str]] = None) -> None: | |
if files is None: | |
files = [os.getenv("NETRC", ""), "~/.netrc", "~/_netrc"] | |
self.netrc_files = files | |
def netrc_info(self) -> typing.Optional[netrc.netrc]: | |
if not hasattr(self, "_netrc_info"): | |
self._netrc_info = None | |
for file_path in self.netrc_files: | |
expanded_path = Path(file_path).expanduser() | |
try: | |
if expanded_path.is_file(): | |
self._netrc_info = netrc.netrc(str(expanded_path)) | |
break | |
except (netrc.NetrcParseError, IOError): # pragma: nocover | |
# Issue while reading the netrc file, ignore... | |
pass | |
return self._netrc_info | |
def get_credentials(self, host: str) -> typing.Optional[typing.Tuple[str, str]]: | |
if self.netrc_info is None: | |
return None | |
auth_info = self.netrc_info.authenticators(host) | |
if auth_info is None or auth_info[2] is None: | |
return None | |
return (auth_info[0], auth_info[2]) | |
def get_ca_bundle_from_env() -> typing.Optional[str]: | |
if "SSL_CERT_FILE" in os.environ: | |
ssl_file = Path(os.environ["SSL_CERT_FILE"]) | |
if ssl_file.is_file(): | |
return str(ssl_file) | |
if "SSL_CERT_DIR" in os.environ: | |
ssl_path = Path(os.environ["SSL_CERT_DIR"]) | |
if ssl_path.is_dir(): | |
return str(ssl_path) | |
return None | |
def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]: | |
""" | |
Returns a list of parsed link headers, for more info see: | |
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link | |
The generic syntax of those is: | |
Link: < uri-reference >; param1=value1; param2="value2" | |
So for instance: | |
Link; '<http:/.../front.jpeg>; type="image/jpeg",<http://.../back.jpeg>;' | |
would return | |
[ | |
{"url": "http:/.../front.jpeg", "type": "image/jpeg"}, | |
{"url": "http://.../back.jpeg"}, | |
] | |
:param value: HTTP Link entity-header field | |
:return: list of parsed link headers | |
""" | |
links: typing.List[typing.Dict[str, str]] = [] | |
replace_chars = " '\"" | |
value = value.strip(replace_chars) | |
if not value: | |
return links | |
for val in re.split(", *<", value): | |
try: | |
url, params = val.split(";", 1) | |
except ValueError: | |
url, params = val, "" | |
link = {"url": url.strip("<> '\"")} | |
for param in params.split(";"): | |
try: | |
key, value = param.split("=") | |
except ValueError: | |
break | |
link[key.strip(replace_chars)] = value.strip(replace_chars) | |
links.append(link) | |
return links | |
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} | |
def obfuscate_sensitive_headers( | |
items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] | |
) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: | |
for k, v in items: | |
if to_str(k.lower()) in SENSITIVE_HEADERS: | |
v = to_bytes_or_str("[secure]", match_type_of=v) | |
yield k, v | |
_LOGGER_INITIALIZED = False | |
TRACE_LOG_LEVEL = 5 | |
class Logger(logging.Logger): | |
# Stub for type checkers. | |
def trace(self, message: str, *args: typing.Any, **kwargs: typing.Any) -> None: | |
... # pragma: nocover | |
def get_logger(name: str) -> Logger: | |
""" | |
Get a `logging.Logger` instance, and optionally | |
set up debug logging based on the HTTPX_LOG_LEVEL environment variable. | |
""" | |
global _LOGGER_INITIALIZED | |
if not _LOGGER_INITIALIZED: | |
_LOGGER_INITIALIZED = True | |
logging.addLevelName(TRACE_LOG_LEVEL, "TRACE") | |
log_level = os.environ.get("HTTPX_LOG_LEVEL", "").upper() | |
if log_level in ("DEBUG", "TRACE"): | |
logger = logging.getLogger("httpx") | |
logger.setLevel(logging.DEBUG if log_level == "DEBUG" else TRACE_LOG_LEVEL) | |
handler = logging.StreamHandler(sys.stderr) | |
handler.setFormatter( | |
logging.Formatter( | |
fmt="%(levelname)s [%(asctime)s] %(name)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
) | |
logger.addHandler(handler) | |
logger = logging.getLogger(name) | |
def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None: | |
logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs) | |
logger.trace = trace # type: ignore | |
return typing.cast(Logger, logger) | |
def port_or_default(url: "URL") -> typing.Optional[int]: | |
if url.port is not None: | |
return url.port | |
return {"http": 80, "https": 443}.get(url.scheme) | |
def same_origin(url: "URL", other: "URL") -> bool: | |
""" | |
Return 'True' if the given URLs share the same origin. | |
""" | |
return ( | |
url.scheme == other.scheme | |
and url.host == other.host | |
and port_or_default(url) == port_or_default(other) | |
) | |
def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]: | |
"""Gets proxy information from the environment""" | |
# urllib.request.getproxies() falls back on System | |
# Registry and Config for proxies on Windows and macOS. | |
# We don't want to propagate non-HTTP proxies into | |
# our configuration such as 'TRAVIS_APT_PROXY'. | |
proxy_info = getproxies() | |
mounts: typing.Dict[str, typing.Optional[str]] = {} | |
for scheme in ("http", "https", "all"): | |
if proxy_info.get(scheme): | |
hostname = proxy_info[scheme] | |
mounts[f"{scheme}://"] = ( | |
hostname if "://" in hostname else f"http://{hostname}" | |
) | |
no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] | |
for hostname in no_proxy_hosts: | |
# See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details | |
# on how names in `NO_PROXY` are handled. | |
if hostname == "*": | |
# If NO_PROXY=* is used or if "*" occurs as any one of the comma | |
# seperated hostnames, then we should just bypass any information | |
# from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore | |
# proxies. | |
return {} | |
elif hostname: | |
# NO_PROXY=.google.com is marked as "all://*.google.com, | |
# which disables "www.google.com" but not "google.com" | |
# NO_PROXY=google.com is marked as "all://*google.com, | |
# which disables "www.google.com" and "google.com". | |
# (But not "wwwgoogle.com") | |
mounts[f"all://*{hostname}"] = None | |
return mounts | |
def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes: | |
return value.encode(encoding) if isinstance(value, str) else value | |
def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: | |
return value if isinstance(value, str) else value.decode(encoding) | |
def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: | |
return value if isinstance(match_type_of, str) else value.encode() | |
def unquote(value: str) -> str: | |
return value[1:-1] if value[0] == value[-1] == '"' else value | |
def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]: | |
if filename: | |
return mimetypes.guess_type(filename)[0] or "application/octet-stream" | |
return None | |
def peek_filelike_length(stream: typing.IO) -> int: | |
""" | |
Given a file-like stream object, return its length in number of bytes | |
without reading it into memory. | |
""" | |
try: | |
# Is it an actual file? | |
fd = stream.fileno() | |
except OSError: | |
# No... Maybe it's something that supports random access, like `io.BytesIO`? | |
try: | |
# Assuming so, go to end of stream to figure out its length, | |
# then put it back in place. | |
offset = stream.tell() | |
length = stream.seek(0, os.SEEK_END) | |
stream.seek(offset) | |
except OSError: | |
# Not even that? Sorry, we're doomed... | |
raise | |
else: | |
return length | |
else: | |
# Yup, seems to be an actual file. | |
return os.fstat(fd).st_size | |
def flatten_queryparams( | |
queryparams: typing.Mapping[ | |
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]] | |
] | |
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]: | |
""" | |
Convert a mapping of query params into a flat list of two-tuples | |
representing each item. | |
Example: | |
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]}) | |
[("q", "httpx), ("tag", "python"), ("tag", "dev")] | |
""" | |
items = [] | |
for k, v in queryparams.items(): | |
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)): | |
for u in v: | |
items.append((k, u)) | |
else: | |
items.append((k, typing.cast("PrimitiveData", v))) | |
return items | |
class Timer: | |
async def _get_time(self) -> float: | |
library = sniffio.current_async_library() | |
if library == "trio": | |
import trio | |
return trio.current_time() | |
elif library == "curio": # pragma: nocover | |
import curio | |
return await curio.clock() | |
import asyncio | |
return asyncio.get_event_loop().time() | |
def sync_start(self) -> None: | |
self.started = time.perf_counter() | |
async def async_start(self) -> None: | |
self.started = await self._get_time() | |
def sync_elapsed(self) -> float: | |
now = time.perf_counter() | |
return now - self.started | |
async def async_elapsed(self) -> float: | |
now = await self._get_time() | |
return now - self.started | |
class URLPattern: | |
""" | |
A utility class currently used for making lookups against proxy keys... | |
# Wildcard matching... | |
>>> pattern = URLPattern("all") | |
>>> pattern.matches(httpx.URL("http://example.com")) | |
True | |
# Witch scheme matching... | |
>>> pattern = URLPattern("https") | |
>>> pattern.matches(httpx.URL("https://example.com")) | |
True | |
>>> pattern.matches(httpx.URL("http://example.com")) | |
False | |
# With domain matching... | |
>>> pattern = URLPattern("https://example.com") | |
>>> pattern.matches(httpx.URL("https://example.com")) | |
True | |
>>> pattern.matches(httpx.URL("http://example.com")) | |
False | |
>>> pattern.matches(httpx.URL("https://other.com")) | |
False | |
# Wildcard scheme, with domain matching... | |
>>> pattern = URLPattern("all://example.com") | |
>>> pattern.matches(httpx.URL("https://example.com")) | |
True | |
>>> pattern.matches(httpx.URL("http://example.com")) | |
True | |
>>> pattern.matches(httpx.URL("https://other.com")) | |
False | |
# With port matching... | |
>>> pattern = URLPattern("https://example.com:1234") | |
>>> pattern.matches(httpx.URL("https://example.com:1234")) | |
True | |
>>> pattern.matches(httpx.URL("https://example.com")) | |
False | |
""" | |
def __init__(self, pattern: str) -> None: | |
from ._models import URL | |
if pattern and ":" not in pattern: | |
warn_deprecated( | |
f"Proxy keys should use proper URL forms rather " | |
f"than plain scheme strings. " | |
f'Instead of "{pattern}", use "{pattern}://"' | |
) | |
pattern += "://" | |
url = URL(pattern) | |
self.pattern = pattern | |
self.scheme = "" if url.scheme == "all" else url.scheme | |
self.host = "" if url.host == "*" else url.host | |
self.port = url.port | |
if not url.host or url.host == "*": | |
self.host_regex: typing.Optional[typing.Pattern[str]] = None | |
else: | |
if url.host.startswith("*."): | |
# *.example.com should match "www.example.com", but not "example.com" | |
domain = re.escape(url.host[2:]) | |
self.host_regex = re.compile(f"^.+\\.{domain}$") | |
elif url.host.startswith("*"): | |
# *example.com should match "www.example.com" and "example.com" | |
domain = re.escape(url.host[1:]) | |
self.host_regex = re.compile(f"^(.+\\.)?{domain}$") | |
else: | |
# example.com should match "example.com" but not "www.example.com" | |
domain = re.escape(url.host) | |
self.host_regex = re.compile(f"^{domain}$") | |
def matches(self, other: "URL") -> bool: | |
if self.scheme and self.scheme != other.scheme: | |
return False | |
if ( | |
self.host | |
and self.host_regex is not None | |
and not self.host_regex.match(other.host) | |
): | |
return False | |
if self.port is not None and self.port != other.port: | |
return False | |
return True | |
def priority(self) -> tuple: | |
""" | |
The priority allows URLPattern instances to be sortable, so that | |
we can match from most specific to least specific. | |
""" | |
# URLs with a port should take priority over URLs without a port. | |
port_priority = 0 if self.port is not None else 1 | |
# Longer hostnames should match first. | |
host_priority = -len(self.host) | |
# Longer schemes should match first. | |
scheme_priority = -len(self.scheme) | |
return (port_priority, host_priority, scheme_priority) | |
def __hash__(self) -> int: | |
return hash(self.pattern) | |
def __lt__(self, other: "URLPattern") -> bool: | |
return self.priority < other.priority | |
def __eq__(self, other: typing.Any) -> bool: | |
return isinstance(other, URLPattern) and self.pattern == other.pattern | |
def warn_deprecated(message: str) -> None: # pragma: nocover | |
warnings.warn(message, DeprecationWarning, stacklevel=2) | |