Import of the watch repository from Pebble

This commit is contained in:
Matthieu Jeanson 2024-12-12 16:43:03 -08:00 committed by Katharine Berry
commit 3b92768480
10334 changed files with 2564465 additions and 0 deletions

23
tools/pulse/__init__.py Normal file
View file

@ -0,0 +1,23 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import flash_imaging
from . import pulse_logging
from . import pulse_prompt
from .socket import Connection
from .exceptions import PulseError
Connection.register_extension('flash', flash_imaging.FlashImagingProtocol)
Connection.register_extension('logging', pulse_logging.LoggingProtocol)
Connection.register_extension('prompt', pulse_prompt.PromptProtocol)

56
tools/pulse/exceptions.py Normal file
View file

@ -0,0 +1,56 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PulseError(Exception):
pass
class ProtocolAlreadyRegistered(PulseError):
pass
class FrameDecodeError(PulseError):
pass
class ReceiveQueueEmpty(PulseError):
pass
class ResponseParseError(PulseError):
pass
class CommandTimedOut(PulseError):
pass
class WriteError(PulseError):
pass
class EraseError(PulseError):
pass
class RegionDoesNotExist(PulseError):
pass
class TTYAutodetectionUnavailable(PulseError):
pass
class InvalidOperation(PulseError):
pass

View file

