# tsuserver3, an Attorney Online server # # Copyright (C) 2017 argoneus # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # Partly authored by Johan Hanssen Seferidis (MIT license): # https://github.com/Pithikos/python-websocket-server import asyncio import re import struct from base64 import b64encode from hashlib import sha1 from server import logger class Bitmasks: FIN = 0x80 OPCODE = 0x0f MASKED = 0x80 PAYLOAD_LEN = 0x7f PAYLOAD_LEN_EXT16 = 0x7e PAYLOAD_LEN_EXT64 = 0x7f class Opcode: CONTINUATION = 0x0 TEXT = 0x1 BINARY = 0x2 CLOSE_CONN = 0x8 PING = 0x9 PONG = 0xA class WebSocket: """ State data for clients that are connected via a WebSocket that wraps over a conventional TCP connection. """ def __init__(self, client, protocol): self.client = client self.transport = client.transport self.protocol = protocol self.keep_alive = True self.handshake_done = False self.valid = False def handle(self, data): if not self.handshake_done: return self.handshake(data) return self.parse(data) def parse(self, data): b1, b2 = 0, 0 if len(data) >= 2: b1, b2 = data[0], data[1] fin = b1 & Bitmasks.FIN opcode = b1 & Bitmasks.OPCODE masked = b2 & Bitmasks.MASKED payload_length = b2 & Bitmasks.PAYLOAD_LEN if not b1: # Connection closed self.keep_alive = 0 return if opcode == Opcode.CLOSE_CONN: # Connection close requested self.keep_alive = 0 return if not masked: # Client was not masked (spec violation) logger.log_debug("ws: client was not masked.", self.client) self.keep_alive = 0 print(data) return if opcode == Opcode.CONTINUATION: # No continuation frames supported logger.log_debug("ws: client tried to send continuation frame.", self.client) return elif opcode == Opcode.BINARY: # No binary frames supported logger.log_debug("ws: client tried to send binary frame.", self.client) return elif opcode == Opcode.TEXT: def opcode_handler(s, msg): return msg elif opcode == Opcode.PING: opcode_handler = self.send_pong elif opcode == Opcode.PONG: opcode_handler = lambda s, msg: None else: # Unknown opcode logger.log_debug("ws: unknown opcode!", self.client) self.keep_alive = 0 return if payload_length == 126: payload_length = struct.unpack(">H", data[2:4])[0] elif payload_length == 127: payload_length = struct.unpack(">Q", data[2:10])[0] masks = data[2:6] decoded = "" for char in data[6:payload_length + 6]: char ^= masks[len(decoded) % 4] decoded += chr(char) return opcode_handler(self, decoded) def send_message(self, message): self.send_text(message) def send_pong(self, message): self.send_text(message, Opcode.PONG) def send_text(self, message, opcode=Opcode.TEXT): """ Important: Fragmented (continuation) messages are not supported since their usage cases are limited - when we don't know the payload length. """ # Validate message if isinstance(message, bytes): message = message.decode("utf-8") elif isinstance(message, str): pass else: raise TypeError("Message must be either str or bytes") header = bytearray() payload = message.encode("utf-8") payload_length = len(payload) # Normal payload if payload_length <= 125: header.append(Bitmasks.FIN | opcode) header.append(payload_length) # Extended payload elif payload_length >= 126 and payload_length <= 65535: header.append(Bitmasks.FIN | opcode) header.append(Bitmasks.PAYLOAD_LEN_EXT16) header.extend(struct.pack(">H", payload_length)) # Huge extended payload elif payload_length < (1 << 64): header.append(Bitmasks.FIN | opcode) header.append(Bitmasks.PAYLOAD_LEN_EXT64) header.extend(struct.pack(">Q", payload_length)) else: raise Exception("Message is too big") self.transport.write(header + payload) def handshake(self, data): try: message = data[0:1024].decode().strip() except UnicodeDecodeError: return False upgrade = re.search('\nupgrade[\s]*:[\s]*websocket', message.lower()) if not upgrade: self.keep_alive = False return False key = re.search('\n[sS]ec-[wW]eb[sS]ocket-[kK]ey[\s]*:[\s]*(.*)\r\n', message) if key: key = key.group(1) else: logger.log_debug("Client tried to connect but was missing a key", self.client) self.keep_alive = False return False response = self.make_handshake_response(key) print(response.encode()) self.transport.write(response.encode()) self.handshake_done = True self.valid = True return True def make_handshake_response(self, key): return \ 'HTTP/1.1 101 Switching Protocols\r\n'\ 'Upgrade: websocket\r\n' \ 'Connection: Upgrade\r\n' \ 'Sec-WebSocket-Accept: %s\r\n' \ '\r\n' % self.calculate_response_key(key) def calculate_response_key(self, key): GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' hash = sha1(key.encode() + GUID.encode()) response_key = b64encode(hash.digest()).strip() return response_key.decode('ASCII') def finish(self): self.protocol.connection_lost(self)