import codecs import collections import logging import mimetypes import netrc import os import re import sys import time import typing import warnings from pathlib import Path from urllib.request import getproxies import sniffio from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover from ._models import URL _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} _HTML5_FORM_ENCODING_REPLACEMENTS.update( {chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B} ) _HTML5_FORM_ENCODING_RE = re.compile( r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) ) def normalize_header_key( value: typing.Union[str, bytes], encoding: str = None ) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header key. """ if isinstance(value, bytes): return value.lower() return value.encode(encoding or "ascii").lower() def normalize_header_value( value: typing.Union[str, bytes], encoding: str = None ) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header value. """ if isinstance(value, bytes): return value return value.encode(encoding or "ascii") def str_query_param(value: "PrimitiveData") -> str: """ Coerce a primitive data type into a string value for query params. Note that we prefer JSON-style 'true'/'false' for boolean values here. """ if value is True: return "true" elif value is False: return "false" elif value is None: return "" return str(value) def is_known_encoding(encoding: str) -> bool: """ Return `True` if `encoding` is a known codec. """ try: codecs.lookup(encoding) except LookupError: return False return True def format_form_param(name: str, value: typing.Union[str, bytes]) -> bytes: """ Encode a name/value pair within a multipart form. """ if isinstance(value, bytes): value = value.decode() def replacer(match: typing.Match[str]) -> str: return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) return f'{name}="{value}"'.encode() # Null bytes; no need to recreate these on each call to guess_json_utf _null = b"\x00" _null2 = _null * 2 _null3 = _null * 3 def guess_json_utf(data: bytes) -> typing.Optional[str]: # JSON always starts with two ASCII characters, so detection is as # easy as counting the nulls and from their location and count # determine the encoding. Also detect a BOM, if present. sample = data[:4] if sample in (codecs.BOM_UTF32_LE, codecs.BOM_UTF32_BE): return "utf-32" # BOM included if sample[:3] == codecs.BOM_UTF8: return "utf-8-sig" # BOM included, MS style (discouraged) if sample[:2] in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE): return "utf-16" # BOM included nullcount = sample.count(_null) if nullcount == 0: return "utf-8" if nullcount == 2: if sample[::2] == _null2: # 1st and 3rd are null return "utf-16-be" if sample[1::2] == _null2: # 2nd and 4th are null return "utf-16-le" # Did not detect 2 valid UTF-16 ascii-range characters if nullcount == 3: if sample[:3] == _null3: return "utf-32-be" if sample[1:] == _null3: return "utf-32-le" # Did not detect a valid UTF-32 ascii-range character return None class NetRCInfo: def __init__(self, files: typing.Optional[typing.List[str]] = None) -> None: if files is None: files = [os.getenv("NETRC", ""), "~/.netrc", "~/_netrc"] self.netrc_files = files @property def netrc_info(self) -> typing.Optional[netrc.netrc]: if not hasattr(self, "_netrc_info"): self._netrc_info = None for file_path in self.netrc_files: expanded_path = Path(file_path).expanduser() try: if expanded_path.is_file(): self._netrc_info = netrc.netrc(str(expanded_path)) break except (netrc.NetrcParseError, IOError): # pragma: nocover # Issue while reading the netrc file, ignore... pass return self._netrc_info def get_credentials(self, host: str) -> typing.Optional[typing.Tuple[str, str]]: if self.netrc_info is None: return None auth_info = self.netrc_info.authenticators(host) if auth_info is None or auth_info[2] is None: return None return (auth_info[0], auth_info[2]) def get_ca_bundle_from_env() -> typing.Optional[str]: if "SSL_CERT_FILE" in os.environ: ssl_file = Path(os.environ["SSL_CERT_FILE"]) if ssl_file.is_file(): return str(ssl_file) if "SSL_CERT_DIR" in os.environ: ssl_path = Path(os.environ["SSL_CERT_DIR"]) if ssl_path.is_dir(): return str(ssl_path) return None def parse_header_links(value: str) -> typing.List[typing.Dict[str, str]]: """ Returns a list of parsed link headers, for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link The generic syntax of those is: Link: < uri-reference >; param1=value1; param2="value2" So for instance: Link; '; type="image/jpeg",;' would return [ {"url": "http:/.../front.jpeg", "type": "image/jpeg"}, {"url": "http://.../back.jpeg"}, ] :param value: HTTP Link entity-header field :return: list of parsed link headers """ links: typing.List[typing.Dict[str, str]] = [] replace_chars = " '\"" value = value.strip(replace_chars) if not value: return links for val in re.split(", *<", value): try: url, params = val.split(";", 1) except ValueError: url, params = val, "" link = {"url": url.strip("<> '\"")} for param in params.split(";"): try: key, value = param.split("=") except ValueError: break link[key.strip(replace_chars)] = value.strip(replace_chars) links.append(link) return links SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} def obfuscate_sensitive_headers( items: typing.Iterable[typing.Tuple[typing.AnyStr, typing.AnyStr]] ) -> typing.Iterator[typing.Tuple[typing.AnyStr, typing.AnyStr]]: for k, v in items: if to_str(k.lower()) in SENSITIVE_HEADERS: v = to_bytes_or_str("[secure]", match_type_of=v) yield k, v _LOGGER_INITIALIZED = False TRACE_LOG_LEVEL = 5 class Logger(logging.Logger): # Stub for type checkers. def trace(self, message: str, *args: typing.Any, **kwargs: typing.Any) -> None: ... # pragma: nocover def get_logger(name: str) -> Logger: """ Get a `logging.Logger` instance, and optionally set up debug logging based on the HTTPX_LOG_LEVEL environment variable. """ global _LOGGER_INITIALIZED if not _LOGGER_INITIALIZED: _LOGGER_INITIALIZED = True logging.addLevelName(TRACE_LOG_LEVEL, "TRACE") log_level = os.environ.get("HTTPX_LOG_LEVEL", "").upper() if log_level in ("DEBUG", "TRACE"): logger = logging.getLogger("httpx") logger.setLevel(logging.DEBUG if log_level == "DEBUG" else TRACE_LOG_LEVEL) handler = logging.StreamHandler(sys.stderr) handler.setFormatter( logging.Formatter( fmt="%(levelname)s [%(asctime)s] %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) ) logger.addHandler(handler) logger = logging.getLogger(name) def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None: logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs) logger.trace = trace # type: ignore return typing.cast(Logger, logger) def port_or_default(url: "URL") -> typing.Optional[int]: if url.port is not None: return url.port return {"http": 80, "https": 443}.get(url.scheme) def same_origin(url: "URL", other: "URL") -> bool: """ Return 'True' if the given URLs share the same origin. """ return ( url.scheme == other.scheme and url.host == other.host and port_or_default(url) == port_or_default(other) ) def get_environment_proxies() -> typing.Dict[str, typing.Optional[str]]: """Gets proxy information from the environment""" # urllib.request.getproxies() falls back on System # Registry and Config for proxies on Windows and macOS. # We don't want to propagate non-HTTP proxies into # our configuration such as 'TRAVIS_APT_PROXY'. proxy_info = getproxies() mounts: typing.Dict[str, typing.Optional[str]] = {} for scheme in ("http", "https", "all"): if proxy_info.get(scheme): hostname = proxy_info[scheme] mounts[f"{scheme}://"] = ( hostname if "://" in hostname else f"http://{hostname}" ) no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] for hostname in no_proxy_hosts: # See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details # on how names in `NO_PROXY` are handled. if hostname == "*": # If NO_PROXY=* is used or if "*" occurs as any one of the comma # seperated hostnames, then we should just bypass any information # from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore # proxies. return {} elif hostname: # NO_PROXY=.google.com is marked as "all://*.google.com, # which disables "www.google.com" but not "google.com" # NO_PROXY=google.com is marked as "all://*google.com, # which disables "www.google.com" and "google.com". # (But not "wwwgoogle.com") mounts[f"all://*{hostname}"] = None return mounts def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes: return value.encode(encoding) if isinstance(value, str) else value def to_str(value: typing.Union[str, bytes], encoding: str = "utf-8") -> str: return value if isinstance(value, str) else value.decode(encoding) def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: return value if isinstance(match_type_of, str) else value.encode() def unquote(value: str) -> str: return value[1:-1] if value[0] == value[-1] == '"' else value def guess_content_type(filename: typing.Optional[str]) -> typing.Optional[str]: if filename: return mimetypes.guess_type(filename)[0] or "application/octet-stream" return None def peek_filelike_length(stream: typing.IO) -> int: """ Given a file-like stream object, return its length in number of bytes without reading it into memory. """ try: # Is it an actual file? fd = stream.fileno() except OSError: # No... Maybe it's something that supports random access, like `io.BytesIO`? try: # Assuming so, go to end of stream to figure out its length, # then put it back in place. offset = stream.tell() length = stream.seek(0, os.SEEK_END) stream.seek(offset) except OSError: # Not even that? Sorry, we're doomed... raise else: return length else: # Yup, seems to be an actual file. return os.fstat(fd).st_size def flatten_queryparams( queryparams: typing.Mapping[ str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]] ] ) -> typing.List[typing.Tuple[str, "PrimitiveData"]]: """ Convert a mapping of query params into a flat list of two-tuples representing each item. Example: >>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]}) [("q", "httpx), ("tag", "python"), ("tag", "dev")] """ items = [] for k, v in queryparams.items(): if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)): for u in v: items.append((k, u)) else: items.append((k, typing.cast("PrimitiveData", v))) return items class Timer: async def _get_time(self) -> float: library = sniffio.current_async_library() if library == "trio": import trio return trio.current_time() elif library == "curio": # pragma: nocover import curio return await curio.clock() import asyncio return asyncio.get_event_loop().time() def sync_start(self) -> None: self.started = time.perf_counter() async def async_start(self) -> None: self.started = await self._get_time() def sync_elapsed(self) -> float: now = time.perf_counter() return now - self.started async def async_elapsed(self) -> float: now = await self._get_time() return now - self.started class URLPattern: """ A utility class currently used for making lookups against proxy keys... # Wildcard matching... >>> pattern = URLPattern("all") >>> pattern.matches(httpx.URL("http://example.com")) True # Witch scheme matching... >>> pattern = URLPattern("https") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) False # With domain matching... >>> pattern = URLPattern("https://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) False >>> pattern.matches(httpx.URL("https://other.com")) False # Wildcard scheme, with domain matching... >>> pattern = URLPattern("all://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) True >>> pattern.matches(httpx.URL("https://other.com")) False # With port matching... >>> pattern = URLPattern("https://example.com:1234") >>> pattern.matches(httpx.URL("https://example.com:1234")) True >>> pattern.matches(httpx.URL("https://example.com")) False """ def __init__(self, pattern: str) -> None: from ._models import URL if pattern and ":" not in pattern: warn_deprecated( f"Proxy keys should use proper URL forms rather " f"than plain scheme strings. " f'Instead of "{pattern}", use "{pattern}://"' ) pattern += "://" url = URL(pattern) self.pattern = pattern self.scheme = "" if url.scheme == "all" else url.scheme self.host = "" if url.host == "*" else url.host self.port = url.port if not url.host or url.host == "*": self.host_regex: typing.Optional[typing.Pattern[str]] = None else: if url.host.startswith("*."): # *.example.com should match "www.example.com", but not "example.com" domain = re.escape(url.host[2:]) self.host_regex = re.compile(f"^.+\\.{domain}$") elif url.host.startswith("*"): # *example.com should match "www.example.com" and "example.com" domain = re.escape(url.host[1:]) self.host_regex = re.compile(f"^(.+\\.)?{domain}$") else: # example.com should match "example.com" but not "www.example.com" domain = re.escape(url.host) self.host_regex = re.compile(f"^{domain}$") def matches(self, other: "URL") -> bool: if self.scheme and self.scheme != other.scheme: return False if ( self.host and self.host_regex is not None and not self.host_regex.match(other.host) ): return False if self.port is not None and self.port != other.port: return False return True @property def priority(self) -> tuple: """ The priority allows URLPattern instances to be sortable, so that we can match from most specific to least specific. """ # URLs with a port should take priority over URLs without a port. port_priority = 0 if self.port is not None else 1 # Longer hostnames should match first. host_priority = -len(self.host) # Longer schemes should match first. scheme_priority = -len(self.scheme) return (port_priority, host_priority, scheme_priority) def __hash__(self) -> int: return hash(self.pattern) def __lt__(self, other: "URLPattern") -> bool: return self.priority < other.priority def __eq__(self, other: typing.Any) -> bool: return isinstance(other, URLPattern) and self.pattern == other.pattern def warn_deprecated(message: str) -> None: # pragma: nocover warnings.warn(message, DeprecationWarning, stacklevel=2)