@ -0,0 +1,301 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import struct
import time
from . import exceptions
from . import socket
class EraseCommand(object):
command_type = 1
command_struct = struct.Struct('<BII')
response_type = 128
response_struct = struct.Struct('<xII?')
Response = collections.namedtuple(
'EraseResponse', 'address length complete')
def __init__(self, address, length):
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(
self.command_type, self.address, self.length)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address != self.address or unpacked.length != self.length:
raise exceptions.ResponseParseError(
'Response does not match command: '
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
unpacked.address, unpacked.length, self.address,
self.length))
return unpacked
class WriteCommand(object):
command_type = 2
command_struct = struct.Struct('<BI')
header_len = command_struct.size
def __init__(self, address, data):
self.address = address
self.data = data
@property
def packet(self):
header = self.command_struct.pack(self.command_type, self.address)
return header + self.data
class WriteResponse(object):
response_type = 129
response_struct = struct.Struct('<xII?')
Response = collections.namedtuple(
'WriteResponse', 'address length complete')
@classmethod
def parse(cls, response):
if ord(response[0]) != cls.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
return cls.Response._make(cls.response_struct.unpack(response))
class CrcCommand(object):
command_type = 3
command_struct = struct.Struct('<BII')
response_type = 130
response_struct = struct.Struct('<xIII')
Response = collections.namedtuple('CrcResponse', 'address length crc')
def __init__(self, address, length):
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.address,
self.length)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address != self.address or unpacked.length != self.length:
raise exceptions.ResponseParseError(
'Response does not match command: '
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
unpacked.address, unpacked.length, self.address,
self.length))
return unpacked
class QueryFlashRegionCommand(object):
command_type = 4
command_struct = struct.Struct('<BB')
REGION_PRF = 1
REGION_SYSTEM_RESOURCES = 2
response_type = 131
response_struct = struct.Struct('<xBII')
Response = collections.namedtuple(
'FlashRegionGeometry', 'region address length')
def __init__(self, region):
self.region = region
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.region)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address == 0 and unpacked.length == 0:
raise exceptions.RegionDoesNotExist(self.region)
return unpacked
class FinalizeFlashRegionCommand(object):
command_type = 5
command_struct = struct.Struct('<BB')
response_type = 132
response_struct = struct.Struct('<xB')
def __init__(self, region):
self.region = region
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.region)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
region, = self.response_struct.unpack(response)
if region != self.region:
raise exceptions.ResponseParseError(
'Response does not match command: '
'response is for region %d (expected %d)' % (
region, self.region))
class FlashImagingProtocol(object):
PROTOCOL_NUMBER = 0x02
RESP_BAD_CMD = 192
RESP_INTERNAL_ERROR = 193
REGION_PRF = QueryFlashRegionCommand.REGION_PRF
REGION_SYSTEM_RESOURCES = QueryFlashRegionCommand.REGION_SYSTEM_RESOURCES
def __init__(self, connection):
self.socket = socket.ProtocolSocket(connection,
self.PROTOCOL_NUMBER)
def erase(self, address, length):
cmd = EraseCommand(address, length)
ack_received = False
retries = 0
while retries < 10:
if not ack_received:
self.socket.send(cmd.packet)
try:
packet = self.socket.receive(timeout=5 if ack_received else 1.5)
response = cmd.parse_response(packet)
ack_received = True
if response.complete:
return
except exceptions.ReceiveQueueEmpty:
ack_received = False
retries += 1
continue
raise exceptions.CommandTimedOut
def write(self, address, data, max_retries=5, max_in_flight=5,
progress_cb=None):
mtu = self.socket.mtu - WriteCommand.header_len
assert(mtu > 0)
unsent = collections.OrderedDict()
for offset in xrange(0, len(data), mtu):
segment = data[offset:offset+mtu]
assert(len(segment))
seg_address = address + offset
unsent[seg_address] = WriteCommand(seg_address, segment)
in_flight = collections.OrderedDict()
retries = 0
while unsent or in_flight:
try:
while True:
# Process ACKs (if any)
ack = WriteResponse.parse(
self.socket.receive(block=False))
try:
cmd, _, _ = in_flight[ack.address]
except KeyError:
raise exceptions.WriteError(
'Received ACK for an unknown segment: '
'%#.08x' % ack.address)
if len(cmd.data) != ack.length:
raise exceptions.WriteError(
'ACK length %d != data length %d' % (
ack.length, len(cmd.data)))
assert(ack.complete)
del in_flight[ack.address]
if progress_cb:
progress_cb(True)
except exceptions.ReceiveQueueEmpty:
pass
# Retry any in_flight writes where the ACK has timed out
to_retry = []
timeout_time = time.time() - 0.5
for seg_address, (_, send_time, _) in in_flight.iteritems():
if send_time > timeout_time:
# in_flight is an OrderedDict so iteration is in
# chronological order.
break
to_retry.append(seg_address)
retries += len(to_retry)
for seg_address in to_retry:
cmd, send_time, retry_count = in_flight[seg_address]
del in_flight[seg_address]
if retry_count >= max_retries:
raise exceptions.WriteError(
'Segment %#.08x exceeded the max retry count (%d)' % (
seg_address, max_retries))
retry_count += 1
self.socket.send(cmd.packet)
in_flight[seg_address] = (cmd, time.time(), retry_count)
if progress_cb:
progress_cb(False)
# Send out fresh segments
try:
while len(in_flight) < max_in_flight:
seg_address, cmd = unsent.popitem(last=False)
self.socket.send(cmd.packet)
in_flight[cmd.address] = (cmd, time.time(), 0)
except KeyError:
pass
# Give other threads a chance to run
time.sleep(0)
return retries
def _command_and_response(self, cmd, timeout=0.5):
for attempt in xrange(5):
self.socket.send(cmd.packet)
try:
packet = self.socket.receive(timeout=timeout)
return cmd.parse_response(packet)
except exceptions.ReceiveQueueEmpty:
pass
raise exceptions.CommandTimedOut
def crc(self, address, length):
cmd = CrcCommand(address, length)
return self._command_and_response(cmd, timeout=1).crc
def query_region_geometry(self, region):
cmd = QueryFlashRegionCommand(region)
return self._command_and_response(cmd)
def finalize_region(self, region):
cmd = FinalizeFlashRegionCommand(region)
return self._command_and_response(cmd)

View file

