Mypal/testing/tools/iceserver/iceserver.py

759 lines
27 KiB
Python

# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
import ipaddr
import socket
import hmac
import hashlib
import passlib.utils # for saslprep
import copy
import random
import operator
import platform
import time
from string import Template
from twisted.internet import reactor, protocol
from twisted.internet.task import LoopingCall
from twisted.internet.address import IPv4Address
MAGIC_COOKIE = 0x2112A442
REQUEST = 0
INDICATION = 1
SUCCESS_RESPONSE = 2
ERROR_RESPONSE = 3
BINDING = 0x001
ALLOCATE = 0x003
REFRESH = 0x004
SEND = 0x006
DATA_MSG = 0x007
CREATE_PERMISSION = 0x008
CHANNEL_BIND = 0x009
IPV4 = 1
IPV6 = 2
MAPPED_ADDRESS = 0x0001
USERNAME = 0x0006
MESSAGE_INTEGRITY = 0x0008
ERROR_CODE = 0x0009
UNKNOWN_ATTRIBUTES = 0x000A
LIFETIME = 0x000D
DATA_ATTR = 0x0013
XOR_PEER_ADDRESS = 0x0012
REALM = 0x0014
NONCE = 0x0015
XOR_RELAYED_ADDRESS = 0x0016
REQUESTED_TRANSPORT = 0x0019
DONT_FRAGMENT = 0x001A
XOR_MAPPED_ADDRESS = 0x0020
SOFTWARE = 0x8022
ALTERNATE_SERVER = 0x8023
FINGERPRINT = 0x8028
def unpack_uint(bytes_buf):
result = 0
for byte in bytes_buf:
result = (result << 8) + byte
return result
def pack_uint(value, width):
if value < 0:
raise ValueError("Invalid value: {}".format(value))
buf = bytearray([0]*width)
for i in range(0, width):
buf[i] = (value >> (8*(width - i - 1))) & 0xFF
return buf
def unpack(bytes_buf, format_array):
results = ()
for width in format_array:
results = results + (unpack_uint(bytes_buf[0:width]),)
bytes_buf = bytes_buf[width:]
return results
def pack(values, format_array):
if len(values) != len(format_array):
raise ValueError()
buf = bytearray()
for i in range(0, len(values)):
buf.extend(pack_uint(values[i], format_array[i]))
return buf
def bitwise_pack(source, dest, start_bit, num_bits):
if num_bits <= 0 or num_bits > start_bit + 1:
raise ValueError("Invalid num_bits: {}, start_bit = {}"
.format(num_bits, start_bit))
last_bit = start_bit - num_bits + 1
source = source >> last_bit
dest = dest << num_bits
mask = (1 << num_bits) - 1
dest += source & mask
return dest
class StunAttribute(object):
"""
Represents a STUN attribute in a raw format, according to the following:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| StunAttribute.attr_type | Length (derived as needed) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| StunAttribute.data (variable length) ....
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
"""
__attr_header_fmt = [2,2]
__attr_header_size = reduce(operator.add, __attr_header_fmt)
def __init__(self, attr_type=0, buf=bytearray()):
self.attr_type = attr_type
self.data = buf
def build(self):
buf = pack((self.attr_type, len(self.data)), self.__attr_header_fmt)
buf.extend(self.data)
# add padding if necessary
if len(buf) % 4:
buf.extend([0]*(4 - (len(buf) % 4)))
return buf
def parse(self, buf):
if self.__attr_header_size > len(buf):
raise Exception('truncated at attribute: incomplete header')
self.attr_type, length = unpack(buf, self.__attr_header_fmt)
length += self.__attr_header_size
if length > len(buf):
raise Exception('truncated at attribute: incomplete contents')
self.data = buf[self.__attr_header_size:length]
# verify padding
while length % 4:
if buf[length]:
raise ValueError("Non-zero padding")
length += 1
return length
class StunMessage(object):
"""
Represents a STUN message. Contains a method, msg_class, cookie,
transaction_id, and attributes (as an array of StunAttribute).
Has various functions for getting/adding attributes.
"""
def __init__(self):
self.method = 0
self.msg_class = 0
self.cookie = MAGIC_COOKIE
self.transaction_id = 0
self.attributes = []
# 0 1 2 3
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |0 0|M M M M M|C|M M M|C|M M M M| Message Length |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Magic Cookie |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | |
# | Transaction ID (96 bits) |
# | |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
__header_fmt = [2, 2, 4, 12]
__header_size = reduce(operator.add, __header_fmt)
# Returns how many bytes were parsed if buf was large enough, or how many
# bytes we would have needed if not. Throws if buf is malformed.
def parse(self, buf):
min_buf_size = self.__header_size
if len(buf) < min_buf_size:
return min_buf_size
message_type, length, cookie, self.transaction_id = unpack(
buf, self.__header_fmt)
min_buf_size += length
if len(buf) < min_buf_size:
return min_buf_size
# Avert your eyes...
self.method = bitwise_pack(message_type, 0, 13, 5)
self.msg_class = bitwise_pack(message_type, 0, 8, 1)
self.method = bitwise_pack(message_type, self.method, 7, 3)
self.msg_class = bitwise_pack(message_type, self.msg_class, 4, 1)
self.method = bitwise_pack(message_type, self.method, 3, 4)
if cookie != self.cookie:
raise Exception('Invalid cookie: {}'.format(cookie))
buf = buf[self.__header_size:min_buf_size]
while len(buf):
attr = StunAttribute()
length = attr.parse(buf)
buf = buf[length:]
self.attributes.append(attr)
return min_buf_size
# stop_after_attr_type is useful for calculating MESSAGE-DIGEST
def build(self, stop_after_attr_type=0):
attrs = bytearray()
for attr in self.attributes:
attrs.extend(attr.build())
if attr.attr_type == stop_after_attr_type:
break
message_type = bitwise_pack(self.method, 0, 11, 5)
message_type = bitwise_pack(self.msg_class, message_type, 1, 1)
message_type = bitwise_pack(self.method, message_type, 6, 3)
message_type = bitwise_pack(self.msg_class, message_type, 0, 1)
message_type = bitwise_pack(self.method, message_type, 3, 4)
message = pack((message_type,
len(attrs),
self.cookie,
self.transaction_id), self.__header_fmt)
message.extend(attrs)
return message
def add_error_code(self, code, phrase=None):
# 0 1 2 3
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Reserved, should be 0 |Class| Number |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Reason Phrase (variable) ..
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
error_code_fmt = [3, 1]
error_code = pack((code // 100, code % 100), error_code_fmt)
if phrase != None:
error_code.extend(bytearray(phrase))
self.attributes.append(StunAttribute(ERROR_CODE, error_code))
# 0 1 2 3
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |x x x x x x x x| Family | X-Port |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | X-Address (Variable)
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
__xor_v4addr_fmt = [1, 1, 2, 4]
__xor_v6addr_fmt = [1, 1, 2, 16]
__xor_v4addr_size = reduce(operator.add, __xor_v4addr_fmt)
__xor_v6addr_size = reduce(operator.add, __xor_v6addr_fmt)
def get_xaddr(self, ip_addr, version):
if version == IPV4:
return self.cookie ^ ip_addr
elif version == IPV6:
return ((self.cookie << 96) + self.transaction_id) ^ ip_addr
else:
raise ValueError("Invalid family: {}".format(family))
def get_xport(self, port):
return (self.cookie >> 16) ^ port
def add_xor_address(self, addr_port, attr_type):
ip_address = ipaddr.IPAddress(addr_port.host)
xport = self.get_xport(addr_port.port)
if ip_address.version == 4:
xaddr = self.get_xaddr(int(ip_address), IPV4)
xor_address = pack((0, IPV4, xport, xaddr), self.__xor_v4addr_fmt)
elif ip_address.version == 6:
xaddr = self.get_xaddr(int(ip_address), IPV6)
xor_address = pack((0, IPV6, xport, xaddr), self.__xor_v6addr_fmt)
else:
raise ValueError("Invalid ip version: {}"
.format(ip_address.version))
self.attributes.append(StunAttribute(attr_type, xor_address))
def add_data(self, buf):
self.attributes.append(StunAttribute(DATA_ATTR, buf))
def find(self, attr_type):
for attr in self.attributes:
if attr.attr_type == attr_type:
return attr
return None
def get_xor_address(self, attr_type):
addr_attr = self.find(attr_type)
if not addr_attr:
return None
padding, family, xport, xaddr = unpack(addr_attr.data,
self.__xor_v4addr_fmt)
addr_ctor = IPv4Address
if family == IPV6:
from twisted.internet.address import IPv6Address
padding, family, xport, xaddr = unpack(addr_attr.data,
self.__xor_v6addr_fmt)
addr_ctor = IPv6Address
elif family != IPV4:
raise ValueError("Invalid family: {}".format(family))
return addr_ctor('UDP',
str(ipaddr.IPAddress(self.get_xaddr(xaddr, family))),
self.get_xport(xport))
def add_nonce(self, nonce):
self.attributes.append(StunAttribute(NONCE, bytearray(nonce)))
def add_realm(self, realm):
self.attributes.append(StunAttribute(REALM, bytearray(realm)))
def calculate_message_digest(self, username, realm, password):
digest_buf = self.build(MESSAGE_INTEGRITY)
# Trim off the MESSAGE-INTEGRITY attr
digest_buf = digest_buf[:len(digest_buf) - 24]
password = passlib.utils.saslprep(unicode(password))
key_string = "{}:{}:{}".format(username, realm, password)
md5 = hashlib.md5()
md5.update(key_string)
key = md5.digest()
return bytearray(hmac.new(key, digest_buf, hashlib.sha1).digest())
def add_lifetime(self, lifetime):
self.attributes.append(StunAttribute(LIFETIME, pack_uint(lifetime, 4)))
def get_lifetime(self):
lifetime_attr = self.find(LIFETIME)
if not lifetime_attr:
return None
return unpack_uint(lifetime_attr.data[0:4])
def get_username(self):
username = self.find(USERNAME)
if not username:
return None
return str(username.data)
def add_message_integrity(self, username, realm, password):
dummy_value = bytearray([0]*20)
self.attributes.append(StunAttribute(MESSAGE_INTEGRITY, dummy_value))
digest = self.calculate_message_digest(username, realm, password)
self.find(MESSAGE_INTEGRITY).data = digest
class Allocation(protocol.DatagramProtocol):
"""
Comprises the socket for a TURN allocation, a back-reference to the
transport we will forward received traffic on, the allocator's address and
username, the set of permissions for the allocation, and the allocation's
expiry.
"""
def __init__(self, other_transport_handler, allocator_address, username):
self.permissions = set() # str, int tuples
# Handler to use when sending stuff that arrives on the allocation
self.other_transport_handler = other_transport_handler
self.allocator_address = allocator_address
self.username = username
self.expiry = time.time()
self.port = reactor.listenUDP(0, self, interface=v4_address)
def datagramReceived(self, data, (host, port)):
if not host in self.permissions:
print("Dropping packet from {}:{}, no permission on allocation {}"
.format(host, port, self.transport.getHost()))
return
data_indication = StunMessage()
data_indication.method = DATA_MSG
data_indication.msg_class = INDICATION
data_indication.transaction_id = random.getrandbits(96)
# Only handles UDP allocations. Doubtful that we need more than this.
data_indication.add_xor_address(IPv4Address('UDP', host, port),
XOR_PEER_ADDRESS)
data_indication.add_data(data)
self.other_transport_handler.write(data_indication.build(),
self.allocator_address)
def close(self):
self.port.stopListening()
self.port = None
class StunHandler(object):
"""
Frames and handles STUN messages. This is the core logic of the TURN
server, along with Allocation.
"""
def __init__(self, transport_handler):
self.client_address = None
self.data = str()
self.transport_handler = transport_handler
def data_received(self, data, address):
self.data += bytearray(data)
while True:
stun_message = StunMessage()
parsed_len = stun_message.parse(self.data)
if parsed_len > len(self.data):
break
self.data = self.data[parsed_len:]
response = self.handle_stun(stun_message, address)
if response:
self.transport_handler.write(response, address)
def handle_stun(self, stun_message, address):
self.client_address = address
if stun_message.msg_class == INDICATION:
if stun_message.method == SEND:
self.handle_send_indication(stun_message)
else:
print("Dropping unknown indication method: {}"
.format(stun_message.method))
return None
if stun_message.msg_class != REQUEST:
print("Dropping STUN response, method: {}"
.format(stun_message.method))
return None
if stun_message.method == BINDING:
return self.make_success_response(stun_message).build()
elif stun_message.method == ALLOCATE:
return self.handle_allocation(stun_message).build()
elif stun_message.method == REFRESH:
return self.handle_refresh(stun_message).build()
elif stun_message.method == CREATE_PERMISSION:
return self.handle_permission(stun_message).build()
else:
return self.make_error_response(
stun_message,
400,
("Unsupported STUN request, method: {}"
.format(stun_message.method))).build()
def get_allocation_tuple(self):
return (self.client_address.host,
self.client_address.port,
self.transport_handler.transport.getHost().type,
self.transport_handler.transport.getHost().host,
self.transport_handler.transport.getHost().port)
def handle_allocation(self, request):
allocate_response = self.check_long_term_auth(request)
if allocate_response.msg_class == SUCCESS_RESPONSE:
if self.get_allocation_tuple() in allocations:
return self.make_error_response(
request,
437,
("Duplicate allocation request for tuple {}"
.format(self.get_allocation_tuple())))
allocation = Allocation(self.transport_handler,
self.client_address,
request.get_username())
allocate_response.add_xor_address(allocation.transport.getHost(),
XOR_RELAYED_ADDRESS)
lifetime = request.get_lifetime()
if lifetime == None:
return self.make_error_response(
request,
400,
"Missing lifetime attribute in allocation request")
lifetime = min(lifetime, 3600)
allocate_response.add_lifetime(lifetime)
allocation.expiry = time.time() + lifetime
allocate_response.add_message_integrity(turn_user,
turn_realm,
turn_pass)
allocations[self.get_allocation_tuple()] = allocation
return allocate_response
def handle_refresh(self, request):
refresh_response = self.check_long_term_auth(request)
if refresh_response.msg_class == SUCCESS_RESPONSE:
try:
allocation = allocations[self.get_allocation_tuple()]
except KeyError:
return self.make_error_response(
request,
437,
("Refresh request for non-existing allocation, tuple {}"
.format(self.get_allocation_tuple())))
if allocation.username != request.get_username():
return self.make_error_response(
request,
441,
("Refresh request with wrong user, exp {}, got {}"
.format(allocation.username, request.get_username())))
lifetime = request.get_lifetime()
if lifetime == None:
return self.make_error_response(
request,
400,
"Missing lifetime attribute in allocation request")
lifetime = min(lifetime, 3600)
refresh_response.add_lifetime(lifetime)
allocation.expiry = time.time() + lifetime
refresh_response.add_message_integrity(turn_user,
turn_realm,
turn_pass)
return refresh_response
def handle_permission(self, request):
permission_response = self.check_long_term_auth(request)
if permission_response.msg_class == SUCCESS_RESPONSE:
try:
allocation = allocations[self.get_allocation_tuple()]
except KeyError:
return self.make_error_response(
request,
437,
("No such allocation for permission request, tuple {}"
.format(self.get_allocation_tuple())))
if allocation.username != request.get_username():
return self.make_error_response(
request,
441,
("Permission request with wrong user, exp {}, got {}"
.format(allocation.username, request.get_username())))
# TODO: Handle multiple XOR-PEER-ADDRESS
peer_address = request.get_xor_address(XOR_PEER_ADDRESS)
if not peer_address:
return self.make_error_response(
request,
400,
"Missing XOR-PEER-ADDRESS on permission request")
permission_response.add_message_integrity(turn_user,
turn_realm,
turn_pass)
allocation.permissions.add(peer_address.host)
return permission_response
def handle_send_indication(self, indication):
try:
allocation = allocations[self.get_allocation_tuple()]
except KeyError:
print("Dropping send indication; no allocation for tuple {}"
.format(self.get_allocation_tuple()))
return
peer_address = indication.get_xor_address(XOR_PEER_ADDRESS)
if not peer_address:
print("Dropping send indication, missing XOR-PEER-ADDRESS")
return
data_attr = indication.find(DATA_ATTR)
if not data_attr:
print("Dropping send indication, missing DATA")
return
if indication.find(DONT_FRAGMENT):
print("Dropping send indication, DONT-FRAGMENT set")
return
if not peer_address.host in allocation.permissions:
print("Dropping send indication, no permission for {} on tuple {}"
.format(peer_address.host, self.get_allocation_tuple()))
return
allocation.transport.write(data_attr.data,
(peer_address.host, peer_address.port))
def make_success_response(self, request):
response = copy.deepcopy(request)
response.attributes = []
response.add_xor_address(self.client_address, XOR_MAPPED_ADDRESS)
response.msg_class = SUCCESS_RESPONSE
return response
def make_error_response(self, request, code, reason=None):
if reason:
print("{}: rejecting with {}".format(reason, code))
response = copy.deepcopy(request)
response.attributes = []
response.add_error_code(code, reason)
response.msg_class = ERROR_RESPONSE
return response
def make_challenge_response(self, request, reason=None):
response = self.make_error_response(request, 401, reason)
# 65 means the hex encoding will need padding half the time
response.add_nonce("{:x}".format(random.getrandbits(65)))
response.add_realm(turn_realm)
return response
def check_long_term_auth(self, request):
message_integrity = request.find(MESSAGE_INTEGRITY)
if not message_integrity:
return self.make_challenge_response(request)
username = request.find(USERNAME)
realm = request.find(REALM)
nonce = request.find(NONCE)
if not username or not realm or not nonce:
return self.make_error_response(
request,
400,
"Missing either USERNAME, NONCE, or REALM")
if str(username.data) != turn_user:
return self.make_challenge_response(
request,
"Wrong user {}, exp {}".format(username.data, turn_user))
expected_message_digest = request.calculate_message_digest(turn_user,
turn_realm,
turn_pass)
if message_integrity.data != expected_message_digest:
return self.make_challenge_response(request,
"Incorrect message disgest")
return self.make_success_response(request)
class UdpStunHandler(protocol.DatagramProtocol):
"""
Represents a UDP listen port for TURN.
"""
def datagramReceived(self, data, address):
stun_handler = StunHandler(self)
stun_handler.data_received(data,
IPv4Address('UDP', address[0], address[1]))
def write(self, data, address):
self.transport.write(str(data), (address.host, address.port))
class TcpStunHandlerFactory(protocol.Factory):
"""
Represents a TCP listen port for TURN.
"""
def buildProtocol(self, addr):
return TcpStunHandler(addr)
class TcpStunHandler(protocol.Protocol):
"""
Represents a connected TCP port for TURN.
"""
def __init__(self, addr):
self.address = addr
self.stun_handler = None
def dataReceived(self, data):
# This needs to persist, since it handles framing
if not self.stun_handler:
self.stun_handler = StunHandler(self)
self.stun_handler.data_received(data, self.address)
def connectionLost(self, reason):
print("Lost connection from {}".format(self.address))
# Destroy allocations that this connection made
for key, allocation in allocations.items():
if allocation.other_transport_handler == self:
print("Closing allocation due to dropped connection: {}"
.format(key))
del allocations[key]
allocation.close()
def write(self, data, address):
self.transport.write(str(data))
def get_default_route(family):
dummy_socket = socket.socket(family, socket.SOCK_DGRAM)
if family is socket.AF_INET:
dummy_socket.connect(("8.8.8.8", 53))
else:
dummy_socket.connect(("2001:4860:4860::8888", 53))
default_route = dummy_socket.getsockname()[0]
dummy_socket.close()
return default_route
turn_user="foo"
turn_pass="bar"
turn_realm="mozilla.invalid"
allocations = {}
v4_address = get_default_route(socket.AF_INET)
try:
v6_address = get_default_route(socket.AF_INET6)
except:
v6_address = ""
def prune_allocations():
now = time.time()
for key, allocation in allocations.items():
if allocation.expiry < now:
print("Allocation expired: {}".format(key))
del allocations[key]
allocation.close()
if __name__ == "__main__":
random.seed()
if platform.system() is "Windows":
# Windows is finicky about allowing real interfaces to talk to loopback.
interface_4 = v4_address
interface_6 = v6_address
hostname = socket.gethostname()
else:
# Our linux builders do not have a hostname that resolves to the real
# interface.
interface_4 = "127.0.0.1"
interface_6 = "::1"
hostname = "localhost"
reactor.listenUDP(3478, UdpStunHandler(), interface=interface_4)
reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_4)
try:
reactor.listenUDP(3478, UdpStunHandler(), interface=interface_6)
reactor.listenTCP(3478, TcpStunHandlerFactory(), interface=interface_6)
except:
pass
allocation_pruner = LoopingCall(prune_allocations)
allocation_pruner.start(1)
template = Template(
'[\
{"url":"stun:$hostname"}, \
{"url":"stun:$hostname?transport=tcp"}, \
{"username":"$user","credential":"$pwd","url":"turn:$hostname"}, \
{"username":"$user","credential":"$pwd","url":"turn:$hostname?transport=tcp"}]'
)
print(template.substitute(user=turn_user,
pwd=turn_pass,
hostname=hostname))
reactor.run()