|
""" |
|
Telnet server. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import asyncio |
|
import contextvars |
|
import socket |
|
from asyncio import get_running_loop |
|
from typing import Any, Callable, Coroutine, TextIO, cast |
|
|
|
from prompt_toolkit.application.current import create_app_session, get_app |
|
from prompt_toolkit.application.run_in_terminal import run_in_terminal |
|
from prompt_toolkit.data_structures import Size |
|
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text |
|
from prompt_toolkit.input import PipeInput, create_pipe_input |
|
from prompt_toolkit.output.vt100 import Vt100_Output |
|
from prompt_toolkit.renderer import print_formatted_text as print_formatted_text |
|
from prompt_toolkit.styles import BaseStyle, DummyStyle |
|
|
|
from .log import logger |
|
from .protocol import ( |
|
DO, |
|
ECHO, |
|
IAC, |
|
LINEMODE, |
|
MODE, |
|
NAWS, |
|
SB, |
|
SE, |
|
SEND, |
|
SUPPRESS_GO_AHEAD, |
|
TTYPE, |
|
WILL, |
|
TelnetProtocolParser, |
|
) |
|
|
|
__all__ = [ |
|
"TelnetServer", |
|
] |
|
|
|
|
|
def int2byte(number: int) -> bytes: |
|
return bytes((number,)) |
|
|
|
|
|
def _initialize_telnet(connection: socket.socket) -> None: |
|
logger.info("Initializing telnet connection") |
|
|
|
|
|
connection.send(IAC + DO + LINEMODE) |
|
|
|
|
|
|
|
connection.send(IAC + WILL + SUPPRESS_GO_AHEAD) |
|
|
|
|
|
connection.send(IAC + SB + LINEMODE + MODE + int2byte(0) + IAC + SE) |
|
|
|
|
|
connection.send(IAC + WILL + ECHO) |
|
|
|
|
|
connection.send(IAC + DO + NAWS) |
|
|
|
|
|
|
|
connection.send(IAC + DO + TTYPE) |
|
|
|
|
|
|
|
|
|
connection.send(IAC + SB + TTYPE + SEND + IAC + SE) |
|
|
|
|
|
class _ConnectionStdout: |
|
""" |
|
Wrapper around socket which provides `write` and `flush` methods for the |
|
Vt100_Output output. |
|
""" |
|
|
|
def __init__(self, connection: socket.socket, encoding: str) -> None: |
|
self._encoding = encoding |
|
self._connection = connection |
|
self._errors = "strict" |
|
self._buffer: list[bytes] = [] |
|
self._closed = False |
|
|
|
def write(self, data: str) -> None: |
|
data = data.replace("\n", "\r\n") |
|
self._buffer.append(data.encode(self._encoding, errors=self._errors)) |
|
self.flush() |
|
|
|
def isatty(self) -> bool: |
|
return True |
|
|
|
def flush(self) -> None: |
|
try: |
|
if not self._closed: |
|
self._connection.send(b"".join(self._buffer)) |
|
except OSError as e: |
|
logger.warning(f"Couldn't send data over socket: {e}") |
|
|
|
self._buffer = [] |
|
|
|
def close(self) -> None: |
|
self._closed = True |
|
|
|
@property |
|
def encoding(self) -> str: |
|
return self._encoding |
|
|
|
@property |
|
def errors(self) -> str: |
|
return self._errors |
|
|
|
|
|
class TelnetConnection: |
|
""" |
|
Class that represents one Telnet connection. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
conn: socket.socket, |
|
addr: tuple[str, int], |
|
interact: Callable[[TelnetConnection], Coroutine[Any, Any, None]], |
|
server: TelnetServer, |
|
encoding: str, |
|
style: BaseStyle | None, |
|
vt100_input: PipeInput, |
|
enable_cpr: bool = True, |
|
) -> None: |
|
self.conn = conn |
|
self.addr = addr |
|
self.interact = interact |
|
self.server = server |
|
self.encoding = encoding |
|
self.style = style |
|
self._closed = False |
|
self._ready = asyncio.Event() |
|
self.vt100_input = vt100_input |
|
self.enable_cpr = enable_cpr |
|
self.vt100_output: Vt100_Output | None = None |
|
|
|
|
|
self.size = Size(rows=40, columns=79) |
|
|
|
|
|
_initialize_telnet(conn) |
|
|
|
|
|
def get_size() -> Size: |
|
return self.size |
|
|
|
self.stdout = cast(TextIO, _ConnectionStdout(conn, encoding=encoding)) |
|
|
|
def data_received(data: bytes) -> None: |
|
"""TelnetProtocolParser 'data_received' callback""" |
|
self.vt100_input.send_bytes(data) |
|
|
|
def size_received(rows: int, columns: int) -> None: |
|
"""TelnetProtocolParser 'size_received' callback""" |
|
self.size = Size(rows=rows, columns=columns) |
|
if self.vt100_output is not None and self.context: |
|
self.context.run(lambda: get_app()._on_resize()) |
|
|
|
def ttype_received(ttype: str) -> None: |
|
"""TelnetProtocolParser 'ttype_received' callback""" |
|
self.vt100_output = Vt100_Output( |
|
self.stdout, get_size, term=ttype, enable_cpr=enable_cpr |
|
) |
|
self._ready.set() |
|
|
|
self.parser = TelnetProtocolParser(data_received, size_received, ttype_received) |
|
self.context: contextvars.Context | None = None |
|
|
|
async def run_application(self) -> None: |
|
""" |
|
Run application. |
|
""" |
|
|
|
def handle_incoming_data() -> None: |
|
data = self.conn.recv(1024) |
|
if data: |
|
self.feed(data) |
|
else: |
|
|
|
logger.info("Connection closed by client. {!r} {!r}".format(*self.addr)) |
|
self.close() |
|
|
|
|
|
loop = get_running_loop() |
|
loop.add_reader(self.conn, handle_incoming_data) |
|
|
|
try: |
|
|
|
await self._ready.wait() |
|
with create_app_session(input=self.vt100_input, output=self.vt100_output): |
|
self.context = contextvars.copy_context() |
|
await self.interact(self) |
|
finally: |
|
self.close() |
|
|
|
def feed(self, data: bytes) -> None: |
|
""" |
|
Handler for incoming data. (Called by TelnetServer.) |
|
""" |
|
self.parser.feed(data) |
|
|
|
def close(self) -> None: |
|
""" |
|
Closed by client. |
|
""" |
|
if not self._closed: |
|
self._closed = True |
|
|
|
self.vt100_input.close() |
|
get_running_loop().remove_reader(self.conn) |
|
self.conn.close() |
|
self.stdout.close() |
|
|
|
def send(self, formatted_text: AnyFormattedText) -> None: |
|
""" |
|
Send text to the client. |
|
""" |
|
if self.vt100_output is None: |
|
return |
|
formatted_text = to_formatted_text(formatted_text) |
|
print_formatted_text( |
|
self.vt100_output, formatted_text, self.style or DummyStyle() |
|
) |
|
|
|
def send_above_prompt(self, formatted_text: AnyFormattedText) -> None: |
|
""" |
|
Send text to the client. |
|
This is asynchronous, returns a `Future`. |
|
""" |
|
formatted_text = to_formatted_text(formatted_text) |
|
return self._run_in_terminal(lambda: self.send(formatted_text)) |
|
|
|
def _run_in_terminal(self, func: Callable[[], None]) -> None: |
|
|
|
|
|
if self.context: |
|
self.context.run(run_in_terminal, func) |
|
else: |
|
raise RuntimeError("Called _run_in_terminal outside `run_application`.") |
|
|
|
def erase_screen(self) -> None: |
|
""" |
|
Erase the screen and move the cursor to the top. |
|
""" |
|
if self.vt100_output is None: |
|
return |
|
self.vt100_output.erase_screen() |
|
self.vt100_output.cursor_goto(0, 0) |
|
self.vt100_output.flush() |
|
|
|
|
|
async def _dummy_interact(connection: TelnetConnection) -> None: |
|
pass |
|
|
|
|
|
class TelnetServer: |
|
""" |
|
Telnet server implementation. |
|
|
|
Example:: |
|
|
|
async def interact(connection): |
|
connection.send("Welcome") |
|
session = PromptSession() |
|
result = await session.prompt_async(message="Say something: ") |
|
connection.send(f"You said: {result}\n") |
|
|
|
async def main(): |
|
server = TelnetServer(interact=interact, port=2323) |
|
await server.run() |
|
""" |
|
|
|
def __init__( |
|
self, |
|
host: str = "127.0.0.1", |
|
port: int = 23, |
|
interact: Callable[ |
|
[TelnetConnection], Coroutine[Any, Any, None] |
|
] = _dummy_interact, |
|
encoding: str = "utf-8", |
|
style: BaseStyle | None = None, |
|
enable_cpr: bool = True, |
|
) -> None: |
|
self.host = host |
|
self.port = port |
|
self.interact = interact |
|
self.encoding = encoding |
|
self.style = style |
|
self.enable_cpr = enable_cpr |
|
|
|
self._run_task: asyncio.Task[None] | None = None |
|
self._application_tasks: list[asyncio.Task[None]] = [] |
|
|
|
self.connections: set[TelnetConnection] = set() |
|
|
|
@classmethod |
|
def _create_socket(cls, host: str, port: int) -> socket.socket: |
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
s.bind((host, port)) |
|
|
|
s.listen(4) |
|
return s |
|
|
|
async def run(self, ready_cb: Callable[[], None] | None = None) -> None: |
|
""" |
|
Run the telnet server, until this gets cancelled. |
|
|
|
:param ready_cb: Callback that will be called at the point that we're |
|
actually listening. |
|
""" |
|
socket = self._create_socket(self.host, self.port) |
|
logger.info( |
|
"Listening for telnet connections on %s port %r", self.host, self.port |
|
) |
|
|
|
get_running_loop().add_reader(socket, lambda: self._accept(socket)) |
|
|
|
if ready_cb: |
|
ready_cb() |
|
|
|
try: |
|
|
|
await asyncio.Future() |
|
finally: |
|
get_running_loop().remove_reader(socket) |
|
socket.close() |
|
|
|
|
|
for t in self._application_tasks: |
|
t.cancel() |
|
|
|
|
|
|
|
|
|
|
|
if len(self._application_tasks) > 0: |
|
await asyncio.wait( |
|
self._application_tasks, |
|
timeout=None, |
|
return_when=asyncio.ALL_COMPLETED, |
|
) |
|
|
|
def start(self) -> None: |
|
""" |
|
Deprecated: Use `.run()` instead. |
|
|
|
Start the telnet server (stop by calling and awaiting `stop()`). |
|
""" |
|
if self._run_task is not None: |
|
|
|
return |
|
|
|
self._run_task = get_running_loop().create_task(self.run()) |
|
|
|
async def stop(self) -> None: |
|
""" |
|
Deprecated: Use `.run()` instead. |
|
|
|
Stop a telnet server that was started using `.start()` and wait for the |
|
cancellation to complete. |
|
""" |
|
if self._run_task is not None: |
|
self._run_task.cancel() |
|
try: |
|
await self._run_task |
|
except asyncio.CancelledError: |
|
pass |
|
|
|
def _accept(self, listen_socket: socket.socket) -> None: |
|
""" |
|
Accept new incoming connection. |
|
""" |
|
conn, addr = listen_socket.accept() |
|
logger.info("New connection %r %r", *addr) |
|
|
|
|
|
async def run() -> None: |
|
try: |
|
with create_pipe_input() as vt100_input: |
|
connection = TelnetConnection( |
|
conn, |
|
addr, |
|
self.interact, |
|
self, |
|
encoding=self.encoding, |
|
style=self.style, |
|
vt100_input=vt100_input, |
|
enable_cpr=self.enable_cpr, |
|
) |
|
self.connections.add(connection) |
|
|
|
logger.info("Starting interaction %r %r", *addr) |
|
try: |
|
await connection.run_application() |
|
finally: |
|
self.connections.remove(connection) |
|
logger.info("Stopping interaction %r %r", *addr) |
|
except EOFError: |
|
|
|
|
|
|
|
|
|
logger.info("Unhandled EOFError in telnet application.") |
|
except KeyboardInterrupt: |
|
|
|
logger.info("Unhandled KeyboardInterrupt in telnet application.") |
|
except BaseException as e: |
|
print(f"Got {type(e).__name__}", e) |
|
import traceback |
|
|
|
traceback.print_exc() |
|
finally: |
|
self._application_tasks.remove(task) |
|
|
|
task = get_running_loop().create_task(run()) |
|
self._application_tasks.append(task) |
|
|