@ -0,0 +1,95 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import re
import struct
import threading
import time
from datetime import datetime
from . import socket
class LogMessage(collections.namedtuple('LogMessage',
'log_level task timestamp file_name line_number message')):
__slots__ = ()
response_struct = struct.Struct('<c16sccQH')
def __str__(self):
msec_timestamp = self.timestamp.strftime("%H:%M:%S.%f")[:-3]
template = ('{self.log_level} {self.task} {msec_timestamp} '
'{self.file_name}:{self.line_number}> {self.message}')
return template.format(self=self, msec_timestamp=msec_timestamp)
@classmethod
def parse(cls, packet):
result = cls.response_struct.unpack(packet[:cls.response_struct.size])
msg = packet[cls.response_struct.size:]
log_level = result[2]
task = result[3]
timestamp = datetime.fromtimestamp(result[4] / 1000.0)
file_name = result[1].split('\x00', 1)[0] # NUL terminated
line_number = result[5]
return cls(log_level, task, timestamp, file_name, line_number, msg)
class LoggingProtocol(object):
PROTOCOL_NUMBER = 0x03
def __init__(self, connection):
self.socket = socket.ProtocolSocket(connection, self.PROTOCOL_NUMBER)
def receive(self, block=True, timeout=None):
return LogMessage.parse(self.socket.receive(block, timeout))
if __name__ == '__main__':
import readline
import sys
from log_hashing import log_dehash
if len(sys.argv) != 2:
print 'Usage: python ' + sys.argv[0] + ' <loghash_dict_path>'
sys.exit(1)
loghash_dict_path = sys.argv[1]
json_dict = json.load(open(loghash_dict_path, 'rb'))
log_hash_dict = {int(key): value
for (key, value) in json_dict.iteritems() if key.isdigit()}
def dehash(msg):
return log_dehash.dehash_logstring(msg, log_hash_dict)
def start_logging(logger):
while True:
msg = logger.receive()
print dehash(str(msg))
with socket.Connection.open_dbgserial('ftdi://ftdi:4232:1/3') as connection:
logging_thread = threading.Thread(target=start_logging, args=[connection.logging])
logging_thread.daemon = True
logging_thread.start()
inputCommand = raw_input('>')
while inputCommand:
for message in connection.prompt.command_and_response(inputCommand):
print message
inputCommand = raw_input('>')

128
tools/pulse/pulse_prompt.py Normal file
View file

@ -0,0 +1,128 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import struct
from datetime import datetime
from . import exceptions
from . import socket
class PromptProtocol(object):
PROTOCOL_NUMBER = 0x04
def __init__(self, connection):
self.socket = socket.ProtocolSocket(connection, self.PROTOCOL_NUMBER)
def command_and_response(self, command_string, timeout=1):
log = []
cmd = PromptCommand(command_string)
for _ in xrange(5):
self.socket.send(cmd.packet)
try:
response = PromptResponse.parse(self.socket.receive(timeout=timeout))
except exceptions.ReceiveQueueEmpty:
continue
# Retry the command if we don't get an ACK
if not response.is_ack_response():
continue
# Receive messages until DONE
is_done = False
retries = 0
while not is_done and retries < 3:
try:
response = PromptResponse.parse(self.socket.receive(timeout=timeout))
if response.is_done_response():
is_done = True
elif response.is_message_response():
log.append(response.message)
retries = 0
except exceptions.ReceiveQueueEmpty:
self.socket.send(cmd.packet)
retries += 1
if retries == 3:
raise exceptions.CommandTimedOut('Lost connection while waiting')
return log
raise exceptions.CommandTimedOut('Command not acknowledged')
class PromptResponse(collections.namedtuple('PromptResponse',
'response_type timestamp message')):
ACK_RESPONSE = 100
DONE_RESPONSE = 101
MESSAGE_RESPONSE = 102
response_struct = struct.Struct('<BQ')
def is_ack_response(self):
return self.response_type == self.ACK_RESPONSE
def is_done_response(self):
return self.response_type == self.DONE_RESPONSE
def is_message_response(self):
return self.response_type == self.MESSAGE_RESPONSE
@classmethod
def parse(cls, response):
result = cls.response_struct.unpack(response[:cls.response_struct.size])
response_type = result[0]
timestamp = datetime.fromtimestamp(result[1] / 1000.0)
message = response[cls.response_struct.size:]
return cls(response_type, timestamp, message)
class PromptCommand(object):
_cookie = 0
def __init__(self, body):
self.body = body
self.cookie = self._get_cookie()
@property
def packet(self):
return chr(self.cookie) + str(self.body)
@classmethod
def _get_cookie(cls):
cookie = cls._cookie
cls._cookie = (cls._cookie + 1) % 256
return cookie
if __name__ == '__main__':
import readline
with socket.Connection.open_dbgserial('ftdi://ftdi:4232:1/3') as connection:
inputCommand = raw_input('>')
while inputCommand:
for message in connection.prompt.command_and_response(inputCommand):
print message
inputCommand = raw_input('>')

