301 lines
8.8 KiB
Python
301 lines
8.8 KiB
Python
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
import json
|
|
import socket
|
|
import time
|
|
|
|
|
|
class SocketTimeout(object):
|
|
def __init__(self, socket, timeout):
|
|
self.sock = socket
|
|
self.timeout = timeout
|
|
self.old_timeout = None
|
|
|
|
def __enter__(self):
|
|
self.old_timeout = self.sock.gettimeout()
|
|
self.sock.settimeout(self.timeout)
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
self.sock.settimeout(self.old_timeout)
|
|
|
|
|
|
class Message(object):
|
|
def __init__(self, msgid):
|
|
self.id = msgid
|
|
|
|
def __eq__(self, other):
|
|
return self.id == other.id
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
|
|
class Command(Message):
|
|
TYPE = 0
|
|
|
|
def __init__(self, msgid, name, params):
|
|
Message.__init__(self, msgid)
|
|
self.name = name
|
|
self.params = params
|
|
|
|
def __str__(self):
|
|
return "<Command id={0}, name={1}, params={2}>".format(self.id, self.name, self.params)
|
|
|
|
def to_msg(self):
|
|
msg = [Command.TYPE, self.id, self.name, self.params]
|
|
return json.dumps(msg)
|
|
|
|
@staticmethod
|
|
def from_msg(payload):
|
|
data = json.loads(payload)
|
|
assert data[0] == Command.TYPE
|
|
cmd = Command(data[1], data[2], data[3])
|
|
return cmd
|
|
|
|
|
|
class Response(Message):
|
|
TYPE = 1
|
|
|
|
def __init__(self, msgid, error, result):
|
|
Message.__init__(self, msgid)
|
|
self.error = error
|
|
self.result = result
|
|
|
|
def __str__(self):
|
|
return "<Response id={0}, error={1}, result={2}>".format(self.id, self.error, self.result)
|
|
|
|
def to_msg(self):
|
|
msg = [Response.TYPE, self.id, self.error, self.result]
|
|
return json.dumps(msg)
|
|
|
|
@staticmethod
|
|
def from_msg(payload):
|
|
data = json.loads(payload)
|
|
assert data[0] == Response.TYPE
|
|
return Response(data[1], data[2], data[3])
|
|
|
|
|
|
class Proto2Command(Command):
|
|
"""Compatibility shim that marshals messages from a protocol level
|
|
2 and below remote into ``Command`` objects.
|
|
"""
|
|
|
|
def __init__(self, name, params):
|
|
Command.__init__(self, None, name, params)
|
|
|
|
|
|
class Proto2Response(Response):
|
|
"""Compatibility shim that marshals messages from a protocol level
|
|
2 and below remote into ``Response`` objects.
|
|
"""
|
|
|
|
def __init__(self, error, result):
|
|
Response.__init__(self, None, error, result)
|
|
|
|
@staticmethod
|
|
def from_data(data):
|
|
err, res = None, None
|
|
if "error" in data:
|
|
err = data
|
|
else:
|
|
res = data
|
|
return Proto2Response(err, res)
|
|
|
|
|
|
class TcpTransport(object):
|
|
"""Socket client that communciates with Marionette via TCP.
|
|
|
|
It speaks the protocol of the remote debugger in Gecko, in which
|
|
messages are always preceded by the message length and a colon, e.g.:
|
|
|
|
7:MESSAGE
|
|
|
|
On top of this protocol it uses a Marionette message format, that
|
|
depending on the protocol level offered by the remote server, varies.
|
|
Supported protocol levels are 1 and above.
|
|
"""
|
|
max_packet_length = 4096
|
|
|
|
def __init__(self, addr, port, socket_timeout=60.0):
|
|
"""If `socket_timeout` is `0` or `0.0`, non-blocking socket mode
|
|
will be used. Setting it to `1` or `None` disables timeouts on
|
|
socket operations altogether.
|
|
"""
|
|
self.addr = addr
|
|
self.port = port
|
|
self._socket_timeout = socket_timeout
|
|
|
|
self.protocol = 1
|
|
self.application_type = None
|
|
self.last_id = 0
|
|
self.expected_response = None
|
|
self.sock = None
|
|
|
|
@property
|
|
def socket_timeout(self):
|
|
return self._socket_timeout
|
|
|
|
@socket_timeout.setter
|
|
def socket_timeout(self, value):
|
|
if self.sock:
|
|
self.sock.settimeout(value)
|
|
self._socket_timeout = value
|
|
|
|
def _unmarshal(self, packet):
|
|
msg = None
|
|
|
|
# protocol 3 and above
|
|
if self.protocol >= 3:
|
|
typ = int(packet[1])
|
|
if typ == Command.TYPE:
|
|
msg = Command.from_msg(packet)
|
|
elif typ == Response.TYPE:
|
|
msg = Response.from_msg(packet)
|
|
|
|
# protocol 2 and below
|
|
else:
|
|
data = json.loads(packet)
|
|
|
|
msg = Proto2Response.from_data(data)
|
|
|
|
return msg
|
|
|
|
def receive(self, unmarshal=True):
|
|
"""Wait for the next complete response from the remote.
|
|
|
|
:param unmarshal: Default is to deserialise the packet and
|
|
return a ``Message`` type. Setting this to false will return
|
|
the raw packet.
|
|
"""
|
|
now = time.time()
|
|
data = ""
|
|
bytes_to_recv = 10
|
|
|
|
while self.socket_timeout is None or (time.time() - now < self.socket_timeout):
|
|
try:
|
|
chunk = self.sock.recv(bytes_to_recv)
|
|
data += chunk
|
|
except socket.timeout:
|
|
pass
|
|
else:
|
|
if not chunk:
|
|
raise socket.error("No data received over socket")
|
|
|
|
sep = data.find(":")
|
|
if sep > -1:
|
|
length = data[0:sep]
|
|
remaining = data[sep + 1:]
|
|
|
|
if len(remaining) == int(length):
|
|
if unmarshal:
|
|
msg = self._unmarshal(remaining)
|
|
self.last_id = msg.id
|
|
|
|
if self.protocol >= 3:
|
|
self.last_id = msg.id
|
|
|
|
# keep reading incoming responses until
|
|
# we receive the user's expected response
|
|
if isinstance(msg, Response) and msg != self.expected_response:
|
|
return self.receive(unmarshal)
|
|
|
|
return msg
|
|
|
|
else:
|
|
return remaining
|
|
|
|
bytes_to_recv = int(length) - len(remaining)
|
|
|
|
raise socket.timeout("Connection timed out after {}s".format(self.socket_timeout))
|
|
|
|
def connect(self):
|
|
"""Connect to the server and process the hello message we expect
|
|
to receive in response.
|
|
|
|
Returns a tuple of the protocol level and the application type.
|
|
"""
|
|
try:
|
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self.sock.settimeout(self.socket_timeout)
|
|
|
|
self.sock.connect((self.addr, self.port))
|
|
except:
|
|
# Unset self.sock so that the next attempt to send will cause
|
|
# another connection attempt.
|
|
self.sock = None
|
|
raise
|
|
|
|
with SocketTimeout(self.sock, 2.0):
|
|
# first packet is always a JSON Object
|
|
# which we can use to tell which protocol level we are at
|
|
raw = self.receive(unmarshal=False)
|
|
hello = json.loads(raw)
|
|
self.protocol = hello.get("marionetteProtocol", 1)
|
|
self.application_type = hello.get("applicationType")
|
|
|
|
return (self.protocol, self.application_type)
|
|
|
|
def send(self, obj):
|
|
"""Send message to the remote server. Allowed input is a
|
|
``Message`` instance or a JSON serialisable object.
|
|
"""
|
|
if not self.sock:
|
|
self.connect()
|
|
|
|
if isinstance(obj, Message):
|
|
data = obj.to_msg()
|
|
if isinstance(obj, Command):
|
|
self.expected_response = obj
|
|
else:
|
|
data = json.dumps(obj)
|
|
payload = "{0}:{1}".format(len(data), data)
|
|
|
|
totalsent = 0
|
|
while totalsent < len(payload):
|
|
sent = self.sock.send(payload[totalsent:])
|
|
if sent == 0:
|
|
raise IOError("Socket error after sending {0} of {1} bytes"
|
|
.format(totalsent, len(payload)))
|
|
else:
|
|
totalsent += sent
|
|
|
|
def respond(self, obj):
|
|
"""Send a response to a command. This can be an arbitrary JSON
|
|
serialisable object or an ``Exception``.
|
|
"""
|
|
res, err = None, None
|
|
if isinstance(obj, Exception):
|
|
err = obj
|
|
else:
|
|
res = obj
|
|
msg = Response(self.last_id, err, res)
|
|
self.send(msg)
|
|
return self.receive()
|
|
|
|
def request(self, name, params):
|
|
"""Sends a message to the remote server and waits for a response
|
|
to come back.
|
|
"""
|
|
self.last_id = self.last_id + 1
|
|
cmd = Command(self.last_id, name, params)
|
|
self.send(cmd)
|
|
return self.receive()
|
|
|
|
def close(self):
|
|
"""Close the socket."""
|
|
if self.sock:
|
|
try:
|
|
self.sock.shutdown(socket.SHUT_RDWR)
|
|
except IOError as exc:
|
|
# Errno 57 is "socket not connected", which we don't care about here.
|
|
if exc.errno != 57:
|
|
raise
|
|
|
|
self.sock.close()
|
|
self.sock = None
|
|
|
|
def __del__(self):
|
|
self.close()
|