444
tools/pulse/socket.py Normal file
View file

@ -0,0 +1,444 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import Queue
import struct
import sys
import threading
import time
import traceback
import uuid
import weakref
from cobs import cobs
import serial
from . import exceptions
import stm32_crc
logger = logging.getLogger(__name__)
try:
import pyftdi.serialext
except ImportError:
pass
DBGSERIAL_PORT_SETTINGS = dict(baudrate=230400, timeout=0.1,
interCharTimeout=0.01)
def get_dbgserial_tty():
# Local import so that we only depend on this package if we're attempting
# to autodetect the TTY. This package isn't always available (e.g., MFG),
# so we don't want it to be required.
try:
import pebble_tty
return pebble_tty.find_dbgserial_tty()
except ImportError:
raise exceptions.TTYAutodetectionUnavailable
def frame_splitter(istream, size=1024, timeout=1, delimiter='\0'):
'''Returns an iterator which yields complete frames.'''
partial = []
start_time = time.time()
while not istream.closed:
data = istream.read(size)
logger.debug('frame_splitter: received %r', data)
while True:
left, delim, data = data.partition(delimiter)
if left:
partial.append(left)
if delim:
frame = ''.join(partial)
partial = []
if frame:
yield frame
if not data:
break
if timeout > 0 and time.time() > start_time + timeout:
yield
def decode_frame(frame):
'''Decodes a PULSE frame.
Returns a tuple (protocol, payload) of the decoded frame.
Raises FrameDecodeError if the frame is not valid.
'''
try:
data = cobs.decode(frame)
except cobs.DecodeError, e:
raise exceptions.FrameDecodeError(e.message)
if len(data) < 5:
raise exceptions.FrameDecodeError('frame too short')
fcs = struct.unpack('<I', data[-4:])[0]
crc = stm32_crc.crc32(data[:-4])
if fcs != crc:
raise exceptions.FrameDecodeError('FCS 0x%.08x != CRC 0x%.08x' % (fcs, crc))
protocol = ord(data[0])
return (protocol, data[1:-4])
def encode_frame(protocol, payload):
frame = struct.pack('<B', protocol)
frame += payload
fcs = stm32_crc.crc32(frame)
frame += struct.pack('<I', fcs)
return cobs.encode(frame)
class Connection(object):
'''A socket for sending and receiving datagrams over the PULSE serial
protocol.
'''
PROTOCOL_LLC = 0x01
LLC_LINK_OPEN_REQUEST = '\x01\x03\x08\x08\x08PULSEv1\r\n'
LLC_LINK_CLOSE_REQUEST = '\x03'
LLC_ECHO_REQUEST = '\x05'
LLC_CHANGE_BAUD = '\x07'
LLC_LINK_OPENED = 0x02
LLC_LINK_CLOSED = 0x04
LLC_ECHO_REPLY = 0x06
EXTENSIONS = {}
# Maximum round-trip time
rtt = 0.4
def __init__(self, iostream, infinite_reconnect=False):
self.infinite_reconnect = infinite_reconnect
self.iostream = iostream
self.closed = False
try:
self.initial_port_settings = self.iostream.getSettingsDict()
except AttributeError:
self.initial_port_settings = None
self.port_settings_altered = False
# Whether the link is open for sending.
self._link_open = threading.Event()
# Whether the link has been severed.
self._link_closed = threading.Event()
self.send_lock = threading.RLock()
self.echoes_inflight = weakref.WeakValueDictionary()
self.protocol_handlers = weakref.WeakValueDictionary()
self.receive_thread = threading.Thread(target=self.run_receive_thread)
self.receive_thread.daemon = True
self.receive_thread.start()
self._open_link()
self.keepalive_thread = threading.Thread(
target=self.run_keepalive_thread)
self.keepalive_thread.daemon = True
self.keepalive_thread.start()
# Instantiate and bind all known extensions
for name, factory in self.EXTENSIONS.iteritems():
setattr(self, name, factory(self))
@classmethod
def register_extension(cls, name, factory):
'''Register a PULSE connection extension.
When a Connection object is instantiated, the object returned by
factory(connection_object) is assigned to connection_object.<name>.
'''
try:
getattr(cls, name)
except AttributeError:
cls.EXTENSIONS[name] = factory
else:
raise ValueError('extension name %r clashes with existing attribute'
% (name,))
@classmethod
def open_dbgserial(cls, url=None, infinite_reconnect=False):
if url is None:
url = get_dbgserial_tty()
if url == "qemu":
url = 'socket://localhost:12345'
ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS)
if url.startswith('socket://'):
# Socket class for PySerial does some pointless buffering
# setting a very small timeout effectively negates it
ser._timeout = 0.00001
return cls(ser, infinite_reconnect=infinite_reconnect)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def send(self, protocol, payload):
if self.closed:
raise exceptions.PulseError('I/O operation on closed connection')
frame = ''.join(('\0', encode_frame(protocol, payload), '\0'))
logger.debug('Connection: sending %r', frame)
with self.send_lock:
self.iostream.write(frame)
def run_receive_thread(self):
logger.debug('Connection: receive thread started')
receiver = frame_splitter(self.iostream, timeout=0)
while True:
try:
protocol, payload = decode_frame(next(receiver))
except exceptions.FrameDecodeError:
continue
except:
# Probably a PySerial exception complaining about reading from a
# closed port. Eat the exception and shut down the thread; users
# don't need to see the stack trace.
logger.debug('Connection: exception in receive thread:\n%s',
traceback.format_exc())
break
logger.debug('Connection:run_receive_thread: '
'protocol=%d payload=%r', protocol, payload)
if protocol == self.PROTOCOL_LLC: # LLC can't be overridden
self.llc_handler(payload)
continue
try:
handler = self.protocol_handlers[protocol]
except KeyError:
self.default_receiver(protocol, payload)
else:
handler.on_receive(payload)
logger.debug('Connection: receive thread exiting')
def default_receiver(self, protocol, frame):
logger.info('Connection:default_receiver: received frame '
'with protocol %d: %r', protocol, frame)
def register_protocol_handler(self, protocol, handler):
'''Register a handler for frames bearing the specified protocol number.
handler.on_receive(payload) is called for each frame received with the
protocol number.
Protocol handlers can be unregistered by calling this function with a
handler of None.
'''
if not handler:
try:
del self.protocol_handlers[protocol]
except KeyError:
pass
return
if protocol in self.protocol_handlers:
raise exceptions.ProtocolAlreadyRegistered(
'Protocol %d is already registered by %r' % (
protocol, self.protocol_handlers[protocol]))
if not hasattr(handler, 'on_receive'):
raise ValueError('%r does not have an on_receive method')
self.protocol_handlers[protocol] = handler
def llc_handler(self, frame):
opcode = ord(frame[0])
if opcode == self.LLC_LINK_OPENED:
# MTU and MRU are from the perspective of this side of the
# connection
version, mru, mtu, timeout = struct.unpack('<xBHHB', frame)
self.version = version
# The server reports the MTU inclusive of protocol number and FCS,
# but we only care about the maximum payload length.
self.mtu = mtu - 5
self.mru = mru
# Timeout is specified in deciseconds. Convert to seconds.
self.timeout = timeout / 10.0
self._link_closed.clear()
self._link_open.set()
elif opcode == self.LLC_LINK_CLOSED:
logger.info('PULSE connection closed.')
self._link_closed.set()
elif opcode == self.LLC_ECHO_REPLY:
self._on_echo_reply(frame[1:])
else:
logger.warning('Received LLC frame with unknown type %d: %r',
opcode, frame)
def run_keepalive_thread(self):
'''The keepalive thread monitors the link, reopening it if necessary.
'''
logger.debug('Connection: keepalive thread started')
OPEN, TEST_LIVENESS, RECONNECT = range(3)
state = OPEN
next_state = state
ping_attempts = 0
ping_wait = self.rtt
while True:
# Check whether the link is being closed from our side before
# trying to keep it alive.
if not self._link_open.is_set():
return
if state == OPEN:
time.sleep(1)
if self._link_closed.is_set():
next_state = RECONNECT
else:
next_state = TEST_LIVENESS
elif state == TEST_LIVENESS:
if ping_attempts < 3:
ping_attempts += 1
ping_wait *= 2 # Exponential backoff
if self.ping(ping_wait):
next_state = OPEN
else:
logger.info('No response to keepalive ping -- '
'strike %d', ping_attempts)
else:
logger.info('Connection: keepalive timed out.')
next_state = RECONNECT
elif state == RECONNECT:
# Lock out everyone from sending so that applications don't send
# to a connection that's in an indeterminate state.
with self.send_lock:
if self.port_settings_altered:
# Ensure that the server has timed out and reset its
# baud rate so we don't get into the bad situation where
# we try to reconnect at the default baud rate but the
# server is listening at a different rate, which is
# practically guaranteed to fail.
logger.info('Letting connection time out before '
'attempting to reconnect.')
time.sleep(self.timeout + self.rtt)
self._link_open.clear()
while not self._link_open.is_set():
try:
self._open_link()
except exceptions.PulseError as e:
logger.warning('Connection: reconnect failed. %s', e)
if not self.infinite_reconnect:
break
logger.warning('Will try again.')
logger.info('Backing off for a while before retrying.')
time.sleep(self.timeout + self.rtt)
else:
next_state = OPEN
else:
assert False, 'Invalid state %d' % state
if next_state != state:
if next_state == TEST_LIVENESS:
ping_attempts = 0
ping_wait = self.rtt
state = next_state
def _open_link(self):
self.closed = False
if self.initial_port_settings:
self.iostream.applySettingsDict(self.initial_port_settings)
for attempt in xrange(5):
logger.info('Opening link (attempt %d)...', attempt)
self.send(self.PROTOCOL_LLC, self.LLC_LINK_OPEN_REQUEST)
if self._link_open.wait(self.rtt):
logger.info('Established PULSE connection!')
logger.info('Version=%d MTU=%d MRU=%d Timeout=%.1f',
self.version, self.mtu, self.mru, self.timeout)
break
else:
self._link_closed.set()
self.closed = True
raise exceptions.PulseError('Could not establish connection')
def close(self):
self._link_open.clear()
if not self._link_closed.is_set():
for attempt in xrange(3):
self.send(self.PROTOCOL_LLC, self.LLC_LINK_CLOSE_REQUEST)
if self._link_closed.wait(self.rtt):
break
else:
logger.warning('Could not confirm link close.')
self._link_closed.set()
self.iostream.close()
self.closed = True
def ping(self, timeout=None):
if not timeout:
timeout = 2 * self.rtt
nonce = uuid.uuid4().bytes
is_received = threading.Event()
self.echoes_inflight[nonce] = is_received
self.send(self.PROTOCOL_LLC, self.LLC_ECHO_REQUEST + nonce)
return is_received.wait(timeout)
def _on_echo_reply(self, payload):
try:
receive_event = self.echoes_inflight[payload]
receive_event.set()
except KeyError:
pass
def change_baud_rate(self, new_baud):
# Fail fast if the IO object doesn't support changing the baud rate
old_baud = self.iostream.baudrate
self.send(self.PROTOCOL_LLC,
self.LLC_CHANGE_BAUD + struct.pack('<I', new_baud))
# Be extra sure that the message has been sent and it's safe to adjust
# the baud rate on the port.
time.sleep(0.1)
self.iostream.baudrate = new_baud
self.port_settings_altered = True
class ProtocolSocket(object):
'''A socket for sending and receiving datagrams of a single protocol over a
PULSE connection.
It is also an example of a Connection protocol handler implementation.
'''
def __init__(self, connection, protocol):
self.connection = connection
self.protocol = protocol
self.receive_queue = Queue.Queue()
self.connection.register_protocol_handler(protocol, self)
def on_receive(self, frame):
self.receive_queue.put(frame)
def receive(self, block=True, timeout=None):
try:
return self.receive_queue.get(block, timeout)
except Queue.Empty:
raise exceptions.ReceiveQueueEmpty
def send(self, frame):
self.connection.send(self.protocol, frame)
@property
def mtu(self):
return self.connection.mtu
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
with Connection.open_dbgserial(sys.argv[1]) as sock:
sock.change_baud_rate(921600)
for _ in xrange(20):
time.sleep(0.5)
send_time = time.time()
if sock.ping():
print "Ping rtt=%.2f ms" % ((time.time() - send_time) * 1000)
else:
print "No echo"