Import of the watch repository from Pebble

This commit is contained in:
Matthieu Jeanson 2024-12-12 16:43:03 -08:00 committed by Katharine Berry
commit 3b92768480
10334 changed files with 2564465 additions and 0 deletions

93
python_libs/pblprog/.gitignore vendored Normal file
View file

@ -0,0 +1,93 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject
.waf*
.lock*
*.swp

View file

@ -0,0 +1,16 @@
# pblprog
Utility for flashing Pebble BigBoards over SWD
Installation
------------
Install from PyPi (https://pebbletechnology.atlassian.net/wiki/display/DEV/pypi) under the package name `pebble.programmer`.
Supported devices
-----------------
It is relatively easy to add support for any STM32/SWD based Pebble BigBoard. Currently supported are:
- silk_bb
- robert_bb

View file

@ -0,0 +1,116 @@
IRQ_DEF(0, WWDG) // Window WatchDog
IRQ_DEF(1, PVD) // PVD through EXTI Line detection
IRQ_DEF(2, TAMP_STAMP) // Tamper and TimeStamps through the EXTI line
IRQ_DEF(3, RTC_WKUP) // RTC Wakeup through the EXTI line
IRQ_DEF(4, FLASH) // FLASH
IRQ_DEF(5, RCC) // RCC
IRQ_DEF(6, EXTI0) // EXTI Line0
IRQ_DEF(7, EXTI1) // EXTI Line1
IRQ_DEF(8, EXTI2) // EXTI Line2
IRQ_DEF(9, EXTI3) // EXTI Line3
IRQ_DEF(10, EXTI4) // EXTI Line4
IRQ_DEF(11, DMA1_Stream0) // DMA1 Stream 0
IRQ_DEF(12, DMA1_Stream1) // DMA1 Stream 1
IRQ_DEF(13, DMA1_Stream2) // DMA1 Stream 2
IRQ_DEF(14, DMA1_Stream3) // DMA1 Stream 3
IRQ_DEF(15, DMA1_Stream4) // DMA1 Stream 4
IRQ_DEF(16, DMA1_Stream5) // DMA1 Stream 5
IRQ_DEF(17, DMA1_Stream6) // DMA1 Stream 6
IRQ_DEF(18, ADC) // ADC1, ADC2 and ADC3s
IRQ_DEF(19, CAN1_TX) // CAN1 TX
IRQ_DEF(20, CAN1_RX0) // CAN1 RX0
IRQ_DEF(21, CAN1_RX1) // CAN1 RX1
IRQ_DEF(22, CAN1_SCE) // CAN1 SCE
IRQ_DEF(23, EXTI9_5) // External Line[9:5]s
IRQ_DEF(24, TIM1_BRK_TIM9) // TIM1 Break and TIM9
IRQ_DEF(25, TIM1_UP_TIM10) // TIM1 Update and TIM10
IRQ_DEF(26, TIM1_TRG_COM_TIM11) // TIM1 Trigger and Commutation and TIM11
IRQ_DEF(27, TIM1_CC) // TIM1 Capture Compare
IRQ_DEF(28, TIM2) // TIM2
IRQ_DEF(29, TIM3) // TIM3
IRQ_DEF(30, TIM4) // TIM4
IRQ_DEF(31, I2C1_EV) // I2C1 Event
IRQ_DEF(32, I2C1_ER) // I2C1 Error
IRQ_DEF(33, I2C2_EV) // I2C2 Event
IRQ_DEF(34, I2C2_ER) // I2C2 Error
IRQ_DEF(35, SPI1) // SPI1
IRQ_DEF(36, SPI2) // SPI2
IRQ_DEF(37, USART1) // USART1
IRQ_DEF(38, USART2) // USART2
IRQ_DEF(39, USART3) // USART3
IRQ_DEF(40, EXTI15_10) // External Line[15:10]s
IRQ_DEF(41, RTC_Alarm) // RTC Alarm (A and B) through EXTI Line
IRQ_DEF(42, OTG_FS_WKUP) // USB OTG FS Wakeup through EXTI line
IRQ_DEF(43, TIM8_BRK_TIM12) // TIM8 Break and TIM12
IRQ_DEF(44, TIM8_UP_TIM13) // TIM8 Update and TIM13
IRQ_DEF(45, TIM8_TRG_COM_TIM14) // TIM8 Trigger and Commutation and TIM14
IRQ_DEF(46, TIM8_CC) // TIM8 Capture Compare
IRQ_DEF(47, DMA1_Stream7) // DMA1 Stream7
IRQ_DEF(48, FSMC) // FSMC
IRQ_DEF(49, SDIO) // SDIO
IRQ_DEF(50, TIM5) // TIM5
IRQ_DEF(51, SPI3) // SPI3
#if !defined(STM32F412xG)
IRQ_DEF(52, UART4) // UART4
IRQ_DEF(53, UART5) // UART5
IRQ_DEF(54, TIM6_DAC) // TIM6 and DAC1&2 underrun errors
#else
IRQ_DEF(54, TIM6) // TIM6
#endif
IRQ_DEF(55, TIM7) // TIM7
IRQ_DEF(56, DMA2_Stream0) // DMA2 Stream 0
IRQ_DEF(57, DMA2_Stream1) // DMA2 Stream 1
IRQ_DEF(58, DMA2_Stream2) // DMA2 Stream 2
IRQ_DEF(59, DMA2_Stream3) // DMA2 Stream 3
IRQ_DEF(60, DMA2_Stream4) // DMA2 Stream 4
#if !defined(STM32F412xG)
IRQ_DEF(61, ETH) // Ethernet
IRQ_DEF(62, ETH_WKUP) // Ethernet Wakeup through EXTI line
#else
IRQ_DEF(61, DFSDM1) // DFSDM1
IRQ_DEF(62, DFSDM2) // DFSDM2
#endif
IRQ_DEF(63, CAN2_TX) // CAN2 TX
IRQ_DEF(64, CAN2_RX0) // CAN2 RX0
IRQ_DEF(65, CAN2_RX1) // CAN2 RX1
IRQ_DEF(66, CAN2_SCE) // CAN2 SCE
IRQ_DEF(67, OTG_FS) // USB OTG FS
IRQ_DEF(68, DMA2_Stream5) // DMA2 Stream 5
IRQ_DEF(69, DMA2_Stream6) // DMA2 Stream 6
IRQ_DEF(70, DMA2_Stream7) // DMA2 Stream 7
IRQ_DEF(71, USART6) // USART6
IRQ_DEF(72, I2C3_EV) // I2C3 event
IRQ_DEF(73, I2C3_ER) // I2C3 error
#if !defined(STM32F412xG)
IRQ_DEF(74, OTG_HS_EP1_OUT) // USB OTG HS End Point 1 Out
IRQ_DEF(75, OTG_HS_EP1_IN) // USB OTG HS End Point 1 In
IRQ_DEF(76, OTG_HS_WKUP) // USB OTG HS Wakeup through EXTI
IRQ_DEF(77, OTG_HS) // USB OTG HS
IRQ_DEF(78, DCMI) // DCMI
IRQ_DEF(79, CRYP) // CRYP crypto
#endif
#if !defined(STM32F412xG)
IRQ_DEF(80, HASH_RNG) // Hash and Rng
#else
IRQ_DEF(80, RNG) // Rng
#endif
#if !defined(STM32F2XX) // STM32F2 IRQs end here
IRQ_DEF(81, FPU) // FPU
#if !defined(STM32F412xG)
IRQ_DEF(82, UART7) // UART7
IRQ_DEF(83, UART8) // UART8
#endif
IRQ_DEF(84, SPI4) // SPI4
IRQ_DEF(85, SPI5) // SPI5
#if !defined(STM32F412xG)
IRQ_DEF(86, SPI6) // SPI6
IRQ_DEF(87, SAI1) // SAI1
IRQ_DEF(88, LTDC) // LTDC
IRQ_DEF(89, LTDC_ER) // LTDC_ER
IRQ_DEF(90, DMA2D) // DMA2D
#else
IRQ_DEF(92, QUADSPI) // QUADSPI
IRQ_DEF(95, FMPI2C1_EV) // FMPI2C1 Event
IRQ_DEF(96, FMPI2C1_ER) // FMPI2C1 Error
#endif
#endif // !defined(STM32F2XX)

View file

@ -0,0 +1,112 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#define HEADER_ADDR (0x20000400)
#define DATA_ADDR (0x20000800)
#define FLASH_SR_ADDR (0x40023C0C)
#define STATE_WAITING (0)
#define STATE_WRITE (1)
#define STATE_CRC (2)
// typedef to make it easy to change the program size
typedef uint8_t p_size_t;
typedef struct __attribute__((__packed__)) {
uint32_t state;
volatile p_size_t *addr;
uint32_t length;
} Header;
static uint8_t prv_crc8(const uint8_t *data, uint32_t data_len) {
uint8_t crc = 0;
// nibble lookup table for (x^8 + x^5 + x^3 + x^2 + x + 1)
static const uint8_t lookup_table[] =
{ 0, 47, 94, 113, 188, 147, 226, 205, 87, 120, 9, 38, 235, 196, 181, 154 };
for (uint32_t i = 0; i < data_len * 2; i++) {
uint8_t nibble = data[i / 2];
if (i % 2 == 0) {
nibble >>= 4;
}
uint8_t index = nibble ^ (crc >> 4);
crc = lookup_table[index & 0xf] ^ ((crc << 4) & 0xf0);
}
return crc;
}
static void prv_wait_for_flash_not_busy(void) {
while ((*(volatile uint32_t *)FLASH_SR_ADDR) & (1 << 16)); // BSY flag in FLASH_SR
}
__attribute__((__noreturn__)) void Reset_Handler(void) {
// Disable all interrupts
__asm__("cpsid i" : : : "memory");
volatile uint32_t *flash_sr = (volatile uint32_t *)FLASH_SR_ADDR;
volatile p_size_t *data = (volatile p_size_t *)DATA_ADDR;
volatile Header *header = (volatile Header *)HEADER_ADDR;
header->state = STATE_WAITING;
while(1) {
switch (header->state) {
case STATE_WRITE:
prv_wait_for_flash_not_busy();
for (uint32_t i = 0; i < header->length / sizeof(p_size_t); i++) {
header->addr[i] = data[i];
__asm__("isb 0xF":::"memory");
__asm__("dsb 0xF":::"memory");
/// Wait until flash isn't busy
prv_wait_for_flash_not_busy();
if (*flash_sr & (0x1f << 4)) {
// error raised, set bad state
header->state = *flash_sr;
}
if (header->addr[i] != data[i]) {
header->state = 0xbd;
}
}
header->addr += header->length / sizeof(p_size_t);
header->state = STATE_WAITING;
break;
case STATE_CRC:
*data = prv_crc8((uint8_t *)header->addr, header->length);
header->state = STATE_WAITING;
break;
default:
break;
}
}
__builtin_unreachable();
}
//! These symbols are defined in the linker script for use in initializing
//! the data sections. uint8_t since we do arithmetic with section lengths.
//! These are arrays to avoid the need for an & when dealing with linker symbols.
extern uint8_t _estack[];
__attribute__((__section__(".isr_vector"))) const void * const vector_table[] = {
_estack,
Reset_Handler
};

View file

@ -0,0 +1,136 @@
__Stack_Size = 128;
PROVIDE ( _Stack_Size = __Stack_Size ) ;
__Stack_Init = _estack - __Stack_Size ;
PROVIDE ( _Stack_Init = __Stack_Init ) ;
MEMORY
{
RAM (rwx) : ORIGIN = 0x20000000, LENGTH = 1K
}
SECTIONS
{
/* for Cortex devices, the beginning of the startup code is stored in the .isr_vector section, which goes to FLASH */
.isr_vector :
{
. = ALIGN(4);
KEEP(*(.isr_vector)) /* Startup code */
. = ALIGN(4);
} >RAM
/* for some STRx devices, the beginning of the startup code is stored in the .flashtext section, which goes to FLASH */
.flashtext :
{
. = ALIGN(4);
*(.flashtext) /* Startup code */
. = ALIGN(4);
} >RAM
/* Exception handling sections. "contains index entries for section unwinding" */
.ARM.exidx :
{
. = ALIGN(4);
*(.ARM.exidx)
. = ALIGN(4);
} >RAM
/* the program code is stored in the .text section, which goes to Flash */
.text :
{
. = ALIGN(4);
*(.text) /* remaining code */
*(.text.*) /* remaining code */
*(.rodata) /* read-only data (constants) */
*(.rodata*)
*(.constdata) /* read-only data (constants) */
*(.constdata*)
*(.glue_7)
*(.glue_7t)
*(i.*)
. = ALIGN(4);
} >RAM
/* This is the initialized data section
The program executes knowing that the data is in the RAM
but the loader puts the initial values in the FLASH (inidata).
It is one task of the startup to copy the initial values from FLASH to RAM. */
.data : {
. = ALIGN(4);
/* This is used by the startup in order to initialize the .data secion */
__data_start = .;
*(.data)
*(.data.*)
. = ALIGN(4);
__data_end = .; /* This is used by the startup in order to initialize the .data secion */
} >RAM
__data_load_start = LOADADDR(.data);
/* This is the uninitialized data section */
.bss (NOLOAD) : {
. = ALIGN(4);
__bss_start = .; /* This is used by the startup in order to initialize the .bss secion */
*(.bss)
*(.bss.*)
*(COMMON)
. = ALIGN(4);
__bss_end = .; /* This is used by the startup in order to initialize the .bss secion */
} >RAM
.stack (NOLOAD) : {
. = ALIGN(8);
_sstack = .;
. = . + __Stack_Size;
. = ALIGN(8);
_estack = .;
} >RAM
/* after that it's only debugging information. */
/* remove the debugging information from the standard libraries */
DISCARD : {
libc.a ( * )
libm.a ( * )
libgcc.a ( * )
}
/* Stabs debugging sections. */
.stab 0 : { *(.stab) }
.stabstr 0 : { *(.stabstr) }
.stab.excl 0 : { *(.stab.excl) }
.stab.exclstr 0 : { *(.stab.exclstr) }
.stab.index 0 : { *(.stab.index) }
.stab.indexstr 0 : { *(.stab.indexstr) }
.comment 0 : { *(.comment) }
/* DWARF debug sections.
Symbols in the DWARF debugging sections are relative to the beginning
of the section so we begin them at 0. */
/* DWARF 1 */
.debug 0 : { *(.debug) }
.line 0 : { *(.line) }
/* GNU DWARF 1 extensions */
.debug_srcinfo 0 : { *(.debug_srcinfo) }
.debug_sfnames 0 : { *(.debug_sfnames) }
/* DWARF 1.1 and DWARF 2 */
.debug_aranges 0 : { *(.debug_aranges) }
.debug_pubnames 0 : { *(.debug_pubnames) }
/* DWARF 2 */
.debug_info 0 : { *(.debug_info .gnu.linkonce.wi.*) }
.debug_abbrev 0 : { *(.debug_abbrev) }
.debug_line 0 : { *(.debug_line) }
.debug_frame 0 : { *(.debug_frame) }
.debug_str 0 : { *(.debug_str) }
.debug_loc 0 : { *(.debug_loc) }
.debug_macinfo 0 : { *(.debug_macinfo) }
/* SGI/MIPS DWARF 2 extensions */
.debug_weaknames 0 : { *(.debug_weaknames) }
.debug_funcnames 0 : { *(.debug_funcnames) }
.debug_typenames 0 : { *(.debug_typenames) }
.debug_varnames 0 : { *(.debug_varnames) }
}

View file

@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from waflib import Utils, Errors
from waflib.TaskGen import after, feature
@after('apply_link')
@feature('cprogram', 'cshlib')
def process_ldscript(self):
if not getattr(self, 'ldscript', None) or self.env.CC_NAME != 'gcc':
return
node = self.path.find_resource(self.ldscript)
if not node:
raise Errors.WafError('could not find %r' % self.ldscript)
self.link_task.env.append_value('LINKFLAGS', '-T%s' % node.abspath())
self.link_task.dep_nodes.append(node)

View file

@ -0,0 +1,84 @@
#!/usr/bin/python
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Grygoriy Fuchedzhy 2010
"""
Support for converting linked targets to ihex, srec or binary files using
objcopy. Use the 'objcopy' feature in conjuction with the 'cc' or 'cxx'
feature. The 'objcopy' feature uses the following attributes:
objcopy_bfdname Target object format name (eg. ihex, srec, binary).
Defaults to ihex.
objcopy_target File name used for objcopy output. This defaults to the
target name with objcopy_bfdname as extension.
objcopy_install_path Install path for objcopy_target file. Defaults to ${PREFIX}/fw.
objcopy_flags Additional flags passed to objcopy.
"""
from waflib.Utils import def_attrs
from waflib import Task
from waflib.TaskGen import feature, after_method
class objcopy(Task.Task):
run_str = '${OBJCOPY} -O ${TARGET_BFDNAME} ${OBJCOPYFLAGS} ${SRC} ${TGT}'
color = 'CYAN'
@feature('objcopy')
@after_method('apply_link')
def objcopy(self):
def_attrs(self,
objcopy_bfdname='ihex',
objcopy_target=None,
objcopy_install_path="${PREFIX}/firmware",
objcopy_flags='')
link_output = self.link_task.outputs[0]
if not self.objcopy_target:
self.objcopy_target = link_output.change_ext('.' + self.objcopy_bfdname).name
elif isinstance(self.objcopy_target, str):
self.objcopy_target = self.path.find_or_declare(self.objcopy_target)
task = self.create_task('objcopy',
src=link_output,
tgt=self.objcopy_target)
task.env.append_unique('TARGET_BFDNAME', self.objcopy_bfdname)
try:
task.env.append_unique('OBJCOPYFLAGS', getattr(self, 'objcopy_flags'))
except AttributeError:
pass
if self.objcopy_install_path:
self.bld.install_files(self.objcopy_install_path,
task.outputs[0],
env=task.env.derive())
def configure(ctx):
objcopy = ctx.find_program('objcopy', var='OBJCOPY', mandatory=True)
def objcopy_simple(task, mode):
return task.exec_command('arm-none-eabi-objcopy -S -R .stack -R .priv_bss'
' -R .bss -O %s "%s" "%s"' %
(mode, task.inputs[0].abspath(), task.outputs[0].abspath()))
def objcopy_simple_bin(task):
return objcopy_simple(task, 'binary')

View file

@ -0,0 +1,63 @@
# Build script for the silk loader
import sys
import os
from waflib import Logs
def options(opt):
opt.load('gcc')
def configure(conf):
# Find our binary tools
conf.find_program('arm-none-eabi-gcc', var='CC', mandatory=True)
conf.env.AS = conf.env.CC
conf.find_program('arm-none-eabi-gcc-ar', var='AR', mandatory=True)
conf.load('gcc')
for tool in 'ar objcopy'.split():
conf.find_program('arm-none-eabi-' + tool, var=tool.upper(), mandatory=True)
# Set up our compiler configuration
CPU_FLAGS = ['-mcpu=cortex-m3', '-mthumb']
OPT_FLAGS = ['-Os', '-g']
C_FLAGS = [
'-std=c11', '-ffunction-sections',
'-Wall', '-Wextra', '-Werror', '-Wpointer-arith',
'-Wno-unused-parameter', '-Wno-missing-field-initializers',
'-Wno-error=unused-function', '-Wno-error=unused-variable',
'-Wno-error=unused-parameter', '-Wno-error=unused-but-set-variable',
'-Wno-packed-bitfield-compat'
]
LINK_FLAGS = ['-Wl,--gc-sections', '-specs=nano.specs']
conf.env.append_unique('CFLAGS', CPU_FLAGS + OPT_FLAGS + C_FLAGS)
conf.env.append_unique('LINKFLAGS', LINK_FLAGS + CPU_FLAGS + OPT_FLAGS)
conf.env.append_unique('DEFINES', ['_REENT_SMALL=1'])
# Load up other waftools that we need
conf.load('objcopy ldscript', tooldir='waftools')
def build(bld):
elf_node = bld.path.get_bld().make_node('loader.elf')
linkflags = ['-Wl,-Map,loader.map']
sources = ['src/**/*.c']
includes = ['src']
bld.program(features="objcopy",
source=bld.path.ant_glob(sources),
includes=includes,
target=elf_node,
ldscript='src/stm32f4_loader.ld',
linkflags=linkflags,
objcopy_bfdname='ihex',
objcopy_target=elf_node.change_ext('.hex'))
import objcopy
bld(rule=objcopy.objcopy_simple_bin, source='loader.elf', target='loader.bin')
# vim:filetype=python

View file

@ -0,0 +1,15 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,44 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from targets import STM32F4FlashProgrammer, STM32F7FlashProgrammer
from swd_port import SerialWireDebugPort
from ftdi_swd import FTDISerialWireDebug
def get_device(board, reset=True, frequency=None):
boards = {
'silk_bb': (0x7893, 10E6, STM32F4FlashProgrammer),
'robert_bb2': (0x7894, 3E6, STM32F7FlashProgrammer)
}
if board not in boards:
raise Exception('Invalid board: {}'.format(board))
usb_pid, default_frequency, board_ctor = boards[board]
if not frequency:
frequency = default_frequency
ftdi = FTDISerialWireDebug(vid=0x0403, pid=usb_pid, interface=0, direction=0x1b,
output_mask=0x02, reset_mask=0x40, frequency=frequency)
swd_port = SerialWireDebugPort(ftdi, reset)
return board_ctor(swd_port)
def flash(board, hex_files):
with get_device(board) as programmer:
programmer.execute_loader()
for hex_file in hex_files:
programmer.load_hex(hex_file)
programmer.reset_core()

View file

@ -0,0 +1,38 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from . import flash
def main(args=None):
if args is None:
parser = argparse.ArgumentParser(description='A tool for flashing a bigboard via FTDI+SWD')
parser.add_argument('hex_files', type=str, nargs='+',
help='Path to one or more hex files to flash')
parser.add_argument('--board', action='store', choices=['robert_bb2', 'silk_bb'], required=True,
help='Which board is being programmed')
parser.add_argument('--verbose', action='store_true',
help='Output lots of debugging info to the console.')
args = parser.parse_args()
logging.basicConfig(level=(logging.DEBUG if args.verbose else logging.INFO))
flash(args.board, args.hex_files)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,137 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from array import array
from pyftdi.pyftdi.ftdi import Ftdi
import usb.util
class FTDISerialWireDebug(object):
def __init__(self, vid, pid, interface, direction, output_mask, reset_mask, frequency):
self._direction = direction
self._output_mask = output_mask
self._reset_mask = reset_mask
self._ftdi = Ftdi()
try:
self._ftdi.open_mpsse(vid, pid, interface, direction=direction, frequency=frequency,
latency=1)
except:
self._ftdi = None
raise
# get the FTDI FIFO size and increase the chuncksize to match
self._ftdi_fifo_size = min(self._ftdi.fifo_sizes)
self._ftdi.write_data_set_chunksize(self._ftdi_fifo_size)
self._cmd_buffer = array('B')
self._output_enabled = False
self._pending_acks = 0
self._sequence_cmd_buffer = None
def close(self):
if not self._ftdi:
return
self.send_cmds()
self._ftdi.close()
# PyFTDI doesn't do a good job of cleaning up - make sure we release the usb device
usb.util.dispose_resources(self._ftdi.usb_dev)
self._ftdi = None
def _fatal(self, message):
raise Exception('FATAL ERROR: {}'.format(message))
def _queue_cmd(self, write_data):
if len(write_data) > self._ftdi_fifo_size:
raise Exception('Data too big!')
if self._sequence_cmd_buffer is not None:
self._sequence_cmd_buffer.extend(write_data)
else:
if len(self._cmd_buffer) + len(write_data) > self._ftdi_fifo_size:
self.send_cmds()
self._cmd_buffer.extend(write_data)
def _set_output_enabled(self, enabled):
if enabled == self._output_enabled:
return
self._output_enabled = enabled
direction = self._direction & ~(0x00 if enabled else self._output_mask)
self._queue_cmd([Ftdi.SET_BITS_LOW, 0, direction])
def reset(self):
# toggle the reset line
self.reset_lo()
self.reset_hi()
def reset_lo(self):
direction = self._direction & ~(0x00 if self._output_enabled else self._output_mask)
self._queue_cmd([Ftdi.SET_BITS_LOW, 0, direction | self._reset_mask])
self.send_cmds()
def reset_hi(self):
direction = self._direction & ~(0x00 if self._output_enabled else self._output_mask)
self._queue_cmd([Ftdi.SET_BITS_LOW, 0, direction & ~self._reset_mask])
self.send_cmds()
def send_cmds(self):
if self._sequence_cmd_buffer is not None:
self._ftdi.write_data(self._sequence_cmd_buffer)
elif len(self._cmd_buffer) > 0:
self._ftdi.write_data(self._cmd_buffer)
self._cmd_buffer = array('B')
def write_bits_cmd(self, data, num_bits):
if num_bits < 0 or num_bits > 8:
self._fatal('Invalid num_bits')
elif (data & ((1 << num_bits) - 1)) != data:
self._fatal('Invalid data!')
self._set_output_enabled(True)
self._queue_cmd([Ftdi.WRITE_BITS_NVE_LSB, num_bits - 1, data])
def write_bytes_cmd(self, data):
length = len(data) - 1
if length < 0 or length > 0xffff:
self._fatal('Invalid length')
self._set_output_enabled(True)
self._queue_cmd([Ftdi.WRITE_BYTES_NVE_LSB, length & 0xff, length >> 8] + data)
def read_bits_cmd(self, num_bits):
if num_bits < 0 or num_bits > 8:
self._fatal('Invalid num_bits')
self._set_output_enabled(False)
self._queue_cmd([Ftdi.READ_BITS_PVE_LSB, num_bits - 1])
def read_bytes_cmd(self, length):
length -= 1
if length < 0 or length > 0xffff:
self._fatal('Invalid length')
self._set_output_enabled(False)
self._queue_cmd([Ftdi.READ_BYTES_PVE_LSB, length & 0xff, length >> 8])
def get_read_bytes(self, length):
return self._ftdi.read_data_bytes(length)
def get_read_fifo_size(self):
return self._ftdi_fifo_size
def start_sequence(self):
if self._sequence_cmd_buffer is not None:
self._fatal('Attempted to start a sequence while one is in progress')
self.send_cmds()
self._sequence_cmd_buffer = array('B')
def end_sequence(self):
if self._sequence_cmd_buffer is None:
self._fatal('No sequence started')
self._sequence_cmd_buffer = None

View file

@ -0,0 +1,287 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct
import time
class SerialWireDebugPort(object):
# debug port registers
DP_IDCODE_ADDR = 0x00
DP_ABORT_ADDR = 0x00
DP_CTRLSTAT_ADDR = 0x04
DP_SELECT_ADDR = 0x08
DP_RDBUFF_ADDR = 0x0c
# MEM-AP register
MEM_AP_CSW_ADDR = 0x0
MEM_AP_CSW_MASTER_DEBUG = (1 << 29)
MEM_AP_CSW_PRIVILEGED_MODE = (1 << 25)
MEM_AP_CSW_ADDRINCWORD = (1 << 4)
MEM_AP_CSW_SIZE8BITS = (0 << 1)
MEM_AP_CSW_SIZE32BITS = (1 << 1)
MEM_AP_TAR_ADDR = 0x4
MEM_AP_DRW_ADDR = 0xc
MEM_AP_IDR_VALUES = [0x24770011, 0x74770001]
def __init__(self, driver, reset=True):
self._driver = driver
self._swd_connected = False
self._reset = reset
self._pending_acks = 0
def close(self):
if self._swd_connected:
# power down the system and debug domains
self._write(self.DP_CTRLSTAT_ADDR, 0x00000000, is_access_port=False)
# send 1 byte worth of trailing bits since we're done communicating
self._driver.write_bits_cmd(0x00, 8)
# this makes the Saleae's SWD analyzer happy, and it's otherwise harmless
self._driver.write_bits_cmd(0x3, 2)
self._swd_connected = False
self._driver.close()
@staticmethod
def _fatal(message):
raise Exception('FATAL ERROR: {}'.format(message))
def _get_request_header(self, addr, is_read, is_access_port):
# the header consists of the following fields
# bit 0: start (always 1)
# bit 1: DebugPort (0) or AccessPort (1)
# bit 2: write (0) or read (1)
# bits 3-4: bits 2 and 3 of the address
# bit 5: parity bit such that bits 1-5 contain an even number of 1's
# bit 6: stop (always 0)
# bit 7: park (always 1)
header = 0x1
header |= (1 << 1) if is_access_port else 0
header |= (1 << 2) if is_read else 0
header |= ((addr & 0xf) >> 2) << 3
parity = 0
for i in range(1, 5):
parity += (header >> i) & 0x1
header |= (parity & 0x1) << 5
header |= 1 << 7
return header
def _send_request_header(self, addr, is_read, is_access_port):
self._driver.write_bytes_cmd([self._get_request_header(addr, is_read, is_access_port)])
def _check_write_acks(self):
if not self._pending_acks:
return
self._driver.send_cmds()
# the ACK is in the top 3 bits that we get from the FTDI read, so shift right by 5
for ack in [x >> 5 for x in self._driver.get_read_bytes(self._pending_acks)]:
if ack != 0x1:
self._fatal('ACK=0x{:02x}'.format(ack))
self._pending_acks = 0
def _read(self, addr, is_access_port):
# check any pending write ACKs before doing a read
self._check_write_acks()
# send the read request
self._send_request_header(addr, is_read=True, is_access_port=is_access_port)
# do all the reads at the same time as an optimization (and hope we get an ACK)
self._driver.read_bits_cmd(4) # 4 bits for ACK + turnaround
self._driver.read_bytes_cmd(4) # 4 data bytes
self._driver.read_bits_cmd(2) # 2 bits for parity + turnaround
self._driver.send_cmds()
result = self._driver.get_read_bytes(6)
# check the ACK
ack = result[0] >> 5
if ack != 0x1:
self._fatal('ACK=0x{:02x}'.format(ack))
# grab the response
response = struct.unpack('<I', result[1:5])[0]
# read two more bits: the parity and another for some reason I don't understand
# check that the parity is correct
parity = (result[5] >> 6) & 0x1
if parity != sum((response >> i) & 0x1 for i in range(32)) & 0x1:
self._fatal('Bad parity')
return response
def _write(self, addr, data, is_access_port, no_ack=False):
if data > 0xffffffff:
self._fatal('Bad data')
# send the write request
self._send_request_header(addr, is_read=False, is_access_port=is_access_port)
# OPTIMIZATION: queue the ACK read now and keep going (hope we get an ACK)
self._driver.read_bits_cmd(4)
# calculate the parity and send the data
parity = sum((data >> i) & 0x1 for i in range(32)) & 0x1
# OPTIMIZATION: We need to send 1 turnaround bit, 4 data bytes, and 1 parity bit.
# We can combine this into a single FTDI write by sending it as 5 bytes, so
# let's shift everything such that the extra 6 bits are at the end where they
# will be properly ignored as trailing bits.
temp = ((data << 1) & 0xfffffffe)
data_bytes = [(temp >> (i * 8)) & 0xff for i in range(4)]
data_bytes += [(data >> 31) | (parity << 1)]
self._driver.write_bytes_cmd(data_bytes)
# check the ACK(s) if necessary
self._pending_acks += 1
if not no_ack or self._pending_acks >= self._driver.get_read_fifo_size():
self._check_write_acks()
def connect(self):
if self._reset:
# reset the target
self._driver.reset_lo()
# switch from JTAG to SWD mode (based on what openocd does)
# - line reset
# - magic number of 0xE79E
# - line reset
# - 2 low bits for unknown reasons (maybe padding to nibbles?)
def line_reset():
# a line reset is 50 high bits (6 bytes + 2 bits)
self._driver.write_bytes_cmd([0xff] * 6)
self._driver.write_bits_cmd(0x3, 2)
line_reset()
self._driver.write_bytes_cmd([0x9e, 0xe7])
line_reset()
self._driver.write_bits_cmd(0x0, 2)
idcode = self._read(self.DP_IDCODE_ADDR, is_access_port=False)
# clear the error flags
self._write(self.DP_ABORT_ADDR, 0x0000001E, is_access_port=False)
# power up the system and debug domains
self._write(self.DP_CTRLSTAT_ADDR, 0xF0000001, is_access_port=False)
# check the MEM-AP IDR
# the IDR register is has the same address as the DRW register but on the 0xf bank
self._write(self.DP_SELECT_ADDR, 0xf0, is_access_port=False) # select the 0xf bank
self._read(self.MEM_AP_DRW_ADDR, is_access_port=True) # read the value register (twice)
if self._read(self.DP_RDBUFF_ADDR, is_access_port=False) not in self.MEM_AP_IDR_VALUES:
self._fatal('Invalid MEM-AP IDR')
self._write(self.DP_SELECT_ADDR, 0x0, is_access_port=False) # return to the 0x0 bank
# enable privileged access to the MEM-AP with 32 bit data accesses and auto-incrementing
csw_value = self.MEM_AP_CSW_PRIVILEGED_MODE
csw_value |= self.MEM_AP_CSW_MASTER_DEBUG
csw_value |= self.MEM_AP_CSW_ADDRINCWORD
csw_value |= self.MEM_AP_CSW_SIZE32BITS
self._write(self.MEM_AP_CSW_ADDR, csw_value, is_access_port=True)
self._swd_connected = True
if self._reset:
self._driver.reset_hi()
return idcode
def read_memory_address(self, addr):
self._write(self.MEM_AP_TAR_ADDR, addr, is_access_port=True)
self._read(self.MEM_AP_DRW_ADDR, is_access_port=True)
return self._read(self.DP_RDBUFF_ADDR, is_access_port=False)
def write_memory_address(self, addr, value):
self._write(self.MEM_AP_TAR_ADDR, addr, is_access_port=True)
self._write(self.MEM_AP_DRW_ADDR, value, is_access_port=True)
def write_memory_bulk(self, base_addr, data):
# TAR is configured as auto-incrementing, but it wraps every 4096 bytes, so that's how much
# we can write before we need to explicitly set it again.
WORD_SIZE = 4
BURST_LENGTH = 4096 / WORD_SIZE
assert(base_addr % BURST_LENGTH == 0 and len(data) % WORD_SIZE == 0)
for i in range(0, len(data), WORD_SIZE):
if i % BURST_LENGTH == 0:
# set the target address
self._write(self.MEM_AP_TAR_ADDR, base_addr + i, is_access_port=True, no_ack=True)
# write the word
word = sum(data[i+j] << (j * 8) for j in range(WORD_SIZE))
self._write(self.MEM_AP_DRW_ADDR, word, is_access_port=True, no_ack=True)
def continuous_read(self, addr, duration):
# This is a highly-optimized function which is samples the specified memory address for the
# specified duration. This is generally used for profiling by reading the PC sampling
# register.
NUM_READS = 510 # a magic number which gives us the best sample rate on Silk/Robert
# don't auto-increment the address
csw_value = self.MEM_AP_CSW_PRIVILEGED_MODE
csw_value |= self.MEM_AP_CSW_SIZE32BITS
self._write(self.MEM_AP_CSW_ADDR, csw_value, is_access_port=True)
# set the address
self._write(self.MEM_AP_TAR_ADDR, addr, is_access_port=True)
# discard the previous value
self._read(self.MEM_AP_DRW_ADDR, is_access_port=True)
# flush everything
self._check_write_acks()
header = self._get_request_header(self.MEM_AP_DRW_ADDR, is_read=True, is_access_port=True)
self._driver.start_sequence()
for i in range(NUM_READS):
self._driver.write_bits_cmd(header, 8)
self._driver.read_bits_cmd(6)
self._driver.read_bytes_cmd(4)
raw_data = []
end_time = time.time() + duration
while time.time() < end_time:
# send the read requests
self._driver.send_cmds()
# do all the reads at the same time as an optimization (and hope we get an ACK)
raw_data.extend(self._driver.get_read_bytes(5 * NUM_READS))
self._driver.end_sequence()
def get_value_from_bits(b):
result = 0
for o in range(len(b)):
result |= b[o] << o
return result
values = []
for raw_result in [raw_data[i:i+5] for i in range(0, len(raw_data), 5)]:
result = raw_result
# The result is read as 5 bytes, with the first one containing 6 bits (shifted in from
# the left as they are read). Let's convert this into an array of bits, and then
# reconstruct the values we care about.
bits = []
bits.extend((result[0] >> (2 + j)) & 1 for j in range(6))
for i in range(4):
bits.extend((result[i + 1] >> j) & 1 for j in range(8))
ack = get_value_from_bits(bits[1:4])
response = get_value_from_bits(bits[4:36])
parity = bits[36]
# check the ACK
if ack != 0x1:
self._fatal('ACK=0x{:02x}'.format(ack))
# read two more bits: the parity and another for some reason I don't understand
# check that the parity is correct
if parity != sum((response >> i) & 0x1 for i in range(32)) & 0x1:
self._fatal('Bad parity')
# store the response
values.append(response)
return values

View file

@ -0,0 +1,16 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .stm32f4 import STM32F4FlashProgrammer
from .stm32f7 import STM32F7FlashProgrammer

View file

@ -0,0 +1,352 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import logging
import os
import struct
import time
from intelhex import IntelHex
LOG = logging.getLogger(__name__)
class STM32FlashProgrammer(object):
# CPUID register
CPUID_ADDR = 0xE000ED00
# Flash constants
FLASH_BASE_ADDR = 0x08000000
# Flash key register (FLASH_KEYR)
FLASH_KEYR_ADDR = 0x40023C04
FLASH_KEYR_VAL1 = 0x45670123
FLASH_KEYR_VAL2 = 0xCDEF89AB
# Flash status register (FLASH_SR)
FLASH_SR_ADDR = 0x40023C0C
FLASH_SR_BSY = (1 << 16)
# Flash control register (FLASH_CR)
FLASH_CR_ADDR = 0x40023C10
FLASH_CR_PG = (1 << 0)
FLASH_CR_SER = (1 << 1)
FLASH_CR_SNB_OFFSET = 3
FLASH_CR_PSIZE_8BIT = (0x0 << 8)
FLASH_CR_PSIZE_16BIT = (0x1 << 8)
FLASH_CR_PSIZE_32BIT = (0x2 << 8)
FLASH_CR_STRT = (1 << 16)
# Debug halting control and status register (DHCSR)
DHCSR_ADDR = 0xE000EDF0
DHCSR_DBGKEY_VALUE = 0xA05F0000
DHCSR_HALT = (1 << 0)
DHCSR_DEBUGEN = (1 << 1)
DHCSR_S_REGRDY = (1 << 16)
DHCSR_S_LOCKUP = (1 << 19)
# Application interrupt and reset control register (AIRCR)
AIRCR_ADDR = 0xE000ED0C
AIRCR_VECTKEY_VALUE = 0x05FA0000
AIRCR_SYSRESETREQ = (1 << 2)
# Debug Core Register Selector Register (DCRSR)
DCRSR_ADDR = 0xE000EDF4
DCRSR_WRITE = (1 << 16)
# Debug Core Register Data register (DCRDR)
DCRDR_ADDR = 0xE000EDF8
# Debug Exception and Monitor Control register (DEMCR)
DEMCR_ADDR = 0xE000EDFC
DEMCR_RESET_CATCH = (1 << 0)
DEMCR_TRCENA = (1 << 24)
# Program Counter Sample Register (PCSR)
PCSR_ADDR = 0xE000101C
# Loader addresses
PBLLDR_HEADER_ADDR = 0x20000400
PBLLDR_HEADER_OFFSET = PBLLDR_HEADER_ADDR + 0x4
PBLLDR_HEADER_LENGTH = PBLLDR_HEADER_ADDR + 0x8
PBLLDR_DATA_ADDR = 0x20000800
PBLLDR_DATA_MAX_LENGTH = 0x20000
PBLLDR_STATE_WAIT = 0
PBLLDR_STATE_WRITE = 1
PBLLDR_STATE_CRC = 2
# SRAM base addr
SRAM_BASE_ADDR = 0x20000000
def __init__(self, driver):
self._driver = driver
self._step_start_time = 0
self.FLASH_SECTOR_SIZES = [x*1024 for x in self.FLASH_SECTOR_SIZES]
def __enter__(self):
try:
self.connect()
return self
except:
self.close()
raise
def __exit__(self, exc, value, trace):
self.close()
def _fatal(self, message):
raise Exception('FATAL ERROR: {}'.format(message))
def _start_step(self, msg):
LOG.info(msg)
self._step_start_time = time.time()
def _end_step(self, msg, no_time=False, num_bytes=None):
total_time = round(time.time() - self._step_start_time, 2)
if not no_time:
msg += ' in {}s'.format(total_time)
if num_bytes:
kibps = round(num_bytes / 1024.0 / total_time, 2)
msg += ' ({} KiB/s)'.format(kibps)
LOG.info(msg)
def connect(self):
self._start_step('Connecting...')
# connect and check the IDCODE
if self._driver.connect() != self.IDCODE:
self._fatal('Invalid IDCODE')
# check the CPUID register
if self._driver.read_memory_address(self.CPUID_ADDR) != self.CPUID_VALUE:
self._fatal('Invalid CPU ID')
self._end_step('Connected', no_time=True)
def halt_core(self):
# halt the core immediately
dhcsr_value = self.DHCSR_DBGKEY_VALUE | self.DHCSR_DEBUGEN | self.DHCSR_HALT
self._driver.write_memory_address(self.DHCSR_ADDR, dhcsr_value)
def resume_core(self):
# resume the core
dhcsr_value = self.DHCSR_DBGKEY_VALUE
self._driver.write_memory_address(self.DHCSR_ADDR, dhcsr_value)
def reset_core(self, halt=False):
if self._driver.read_memory_address(self.DHCSR_ADDR) & self.DHCSR_S_LOCKUP:
# halt the core first to clear the lockup
LOG.info('Clearing lockup condition')
self.halt_core()
# enable reset vector catch
demcr_value = 0
if halt:
demcr_value |= self.DEMCR_RESET_CATCH
self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value)
self._driver.read_memory_address(self.DHCSR_ADDR)
# reset the core
aircr_value = self.AIRCR_VECTKEY_VALUE | self.AIRCR_SYSRESETREQ
self._driver.write_memory_address(self.AIRCR_ADDR, aircr_value)
if halt:
self.halt_core()
def unlock_flash(self):
# unlock the flash
self._driver.write_memory_address(self.FLASH_KEYR_ADDR, self.FLASH_KEYR_VAL1)
self._driver.write_memory_address(self.FLASH_KEYR_ADDR, self.FLASH_KEYR_VAL2)
def _poll_register(self, timeout=0.5):
end_time = time.time() + timeout
while end_time > time.time():
val = self._driver.read_memory_address(self.DHCSR_ADDR)
if val & self.DHCSR_S_REGRDY:
break
else:
raise Exception('Register operation was not confirmed')
def write_register(self, reg, val):
self._driver.write_memory_address(self.DCRDR_ADDR, val)
reg |= self.DCRSR_WRITE
self._driver.write_memory_address(self.DCRSR_ADDR, reg)
self._poll_register()
def read_register(self, reg):
self._driver.write_memory_address(self.DCRSR_ADDR, reg)
self._poll_register()
return self._driver.read_memory_address(self.DCRDR_ADDR)
def erase_flash(self, flash_offset, length):
self._start_step('Erasing...')
def overlap(a1, a2, b1, b2):
return max(a1, b1) < min(a2, b2)
# find all the sectors which we need to erase
erase_sectors = []
for i, size in enumerate(self.FLASH_SECTOR_SIZES):
addr = self.FLASH_BASE_ADDR + sum(self.FLASH_SECTOR_SIZES[:i])
if overlap(flash_offset, flash_offset+length, addr, addr+size):
erase_sectors += [i]
if not erase_sectors:
self._fatal('Could not find sectors to erase!')
# erase the sectors
for sector in erase_sectors:
# start the erase
reg_value = (sector << self.FLASH_CR_SNB_OFFSET)
reg_value |= self.FLASH_CR_PSIZE_8BIT
reg_value |= self.FLASH_CR_STRT
reg_value |= self.FLASH_CR_SER
self._driver.write_memory_address(self.FLASH_CR_ADDR, reg_value)
# wait for the erase to finish
while self._driver.read_memory_address(self.FLASH_SR_ADDR) & self.FLASH_SR_BSY:
time.sleep(0)
self._end_step('Erased')
def close(self):
self._driver.close()
def _write_loader_state(self, state):
self._driver.write_memory_address(self.PBLLDR_HEADER_ADDR, state)
def _wait_loader_state(self, wanted_state, timeout=3):
end_time = time.time() + timeout
state = -1
while time.time() < end_time:
time.sleep(0)
state = self._driver.read_memory_address(self.PBLLDR_HEADER_ADDR)
if state == wanted_state:
break
else:
raise Exception("Timed out waiting for loader state %d, got %d" % (wanted_state, state))
@staticmethod
def _chunks(l, n):
for i in xrange(0, len(l), n):
yield l[i:i+n], len(l[i:i+n]), i
def execute_loader(self):
# reset and halt the core
self.reset_core(halt=True)
with open(os.path.join(os.path.dirname(__file__), "loader.bin")) as f:
loader_bin = f.read()
# load loader binary into SRAM
self._driver.write_memory_bulk(self.SRAM_BASE_ADDR, array.array('B', loader_bin))
# set PC based on value in loader
reg_sp, = struct.unpack("<I", loader_bin[:4])
self.write_register(13, reg_sp)
# set PC to new reset handler
pc, = struct.unpack('<I', loader_bin[4:8])
self.write_register(15, pc)
# unlock flash
self.unlock_flash()
self.resume_core()
@staticmethod
def generate_crc(data):
length = len(data)
lookup_table = [0, 47, 94, 113, 188, 147, 226, 205, 87, 120, 9, 38, 235, 196, 181, 154]
crc = 0
for i in xrange(length*2):
nibble = data[i / 2]
if i % 2 == 0:
nibble >>= 4
index = nibble ^ (crc >> 4)
crc = lookup_table[index & 0xf] ^ ((crc << 4) & 0xf0)
return crc
def read_crc(self, addr, length):
self._driver.write_memory_address(self.PBLLDR_HEADER_OFFSET, addr)
self._driver.write_memory_address(self.PBLLDR_HEADER_LENGTH, length)
self._write_loader_state(self.PBLLDR_STATE_CRC)
self._wait_loader_state(self.PBLLDR_STATE_WAIT)
return self._driver.read_memory_address(self.PBLLDR_DATA_ADDR) & 0xFF
def load_hex(self, hex_path):
self._start_step("Loading binary: %s" % hex_path)
ih = IntelHex(hex_path)
offset = ih.minaddr()
data = ih.tobinarray()
self.load_bin(offset, data)
self._end_step("Loaded binary", num_bytes=len(data))
def load_bin(self, offset, data):
while len(data) % 4 != 0:
data.append(0xFF)
length = len(data)
# prepare the flash for programming
self.erase_flash(offset, length)
cr_value = self.FLASH_CR_PSIZE_8BIT | self.FLASH_CR_PG
self._driver.write_memory_address(self.FLASH_CR_ADDR, cr_value)
# set the base address
self._wait_loader_state(self.PBLLDR_STATE_WAIT)
self._driver.write_memory_address(self.PBLLDR_HEADER_OFFSET, offset)
for chunk, chunk_length, pos in self._chunks(data, self.PBLLDR_DATA_MAX_LENGTH):
LOG.info("Written %d/%d", pos, length)
self._driver.write_memory_address(self.PBLLDR_HEADER_LENGTH, chunk_length)
self._driver.write_memory_bulk(self.PBLLDR_DATA_ADDR, chunk)
self._write_loader_state(self.PBLLDR_STATE_WRITE)
self._wait_loader_state(self.PBLLDR_STATE_WAIT)
expected_crc = self.generate_crc(data)
actual_crc = self.read_crc(offset, length)
if actual_crc != expected_crc:
raise Exception("Bad CRC, expected %d, found %d" % (expected_crc, actual_crc))
LOG.info("CRC-8 matched: %d", actual_crc)
def profile(self, duration):
LOG.info('Collecting %f second(s) worth of samples...', duration)
# ensure DWT is enabled so we can get PC samples from PCSR
demcr_value = self._driver.read_memory_address(self.DEMCR_ADDR)
self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value | self.DEMCR_TRCENA)
# take the samples
samples = self._driver.continuous_read(self.PCSR_ADDR, duration)
# restore the original DEMCR value
self._driver.write_memory_address(self.DEMCR_ADDR, demcr_value)
# process the samples
pcs = dict()
for sample in samples:
sample = '0x%08x' % sample
pcs[sample] = pcs.get(sample, 0) + 1
LOG.info('Collected %d samples!', len(samples))
return pcs

View file

@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from stm32 import STM32FlashProgrammer
class STM32F4FlashProgrammer(STM32FlashProgrammer):
IDCODE = 0x2BA01477
CPUID_VALUE = 0x410FC241
FLASH_SECTOR_SIZES = [16, 16, 16, 16, 64, 128, 128, 128, 128, 128, 128, 128]

View file

@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from stm32 import STM32FlashProgrammer
class STM32F7FlashProgrammer(STM32FlashProgrammer):
IDCODE = 0x5BA02477
CPUID_VALUE = 0x411FC270
FLASH_SECTOR_SIZES = [32, 32, 32, 32, 128, 256, 256, 256, 256, 256, 256, 256]

View file

@ -0,0 +1,44 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
setup(
name='pebble.programmer',
version='0.0.3',
description='Pebble Programmer',
url='https://github.com/pebble/pblprog',
author='Pebble Technology Corporation',
author_email='liam@pebble.com',
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
namespace_packages=['pebble'],
install_requires=[
'intelhex>=2.1,<3',
'pyftdi==0.10.5'
],
package_data={
'pebble.programmer.targets': ['loader.bin']
},
entry_points={
'console_scripts': [
'pblprog = pebble.programmer.__main__:main',
],
},
zip_safe=False
)

89
python_libs/pebble-commander/.gitignore vendored Normal file
View file

@ -0,0 +1,89 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject

View file

@ -0,0 +1,2 @@
Pebble Commander
================

View file

@ -0,0 +1,15 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,16 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .commander import PebbleCommander
from . import _commands

View file

@ -0,0 +1,67 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import argparse
import logging
import sys
from . import interactive
def main(args=None):
def reattach_handler(logger, formatter, handler):
if handler is not None:
logger.removeHandler(handler)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)
return handler
if args is None:
parser = argparse.ArgumentParser(description='Pebble Commander.')
parser.add_argument('-v', '--verbose', help='verbose logging', action='count',
default=0)
parser.add_argument('-t', '--tty', help='serial port (defaults to auto-detect)', metavar='TTY',
default=None)
parser.add_argument('-c', '--pcap', metavar='FILE', default=None,
help='write packet capture to pcap file')
parser.add_argument('dict', help='log-hashing dictionary file', metavar='loghash_dict.json',
nargs='?', default=None)
args = parser.parse_args()
log_level = (logging.DEBUG if args.verbose >= 2
else logging.INFO if args.verbose >= 1
else logging.WARNING)
use_colors = True
formatter_string = '%(name)-12s: %(levelname)-8s %(message)s'
if use_colors:
formatter_string = '\x1b[33m%s\x1b[m' % formatter_string
formatter = logging.Formatter(formatter_string)
handler = reattach_handler(logging.getLogger(), formatter, None)
logging.getLogger().setLevel(log_level)
with interactive.InteractivePebbleCommander(
loghash_path=args.dict, tty=args.tty, capfile=args.pcap) as cmdr:
cmdr.attach_prompt_toolkit()
# Re-create the logging handler to use the patched stdout
handler = reattach_handler(logging.getLogger(), formatter, handler)
cmdr.command_loop()
if __name__ == '__main__':
main()

View file

@ -0,0 +1,27 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import app
from . import battery
from . import bluetooth
from . import clicks
from . import flash
from . import help
from . import imaging
from . import misc
from . import pfs
from . import resets
from . import system
from . import time
from . import windows

View file

@ -0,0 +1,85 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def app_list(cmdr):
""" List applications.
"""
return cmdr.send_prompt_command("app list")
@PebbleCommander.command()
def app_load_metadata(cmdr):
""" Ghetto metadata loading for pbw_image.py
"""
ret = cmdr.send_prompt_command("app load metadata")
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def app_launch(cmdr, idnum):
""" Launch an application.
"""
idnum = int(str(idnum), 0)
if idnum == 0:
raise exceptions.ParameterError('idnum out of range: %d' % idnum)
ret = cmdr.send_prompt_command("app launch %d" % idnum)
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def app_remove(cmdr, idnum):
""" Remove an application.
"""
idnum = int(str(idnum), 0)
if idnum == 0:
raise exceptions.ParameterError('idnum out of range: %d' % idnum)
ret = cmdr.send_prompt_command("app remove %d" % idnum)
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def app_resource_bank(cmdr, idnum=0):
""" Get resource bank info for an application.
"""
idnum = int(str(idnum), 0)
if idnum < 0:
raise exceptions.ParameterError('idnum out of range: %d' % idnum)
ret = cmdr.send_prompt_command("resource bank info %d" % idnum)
if not ret[0].startswith("OK "):
raise exceptions.PromptResponseError(ret)
return [ret[0][3:]]
@PebbleCommander.command()
def app_next_id(cmdr):
""" Get next free application ID.
"""
return cmdr.send_prompt_command("app next id")
@PebbleCommander.command()
def app_available(cmdr, idnum):
""" Check if an application is available.
"""
idnum = int(str(idnum), 0)
if idnum == 0:
raise exceptions.ParameterError('idnum out of range: %d' % idnum)
return cmdr.send_prompt_command("app available %d" % idnum)

View file

@ -0,0 +1,35 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def battery_force_charge(cmdr, charging=True):
""" Force the device to believe it is or isn't charging.
"""
if parsers.str2bool(charging):
charging = "enable"
else:
charging = "disable"
ret = cmdr.send_prompt_command("battery chargeopt %s" % charging)
if ret:
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def battery_status(cmdr):
""" Get current battery status.
"""
return cmdr.send_prompt_command("battery status")

View file

@ -0,0 +1,82 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def bt_airplane_mode(cmdr, enter=True):
""" Enter or exit airplane mode.
`enter` should either be a boolean, "enter", or "exit".
"""
if parsers.str2bool(enter, also_true=["enter"], also_false=["exit"]):
enter = "enter"
else:
enter = "exit"
ret = cmdr.send_prompt_command("bt airplane mode %s" % enter)
if ret:
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def bt_prefs_wipe(cmdr):
""" Wipe bluetooth preferences.
"""
ret = cmdr.send_prompt_command("bt prefs wipe")
if ret:
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def bt_mac(cmdr):
""" Get the bluetooth MAC address.
"""
ret = cmdr.send_prompt_command("bt mac")
if not ret[0].startswith("0x"):
raise exceptions.PromptResponseError(ret)
retstr = ret[0][2:]
return [':'.join(retstr[i:i+2] for i in range(0, len(retstr), 2))]
@PebbleCommander.command()
def bt_set_addr(cmdr, new_mac=None):
""" Set the bluetooth MAC address.
Don't specify `new_mac` to revert to default.
`new_mac` should be of the normal 6 hex octets split with colons.
"""
if not new_mac:
new_mac = "00:00:00:00:00:00"
mac = parsers.str2mac(new_mac)
macstr = ''.join(["%02X" % byte for byte in mac])
ret = cmdr.send_prompt_command("bt set addr %s" % macstr)
if ret[0] != new_mac:
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def bt_set_name(cmdr, new_name=None):
""" Set the bluetooth name.
"""
if not new_name:
new_name = ""
# Note: the only reason for this is because prompt sucks
# This can probably be removed when prompt goes away
if ' ' in new_name:
raise exceptions.ParameterError("bluetooth name must not have spaces")
ret = cmdr.send_prompt_command("bt set name %s" % new_name)
if ret:
raise exceptions.PromptResponseError(ret)

View file

@ -0,0 +1,58 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def click_short(cmdr, button):
""" Click a button.
"""
button = int(str(button), 0)
if not 0 <= button <= 3:
raise exceptions.ParameterError('button out of range: %d' % button)
ret = cmdr.send_prompt_command("click short %d" % button)
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def click_long(cmdr, button, hold_ms=20):
""" Hold a button.
`hold_ms` is how many ms to hold the button down before releasing.
"""
return cmdr.click_multiple(button, hold_ms=hold_ms)
@PebbleCommander.command()
def click_multiple(cmdr, button, count=1, hold_ms=20, delay_ms=0):
""" Rhythmically click a button.
"""
button = int(str(button), 0)
count = int(str(count), 0)
hold_ms = int(str(hold_ms), 0)
delay_ms = int(str(delay_ms), 0)
if not 0 <= button <= 3:
raise exceptions.ParameterError('button out of range: %d' % button)
if not count > 0:
raise exceptions.ParameterError('count out of range: %d' % count)
if hold_ms < 0:
raise exceptions.ParameterError('hold_ms out of range: %d' % hold_ms)
if delay_ms < 0:
raise exceptions.ParameterError('delay_ms out of range: %d' % delay_ms)
ret = cmdr.send_prompt_command(
"click multiple {button:d} {count:d} {hold_ms:d} {delay_ms:d}".format(**locals()))
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)

View file

@ -0,0 +1,61 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
# TODO: flash-write
# Can't do it with pulse prompt :(
@PebbleCommander.command()
def flash_erase(cmdr, address, length):
""" Erase flash area.
"""
address = int(str(address), 0)
length = int(str(length), 0)
if address < 0:
raise exceptions.ParameterError('address out of range: %d' % address)
if length <= 0:
raise exceptions.ParameterError('length out of range: %d' % length)
# TODO: I guess catch errors
ret = cmdr.send_prompt_command("erase flash 0x%X %d" % (address, length))
if not ret[1].startswith("OK"):
raise exceptions.PromptResponseError(ret)
@PebbleCommander.command()
def flash_crc(cmdr, address, length):
""" Calculate CRC of flash area.
"""
address = int(str(address), 0)
length = int(str(length), 0)
if address < 0:
raise exceptions.ParameterError('address out of range: %d' % address)
if length <= 0:
raise exceptions.ParameterError('length out of range: %d' % length)
# TODO: I guess catch errors
ret = cmdr.send_prompt_command("crc flash 0x%X %d" % (address, length))
if not ret[0].startswith("CRC: "):
raise exceptions.PromptResponseError(ret)
return [ret[0][5:]]
@PebbleCommander.command()
def prf_address(cmdr):
""" Get address of PRF.
"""
ret = cmdr.send_prompt_command("prf image address")
if not ret[0].startswith("OK "):
raise exceptions.PromptResponseError(ret)
return [ret[0][3:]]

View file

@ -0,0 +1,133 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import sys
from .. import PebbleCommander, exceptions, parsers
def trim_docstring(var):
return inspect.getdoc(var) or ''
def get_help_short(cmdr, cmd_name, help_output=None):
"""
cmd_name is the command's name.
help_output is the raw output of the `!help` command.
"""
output = None
func = cmdr.get_command(cmd_name)
if func: # Host command
# cmdstr is the actual function name
cmdstr = func.name
spec = inspect.getargspec(func)
if len(spec.args) > 1:
maxargs = len(spec.args) - 1
if spec.defaults is None:
cmdstr += " {%d args}" % maxargs
else:
minargs = maxargs - len(spec.defaults)
cmdstr += " {%d~%d args}" % (minargs, maxargs)
if func.__doc__ is not None:
output = "%-30s - %s" % (cmdstr, trim_docstring(func).splitlines()[0])
else:
output = cmdstr
else: # Prompt command
if cmd_name[0] == '!': # Strip the bang if it's there
cmd_name = cmd_name[1:]
# Get the output if it wasn't provided
if help_output is None:
help_output = cmdr.send_prompt_command("help")
for prompt_cmd in help_output[1:]:
# Match, even with argument count provided
if prompt_cmd == cmd_name or prompt_cmd.startswith(cmd_name+" "):
# Output should be the full argument string with the bang
output = '!' + prompt_cmd
break
return output
def help_arginfo_nodefault(arg):
return "%s" % arg.upper()
def help_arginfo_default(arg, dflt):
return "[%s (default: %s)]" % (arg.upper(), str(dflt))
def get_help_long(cmdr, cmd_name):
output = ""
func = cmdr.get_command(cmd_name)
if func:
spec = inspect.getargspec(func)
specstr = []
for i, arg in enumerate(spec.args[1:]):
if spec.defaults is not None:
minargs = len(spec.args[1:]) - len(spec.defaults)
if i >= minargs:
specstr.append(help_arginfo_default(arg, spec.defaults[i - minargs]))
else:
specstr.append(help_arginfo_nodefault(arg))
else:
specstr.append(help_arginfo_nodefault(arg))
specstr = ' '.join(specstr)
cmdstr = func.name + " " + specstr
if func.__doc__ is None:
output = "%s\n\nNo help available." % cmdstr
else:
output = "%s - %s" % (cmdstr, trim_docstring(func))
else: # Prompt command
cmdstr = get_help_short(cmdr, cmd_name)
if cmdstr is None:
output = None
else:
output = "%s\n\nNo help available, due to being a prompt command." % cmdstr
return output
@PebbleCommander.command()
def help(cmdr, cmd=None):
""" Show help.
You're lookin' at it, dummy!
"""
out = []
if cmd is not None:
helpstr = get_help_long(cmdr, cmd)
if helpstr is None:
raise exceptions.ParameterError("No command '%s' found." % cmd)
out.append(helpstr)
else: # List commands
out.append("===Host commands===")
# Bonus, this list is sorted for us already
for cmd_name in dir(cmdr):
if cmdr.get_command(cmd_name):
out.append(get_help_short(cmdr, cmd_name))
out.append("\n===Prompt commands===")
ret = cmdr.send_prompt_command("help")
if ret[0] != 'Available Commands:':
raise exceptions.PromptResponseError("'help' prompt command output invalid")
for cmd_name in ret[1:]:
out.append(get_help_short(cmdr, "!" + cmd_name, ret))
return out

View file

@ -0,0 +1,224 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from binascii import crc32
import os
import struct
import sys
import traceback
import pebble.pulse2.exceptions
from .. import PebbleCommander, exceptions, parsers
from ..util import stm32_crc
class PebbleFirmwareBinaryInfo(object):
V1_STRUCT_VERSION = 1
V1_STRUCT_DEFINTION = [
('20s', 'build_id'),
('L', 'version_timestamp'),
('32s', 'version_tag'),
('8s', 'version_short'),
('?', 'is_recovery_firmware'),
('B', 'hw_platform'),
('B', 'metadata_version')
]
# The platforms which use a legacy defective crc32
LEGACY_CRC_PLATFORMS = [
0, # unknown (assume legacy)
1, # OneEV1
2, # OneEV2
3, # OneEV2_3
4, # OneEV2_4
5, # OnePointFive
6, # TwoPointFive
7, # SnowyEVT2
8, # SnowyDVT
9, # SpaldingEVT
10, # BobbyDVT
11, # Spalding
0xff, # OneBigboard
0xfe, # OneBigboard2
0xfd, # SnowyBigboard
0xfc, # SnowyBigboard2
0xfb, # SpaldingBigboard
]
def get_crc(self):
_, ext = os.path.splitext(self.path)
assert ext == '.bin', 'Can only calculate crc for .bin files'
with open(self.path, 'rb') as f:
image = f.read()
if self.hw_platform in self.LEGACY_CRC_PLATFORMS:
# use the legacy defective crc
return stm32_crc.crc32(image)
else:
# use a regular crc
return crc32(image) & 0xFFFFFFFF
def _get_footer_struct(self):
fmt = '<' + reduce(lambda s, t: s + t[0],
PebbleFirmwareBinaryInfo.V1_STRUCT_DEFINTION, '')
return struct.Struct(fmt)
def _get_footer_data_from_bin(self, path):
with open(path, 'rb') as f:
struct_size = self.struct.size
f.seek(-struct_size, 2)
footer_data = f.read()
return footer_data
def _parse_footer_data(self, footer_data):
z = zip(PebbleFirmwareBinaryInfo.V1_STRUCT_DEFINTION,
self.struct.unpack(footer_data))
return {entry[1]: data for entry, data in z}
def __init__(self, bin_path):
self.path = bin_path
self.struct = self._get_footer_struct()
_, ext = os.path.splitext(bin_path)
if ext != '.bin':
raise ValueError('Unexpected extension. Must be ".bin"')
footer_data = self._get_footer_data_from_bin(bin_path)
self.info = self._parse_footer_data(footer_data)
# Trim leading NULLS on the strings:
for k in ["version_tag", "version_short"]:
self.info[k] = self.info[k].rstrip("\x00")
def __str__(self):
return str(self.info)
def __repr__(self):
return self.info.__repr__()
def __getattr__(self, name):
if name in self.info:
return self.info[name]
raise AttributeError
# typedef struct ATTR_PACKED FirmwareDescription {
# uint32_t description_length;
# uint32_t firmware_length;
# uint32_t checksum;
# } FirmwareDescription;
FW_DESCR_FORMAT = '<III'
FW_DESCR_SIZE = struct.calcsize(FW_DESCR_FORMAT)
def _generate_firmware_description_struct(firmware_length, firmware_crc):
return struct.pack(FW_DESCR_FORMAT, FW_DESCR_SIZE, firmware_length, firmware_crc)
def insert_firmware_description_struct(input_binary, output_binary=None):
fw_bin_info = PebbleFirmwareBinaryInfo(input_binary)
with open(input_binary, 'rb') as inf:
fw_bin = inf.read()
fw_crc = fw_bin_info.get_crc()
return _generate_firmware_description_struct(len(fw_bin), fw_crc) + fw_bin
def _load(connection, image, progress, verbose, address):
image_crc = stm32_crc.crc32(image)
progress_cb = None
if progress or verbose:
def progress_cb(acked):
print('.' if acked else 'R', end='')
sys.stdout.flush()
if progress or verbose:
print('Erasing... ', end='')
sys.stdout.flush()
try:
connection.flash.erase(address, len(image))
except pebble.pulse2.exceptions.PulseException as e:
detail = ''.join(traceback.format_exception_only(type(e), e))
if verbose:
detail = '\n' + traceback.format_exc()
print('Erase failed! ' + detail)
return False
if progress or verbose:
print('done.')
sys.stdout.flush()
try:
retries = connection.flash.write(address, image,
progress_cb=progress_cb)
except pebble.pulse2.exceptions.PulseException as e:
detail = ''.join(traceback.format_exception_only(type(e), e))
if verbose:
detail = '\n' + traceback.format_exc()
print('Write failed! ' + detail)
return False
result_crc = connection.flash.crc(address, len(image))
if progress or verbose:
print()
if verbose:
print('Retries: %d' % retries)
return result_crc == image_crc
def load_firmware(connection, fin, progress, verbose, address=None):
if address is None:
# If address is unspecified, assume we want the prf address
_, address, length = connection.flash.query_region_geometry(
connection.flash.REGION_PRF)
address = int(address)
image = insert_firmware_description_struct(fin)
if _load(connection, image, progress, verbose, address):
connection.flash.finalize_region(
connection.flash.REGION_PRF)
return True
return False
def load_resources(connection, fin, progress, verbose):
_, address, length = connection.flash.query_region_geometry(
connection.flash.REGION_SYSTEM_RESOURCES)
with open(fin, 'rb') as f:
data = f.read()
assert len(data) <= length
if _load(connection, data, progress, verbose, address):
connection.flash.finalize_region(
connection.flash.REGION_SYSTEM_RESOURCES)
return True
return False
@PebbleCommander.command()
def image_resources(cmdr, pack='build/system_resources.pbpack'):
""" Image resources.
"""
load_resources(cmdr.connection, pack,
progress=cmdr.interactive, verbose=cmdr.interactive)
@PebbleCommander.command()
def image_firmware(cmdr, firm='build/prf/src/fw/tintin_fw.bin', address=None):
""" Image recovery firmware.
"""
if address is not None:
address = int(str(address), 0)
load_firmware(cmdr.connection, firm, progress=cmdr.interactive,
verbose=cmdr.interactive, address=address)

View file

@ -0,0 +1,22 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def audit_delay(cmdr):
""" Audit delay_us.
"""
return cmdr.send_prompt_command("audit delay")

View file

@ -0,0 +1,45 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def pfs_prepare(cmdr, size):
""" Prepare for file creation.
"""
size = int(str(size), 0)
if size <= 0:
raise exceptions.ParameterError('size out of range: %d' % size)
# TODO: I guess catch errors
ret = cmdr.send_prompt_command("pfs prepare %d" % size)
if not ret[0].startswith("Success"):
raise exceptions.PromptResponseError(ret)
# TODO: pfs-write
# Can't do it with pulse prompt :(
@PebbleCommander.command()
def pfs_litter(cmdr):
""" Fragment the filesystem.
Creates a bunch of fragmentation in the filesystem by creating a large
number of small files and only deleting a small number of them.
"""
ret = cmdr.send_prompt_command("litter pfs")
if not ret[0].startswith("OK "):
raise exceptions.PromptResponseError(ret)
return [ret[0][3:]]

View file

@ -0,0 +1,45 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def reset(cmdr):
""" Reset the device.
"""
cmdr.send_prompt_command("reset")
@PebbleCommander.command()
def crash(cmdr):
""" Crash the device.
"""
cmdr.send_prompt_command("crash")
@PebbleCommander.command()
def factory_reset(cmdr, fast=False):
""" Perform a factory reset.
If `fast` is specified as true or "fast", do a fast factory reset.
"""
if parsers.str2bool(fast, also_true=["fast"]):
fast = " fast"
else:
fast = ""
ret = cmdr.send_prompt_command("factory reset%s" % fast)
if ret:
raise exceptions.PromptResponseError(ret)

View file

@ -0,0 +1,38 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def version(cmdr):
""" Get version information.
"""
return cmdr.send_prompt_command("version")
@PebbleCommander.command()
def boot_bit_set(cmdr, bit, value):
""" Set some boot bits.
`bit` should be between 0 and 31.
`value` should be a boolean.
"""
bit = int(str(bit), 0)
value = int(parsers.str2bool(value))
if not 0 <= bit <= 31:
raise exceptions.ParameterError('bit index out of range: %d' % bit)
ret = cmdr.send_prompt_command("boot bit set %d %d" % (bit, value))
if not ret[0].startswith("OK"):
raise exceptions.PromptResponseError(ret)

View file

@ -0,0 +1,39 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def set_time(cmdr, new_time):
""" Set the time.
`new_time` should be in epoch seconds.
"""
new_time = int(str(new_time), 0)
if new_time < 1262304000:
raise exceptions.ParameterError("time must be later than 2010-01-01")
ret = cmdr.send_prompt_command("set time %s" % new_time)
if not ret[0].startswith("Time is now"):
raise exceptions.PromptResponseError(ret)
return ret
@PebbleCommander.command()
def timezone_clear(cmdr):
""" Clear timezone settings.
"""
ret = cmdr.send_prompt_command("timezone clear")
if ret:
raise exceptions.PromptResponseError(ret)

View file

@ -0,0 +1,29 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import PebbleCommander, exceptions, parsers
@PebbleCommander.command()
def window_stack(cmdr):
""" Dump the window stack.
"""
return cmdr.send_prompt_command("window stack")
@PebbleCommander.command()
def modal_stack(cmdr):
""" Dump the modal stack.
"""
return cmdr.send_prompt_command("modal stack")

View file

@ -0,0 +1,2 @@
Application-layer PULSEv2 Protocols
===================================

View file

@ -0,0 +1,19 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Public aliases for the classes that users will interact with directly.
from .bulkio import BulkIO
from .flash_imaging import FlashImaging
from .prompt import Prompt
from .streaming_logs import StreamingLogs

View file

@ -0,0 +1,451 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import collections
import logging
import struct
from ..exceptions import PebbleCommanderError
class ResponseParseError(PebbleCommanderError):
pass
class EraseError(PebbleCommanderError):
pass
class OpenCommand(object):
command_type = 1
command_struct = struct.Struct('<BB')
def __init__(self, domain, extra=None):
self.domain = domain
self.extra = extra
@property
def packet(self):
cmd = self.command_struct.pack(self.command_type, self.domain)
if self.extra:
cmd += self.extra
return cmd
class CloseCommand(object):
command_type = 2
command_struct = struct.Struct('<BB')
def __init__(self, fd):
self.fd = fd
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.fd)
class ReadCommand(object):
command_type = 3
command_struct = struct.Struct('<BBII')
def __init__(self, fd, address, length):
self.fd = fd
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.fd,
self.address, self.length)
class WriteCommand(object):
command_type = 4
command_struct = struct.Struct('<BBI')
header_size = command_struct.size
def __init__(self, fd, address, data):
self.fd = fd
self.address = address
self.data = data
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.fd,
self.address) + self.data
class CRCCommand(object):
command_type = 5
command_struct = struct.Struct('<BBII')
def __init__(self, fd, address, length):
self.fd = fd
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.fd,
self.address, self.length)
class StatCommand(object):
command_type = 6
command_struct = struct.Struct('<BB')
def __init__(self, fd):
self.fd = fd
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.fd)
class EraseCommand(object):
command_type = 7
command_struct = struct.Struct('<BB')
def __init__(self, domain, extra=None):
self.domain = domain
self.extra = extra
@property
def packet(self):
cmd = self.command_struct.pack(self.command_type, self.domain)
if self.extra:
cmd += self.extra
return cmd
class OpenResponse(object):
response_type = 128
response_format = '<xB'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('OpenResponse', 'fd')
@classmethod
def parse(cls, response):
response_type = ord(response[0])
if response_type != cls.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return cls.Response._make(cls.response_struct.unpack(response))
class CloseResponse(object):
response_type = 129
response_format = '<xB'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('CloseResponse', 'fd')
@classmethod
def parse(cls, response):
response_type = ord(response[0])
if response_type != cls.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return cls.Response._make(cls.response_struct.unpack(response))
class ReadResponse(object):
response_type = 130
response_format = '<xBI'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('ReadResponse', 'fd address data')
@classmethod
def parse(cls, response):
if ord(response[0]) != cls.response_type:
raise ResponseParseError('Unexpected response: %r' % response)
header = response[:cls.header_size]
body = response[cls.header_size:]
fd, address, = cls.response_struct.unpack(header)
return cls.Response(fd, address, body)
class WriteResponse(object):
response_type = 131
response_format = '<xBII'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('WriteResponse', 'fd address length')
@classmethod
def parse(cls, response):
response_type = ord(response[0])
if response_type != cls.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return cls.Response._make(cls.response_struct.unpack(response))
class CRCResponse(object):
response_type = 132
response_format = '<xBIII'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('CRCResponse', 'fd address length crc')
@classmethod
def parse(cls, response):
response_type = ord(response[0])
if response_type != cls.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return cls.Response._make(cls.response_struct.unpack(response))
class StatResponse(object):
response_type = 133
def __init__(self, name, format, fields):
self.name = name
self.struct = struct.Struct('<xBB' + format)
self.tuple = collections.namedtuple(name, 'fd flags ' + fields)
def parse(self, response):
response_type = ord(response[0])
if response_type != self.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return self.tuple._make(self.struct.unpack(response))
def __repr__(self):
return 'StatResponse({self.name!r}, {self.struct!r}, {self.tuple!r})'.format(self=self)
class EraseResponse(object):
response_type = 134
response_format = '<xBb'
response_struct = struct.Struct(response_format)
header_size = response_struct.size
Response = collections.namedtuple('EraseResponse', 'domain status')
@classmethod
def parse(cls, response):
response_type = ord(response[0])
if response_type != cls.response_type:
raise ResponseParseError('Unexpected response type: %r' % response_type)
return cls.Response._make(cls.response_struct.unpack(response))
def enum(**enums):
return type('Enum', (), enums)
ReadDomains = enum(
MEMORY=1,
EXTERNAL_FLASH=2,
FRAMEBUFFER=3,
COREDUMP=4,
PFS=5
)
class PULSEIO_Base(object):
ERASE_FORMAT = None
STAT_FORMAT = None
DOMAIN = None
def __init__(self, socket, *args, **kwargs):
self.socket = socket
self.pos = 0
options = self._process_args(*args, **kwargs)
resp = self._send_and_receive(OpenCommand, OpenResponse, self.DOMAIN, options)
self.fd = resp.fd
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
@staticmethod
def _process_args(*args, **kwargs):
return ""
def _send_and_receive(self, cmd_type, resp_type, *args):
cmd = cmd_type(*args)
self.socket.send(cmd.packet)
ret = self.socket.receive(block=True)
return resp_type.parse(ret)
def close(self):
if self.fd is not None:
resp = self._send_and_receive(CloseCommand, CloseResponse, self.fd)
assert resp.fd == self.fd
self.fd = None
def seek_absolute(self, pos):
if pos < 0:
raise ValueError('Cannot seek to before start of file')
self.pos = pos
def seek_relative(self, num_bytes):
if (self.pos + num_bytes) < 0:
raise ValueError('Cannot seek to before start of file')
self.pos += num_bytes
@classmethod
def erase(cls, socket, *args):
if cls.ERASE_FORMAT == "raw":
options = "".join(args)
elif cls.ERASE_FORMAT:
options = struct.pack("<" + cls.ERASE_FORMAT, *args)
else:
raise NotImplementedError("Erase is not supported for domain %d" % cls.DOMAIN)
cmd = EraseCommand(cls.DOMAIN, options)
socket.send(cmd.packet)
status = 1
while status > 0:
ret = socket.receive(block=True)
resp = EraseResponse.parse(ret)
logging.debug("ERASE: domain %d status %d", resp.domain, resp.status)
status = resp.status
if status < 0:
raise EraseError(status)
def write(self, data):
if self.fd is None:
raise ValueError('Handle is not open')
mss = self.socket.mtu - WriteCommand.header_size
for offset in xrange(0, len(data), mss):
segment = data[offset:offset+mss]
resp = self._send_and_receive(WriteCommand, WriteResponse, self.fd, self.pos, segment)
assert resp.fd == self.fd
assert resp.address == self.pos
self.pos += len(segment)
def read(self, length):
if self.fd is None:
raise ValueError('Handle is not open')
cmd = ReadCommand(self.fd, self.pos, length)
self.socket.send(cmd.packet)
data = bytearray()
bytes_left = length
while bytes_left > 0:
packet = self.socket.receive(block=True)
fd, chunk_offset, chunk = ReadResponse.parse(packet)
assert fd == self.fd
data.extend(chunk)
bytes_left -= len(chunk)
return data
def crc(self, length):
if self.fd is None:
raise ValueError('Handle is not open')
resp = self._send_and_receive(CRCCommand, CRCResponse, self.fd, self.pos, length)
return resp.crc
def stat(self):
if self.fd is None:
raise ValueError('Handle is not open')
if not self.STAT_FORMAT:
raise NotImplementedError("Stat is not supported for domain %d" % self.DOMAIN)
return self._send_and_receive(StatCommand, self.STAT_FORMAT, self.fd)
class PULSEIO_Memory(PULSEIO_Base):
DOMAIN = ReadDomains.MEMORY
# uint32 for address, uint32 for length
ERASE_FORMAT = "II"
class PULSEIO_ExternalFlash(PULSEIO_Base):
DOMAIN = ReadDomains.EXTERNAL_FLASH
# uint32 for address, uint32 for length
ERASE_FORMAT = "II"
class PULSEIO_Framebuffer(PULSEIO_Base):
DOMAIN = ReadDomains.FRAMEBUFFER
STAT_FORMAT = StatResponse('FramebufferAttributes', 'BBBI', 'width height bpp length')
class PULSEIO_Coredump(PULSEIO_Base):
DOMAIN = ReadDomains.COREDUMP
STAT_FORMAT = StatResponse('CoredumpAttributes', 'BI', 'unread length')
ERASE_FORMAT = "I"
@staticmethod
def _process_args(slot):
return struct.pack("<I", slot)
class PULSEIO_PFS(PULSEIO_Base):
DOMAIN = ReadDomains.PFS
STAT_FORMAT = StatResponse('PFSFileAttributes', 'I', 'length')
ERASE_FORMAT = "raw"
OP_FLAG_READ = 1 << 0
OP_FLAG_WRITE = 1 << 1
OP_FLAG_OVERWRITE = 1 << 2
OP_FLAG_SKIP_HDR_CRC_CHECK = 1 << 3
OP_FLAG_USE_PAGE_CACHE = 1 << 4
@staticmethod
def _process_args(filename, mode='r', flags=0xFE, initial_size=0):
mode_num = PULSEIO_PFS.OP_FLAG_READ
if 'w' in mode:
mode_num |= PULSEIO_PFS.OP_FLAG_WRITE
return struct.pack("<BBI", mode_num, flags, initial_size) + filename
class BulkIO(object):
PROTOCOL_NUMBER = 0x3e21
DOMAIN_MAP = {
'pfs': PULSEIO_PFS,
'framebuffer': PULSEIO_Framebuffer
}
def __init__(self, link):
self.socket = link.open_socket('reliable', self.PROTOCOL_NUMBER)
def open(self, domain, *args, **kwargs):
return self.DOMAIN_MAP[domain](self.socket, *args, **kwargs)
def erase(self, domain, *args, **kwargs):
return self.DOMAIN_MAP[domain].erase(self.socket, *args, **kwargs)
def close(self):
self.socket.close()

View file

@ -0,0 +1,318 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import collections
import struct
import time
import pebble.pulse2.exceptions
from .. import exceptions
class EraseCommand(object):
command_type = 1
command_struct = struct.Struct('<BII')
response_type = 128
response_struct = struct.Struct('<xII?')
Response = collections.namedtuple(
'EraseResponse', 'address length complete')
def __init__(self, address, length):
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(
self.command_type, self.address, self.length)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address != self.address or unpacked.length != self.length:
raise exceptions.ResponseParseError(
'Response does not match command: '
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
unpacked.address, unpacked.length, self.address,
self.length))
return unpacked
class WriteCommand(object):
command_type = 2
command_struct = struct.Struct('<BI')
header_len = command_struct.size
def __init__(self, address, data):
self.address = address
self.data = data
@property
def packet(self):
header = self.command_struct.pack(self.command_type, self.address)
return header + self.data
class WriteResponse(object):
response_type = 129
response_struct = struct.Struct('<xII?')
Response = collections.namedtuple(
'WriteResponse', 'address length complete')
@classmethod
def parse(cls, response):
if ord(response[0]) != cls.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
return cls.Response._make(cls.response_struct.unpack(response))
class CrcCommand(object):
command_type = 3
command_struct = struct.Struct('<BII')
response_type = 130
response_struct = struct.Struct('<xIII')
Response = collections.namedtuple('CrcResponse', 'address length crc')
def __init__(self, address, length):
self.address = address
self.length = length
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.address,
self.length)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address != self.address or unpacked.length != self.length:
raise exceptions.ResponseParseError(
'Response does not match command: '
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
unpacked.address, unpacked.length, self.address,
self.length))
return unpacked
class QueryFlashRegionCommand(object):
command_type = 4
command_struct = struct.Struct('<BB')
REGION_PRF = 1
REGION_SYSTEM_RESOURCES = 2
response_type = 131
response_struct = struct.Struct('<xBII')
Response = collections.namedtuple(
'FlashRegionGeometry', 'region address length')
def __init__(self, region):
self.region = region
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.region)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
unpacked = self.Response._make(self.response_struct.unpack(response))
if unpacked.address == 0 and unpacked.length == 0:
raise exceptions.RegionDoesNotExist(self.region)
return unpacked
class FinalizeFlashRegionCommand(object):
command_type = 5
command_struct = struct.Struct('<BB')
response_type = 132
response_struct = struct.Struct('<xB')
def __init__(self, region):
self.region = region
@property
def packet(self):
return self.command_struct.pack(self.command_type, self.region)
def parse_response(self, response):
if ord(response[0]) != self.response_type:
raise exceptions.ResponseParseError(
'Unexpected response: %r' % response)
region, = self.response_struct.unpack(response)
if region != self.region:
raise exceptions.ResponseParseError(
'Response does not match command: '
'response is for region %d (expected %d)' % (
region, self.region))
class FlashImaging(object):
PORT_NUMBER = 0x0002
RESP_BAD_CMD = 192
RESP_INTERNAL_ERROR = 193
REGION_PRF = QueryFlashRegionCommand.REGION_PRF
REGION_SYSTEM_RESOURCES = QueryFlashRegionCommand.REGION_SYSTEM_RESOURCES
def __init__(self, link):
self.socket = link.open_socket('best-effort', self.PORT_NUMBER)
def close(self):
self.socket.close()
def erase(self, address, length):
cmd = EraseCommand(address, length)
ack_received = False
retries = 0
while retries < 10:
if not ack_received:
self.socket.send(cmd.packet)
try:
packet = self.socket.receive(timeout=5 if ack_received else 1.5)
response = cmd.parse_response(packet)
ack_received = True
if response.complete:
return
except pebble.pulse2.exceptions.ReceiveQueueEmpty:
ack_received = False
retries += 1
continue
raise exceptions.CommandTimedOut
def write(self, address, data, max_retries=5, max_in_flight=5,
progress_cb=None):
mtu = self.socket.mtu - WriteCommand.header_len
assert(mtu > 0)
unsent = collections.deque()
for offset in xrange(0, len(data), mtu):
segment = data[offset:offset+mtu]
assert(len(segment))
seg_address = address + offset
unsent.appendleft(
(seg_address, WriteCommand(seg_address, segment), 0))
in_flight = collections.OrderedDict()
retries = 0
while unsent or in_flight:
try:
while True:
# Process ACKs (if any)
ack = WriteResponse.parse(
self.socket.receive(block=False))
try:
cmd, _, _ = in_flight[ack.address]
del in_flight[ack.address]
except KeyError:
for seg_address, cmd, retry_count in unsent:
if seg_address == ack.address:
if retry_count == 0:
# ACK for a segment we never sent?!
raise exceptions.WriteError(
'Received ACK for an unsent segment: '
'%#.08x' % ack.address)
# Got an ACK for a sent (but timed out) segment
unsent.remove((seg_address, cmd, retry_count))
break
else:
raise exceptions.WriteError(
'Received ACK for an unknown segment: '
'%#.08x' % ack.address)
if len(cmd.data) != ack.length:
raise exceptions.WriteError(
'ACK length %d != data length %d' % (
ack.length, len(cmd.data)))
assert(ack.complete)
if progress_cb:
progress_cb(True)
except pebble.pulse2.exceptions.ReceiveQueueEmpty:
pass
# Retry any in_flight writes where the ACK has timed out
to_retry = []
timeout_time = time.time() - 0.5
for (seg_address,
(cmd, send_time, retry_count)) in in_flight.iteritems():
if send_time > timeout_time:
# in_flight is an OrderedDict so iteration is in
# chronological order.
break
if retry_count >= max_retries:
raise exceptions.WriteError(
'Segment %#.08x exceeded the max retry count (%d)' % (
seg_address, max_retries))
# Enqueue the packet again to resend later.
del in_flight[seg_address]
unsent.appendleft((seg_address, cmd, retry_count+1))
retries += 1
if progress_cb:
progress_cb(False)
# Send out fresh segments
try:
while len(in_flight) < max_in_flight:
seg_address, cmd, retry_count = unsent.pop()
self.socket.send(cmd.packet)
in_flight[cmd.address] = (cmd, time.time(), retry_count)
except IndexError:
pass
# Give other threads a chance to run
time.sleep(0)
return retries
def _command_and_response(self, cmd, timeout=0.5):
for attempt in xrange(5):
self.socket.send(cmd.packet)
try:
packet = self.socket.receive(timeout=timeout)
return cmd.parse_response(packet)
except pebble.pulse2.exceptions.ReceiveQueueEmpty:
pass
raise exceptions.CommandTimedOut
def crc(self, address, length):
cmd = CrcCommand(address, length)
return self._command_and_response(cmd, timeout=1).crc
def query_region_geometry(self, region):
cmd = QueryFlashRegionCommand(region)
return self._command_and_response(cmd)
def finalize_region(self, region):
cmd = FinalizeFlashRegionCommand(region)
return self._command_and_response(cmd)

View file

@ -0,0 +1,78 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import collections
import struct
from datetime import datetime
import pebble.pulse2.exceptions
from .. import exceptions
class Prompt(object):
PORT_NUMBER = 0x3e20
def __init__(self, link):
self.socket = link.open_socket('reliable', self.PORT_NUMBER)
def command_and_response(self, command_string, timeout=20):
log = []
self.socket.send(bytes(command_string))
is_done = False
while not is_done:
try:
response = PromptResponse.parse(
self.socket.receive(timeout=timeout))
if response.is_done_response:
is_done = True
elif response.is_message_response:
log.append(response.message)
except pebble.pulse2.exceptions.ReceiveQueueEmpty:
raise exceptions.CommandTimedOut
return log
def close(self):
self.socket.close()
class PromptResponse(collections.namedtuple('PromptResponse',
'response_type timestamp message')):
DONE_RESPONSE = 101
MESSAGE_RESPONSE = 102
response_struct = struct.Struct('<BQ')
@property
def is_done_response(self):
return self.response_type == self.DONE_RESPONSE
@property
def is_message_response(self):
return self.response_type == self.MESSAGE_RESPONSE
@classmethod
def parse(cls, response):
result = cls.response_struct.unpack(response[:cls.response_struct.size])
response_type = result[0]
timestamp = datetime.fromtimestamp(result[1] / 1000.0)
message = response[cls.response_struct.size:]
return cls(response_type, timestamp, message)

View file

@ -0,0 +1,67 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import collections
import struct
from datetime import datetime
class LogMessage(collections.namedtuple('LogMessage',
'log_level task timestamp file_name line_number message')):
__slots__ = ()
response_struct = struct.Struct('<c16sccQH')
def __str__(self):
msec_timestamp = self.timestamp.strftime("%H:%M:%S.%f")[:-3]
template = ('{self.log_level} {self.task} {msec_timestamp} '
'{self.file_name}:{self.line_number}> {self.message}')
return template.format(self=self, msec_timestamp=msec_timestamp)
@classmethod
def parse(cls, packet):
result = cls.response_struct.unpack(packet[:cls.response_struct.size])
msg = packet[cls.response_struct.size:]
log_level = result[2]
task = result[3]
timestamp = datetime.fromtimestamp(result[4] / 1000.0)
file_name = result[1].split('\x00', 1)[0] # NUL terminated
line_number = result[5]
return cls(log_level, task, timestamp, file_name, line_number, msg)
class StreamingLogs(object):
'''App for receiving log messages streamed by the firmware.
'''
PORT_NUMBER = 0x0003
def __init__(self, interface):
try:
self.socket = interface.simplex_transport.open_socket(
self.PORT_NUMBER)
except AttributeError:
raise TypeError('LoggingApp must be bound directly '
'to an Interface, not a Link')
def receive(self, block=True, timeout=None):
return LogMessage.parse(self.socket.receive(block, timeout))
def close(self):
self.socket.close()

View file

@ -0,0 +1,167 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import re
import threading
import tokenize
import types
from pebble import pulse2
from . import apps
class Pulse2ConnectionAdapter(object):
'''An adapter for the pulse2 API to look enough like pulse.Connection
to make PebbleCommander work...ish.
Prompt will break spectacularly if the firmware reboots or the link
state otherwise changes. Commander itself needs to be modified to be
link-state aware.
'''
def __init__(self, interface):
self.interface = interface
self.logging = apps.StreamingLogs(interface)
link = interface.get_link()
self.prompt = apps.Prompt(link)
self.flash = apps.FlashImaging(link)
def close(self):
self.interface.close()
class PebbleCommander(object):
""" Pebble Commander.
Implements everything for interfacing with PULSE things.
"""
def __init__(self, tty=None, interactive=False, capfile=None):
if capfile is not None:
interface = pulse2.Interface.open_dbgserial(
url=tty, capture_stream=open(capfile, 'wb'))
else:
interface = pulse2.Interface.open_dbgserial(url=tty)
try:
self.connection = Pulse2ConnectionAdapter(interface)
except:
interface.close()
raise
self.interactive = interactive
self.log_listeners_lock = threading.Lock()
self.log_listeners = []
# Start the logging thread
self.log_thread = threading.Thread(target=self._start_logging)
self.log_thread.daemon = True
self.log_thread.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
@classmethod
def command(cls, name=None):
""" Registers a command.
`name` is the command name. If `name` is unspecified, name will be the function name
with underscores converted to hyphens.
The convention for `name` is to separate words with a hyphen. The function name
will be the same as `name` with hyphens replaced with underscores.
Example: `click-short` will result in a PebbleCommander.click_short function existing.
`fn` should return an array of strings (or None), and take the current
`PebbleCommander` as the first argument, and the rest of the argument strings
as subsequent arguments. For errors, `fn` should throw an exception.
# TODO: Probably make the return something structured instead of stringly typed.
"""
def decorator(fn):
# Story time:
# <cory> Things are fine as long as you only read from `name`, but assigning to `name`
# creates a new local which shadows the outer scope's variable, even though it's
# only assigned later on in the block
# <cory> You could work around this by doing something like `name_ = name` and using
# `name_` in the `decorator` scope
cmdname = name
if not cmdname:
cmdname = fn.__name__.replace('_', '-')
funcname = cmdname.replace('-', '_')
if not re.match(tokenize.Name + '$', funcname):
raise ValueError("command name %s isn't a valid name" % funcname)
if hasattr(cls, funcname):
raise ValueError('function name %s clashes with existing attribute' % funcname)
fn.is_command = True
fn.name = cmdname
method = types.MethodType(fn, None, cls)
setattr(cls, funcname, method)
return fn
return decorator
def close(self):
self.connection.close()
def _start_logging(self):
""" Thread to handle logging messages.
"""
while True:
try:
msg = self.connection.logging.receive()
except pulse2.exceptions.SocketClosed:
break
with self.log_listeners_lock:
# TODO: Buffer log messages if no listeners attached?
for listener in self.log_listeners:
try:
listener(msg)
except:
pass
def attach_log_listener(self, listener):
""" Attaches a listener for log messages.
Function takes message and returns are ignored.
"""
with self.log_listeners_lock:
self.log_listeners.append(listener)
def detach_log_listener(self, listener):
""" Removes a listener that was added with `attach_log_listener`
"""
with self.log_listeners_lock:
self.log_listeners.remove(listener)
def send_prompt_command(self, cmd):
""" Send a prompt command string.
Unfortunately this is indeed stringly typed, a better solution is necessary.
"""
return self.connection.prompt.command_and_response(cmd)
def get_command(self, command):
try:
fn = getattr(self, command.replace('-', '_'))
if fn.is_command:
return fn
except AttributeError:
# Method doesn't exist, or isn't a command.
pass
return None

View file

@ -0,0 +1,40 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PebbleCommanderError(Exception):
pass
class ParameterError(PebbleCommanderError):
pass
class PromptResponseError(PebbleCommanderError):
pass
class ResponseParseError(PebbleCommanderError):
pass
class RegionDoesNotExist(PebbleCommanderError):
pass
class CommandTimedOut(PebbleCommanderError):
pass
class WriteError(PebbleCommanderError):
pass

View file

@ -0,0 +1,137 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import shlex
import traceback
from log_hashing.logdehash import LogDehash
import prompt_toolkit
from .commander import PebbleCommander
class InteractivePebbleCommander(object):
""" Interactive Pebble Commander.
Most/all UI implementations should either use this directly or sub-class it.
"""
def __init__(self, loghash_path=None, tty=None, capfile=None):
self.cmdr = PebbleCommander(tty=tty, interactive=True, capfile=capfile)
if loghash_path is None:
loghash_path = "build/src/fw/loghash_dict.json"
self.dehasher = LogDehash(loghash_path)
self.cmdr.attach_log_listener(self.log_listener)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __del__(self):
self.close()
def close(self):
try:
self.cmdr.close()
except:
pass
def attach_prompt_toolkit(self):
""" Attaches prompt_toolkit things
"""
self.history = prompt_toolkit.history.InMemoryHistory()
self.cli = prompt_toolkit.CommandLineInterface(
application=prompt_toolkit.shortcuts.create_prompt_application(u"> ",
history=self.history),
eventloop=prompt_toolkit.shortcuts.create_eventloop())
self.patch_context = self.cli.patch_stdout_context(raw=True)
self.patch_context.__enter__()
def log_listener(self, msg):
""" This is called on every incoming log message.
`msg` is the raw log message class, without any dehashing.
Subclasses should override this probably.
"""
line_dict = self.dehasher.dehash(msg)
line = self.dehasher.commander_format_line(line_dict)
print line
def dispatch_command(self, string):
""" Dispatches a command string.
Subclasses should not override this.
"""
args = shlex.split(string)
# Starting with '!' passes the rest of the line directly to prompt.
# Otherwise we try to run a command; if that fails, the line goes to prompt.
if string.startswith("!"):
string = string[1:] # Chop off the '!' marker
else:
cmd = self.cmdr.get_command(args[0])
if cmd: # If we provide the command, run it.
return cmd(*args[1:])
return self.cmdr.send_prompt_command(string)
def input_handle(self, string):
""" Handles an input line.
Generally the flow is to handle any UI-specific commands, then pass on to
dispatch_command.
Subclasses should override this probably.
"""
# Handle "quit" strings
if string in ["exit", "q", "quit"]:
return False
try:
resp = self.dispatch_command(string)
if resp is not None:
print "\x1b[1m" + '\n'.join(resp) + "\x1b[m"
except:
print "An error occurred!"
traceback.print_exc()
return True
def get_command(self):
""" Get a command input line.
If there is no line, return an empty string or None.
This may block.
Subclasses should override this probably.
"""
if self.cli is None:
self.attach_prompt_toolkit()
doc = self.cli.run(reset_current_buffer=True)
if doc:
return doc.text
else:
return None
def command_loop(self):
""" The main command loop.
Subclasses could override this, but it's probably not useful to do.
"""
while True:
try:
cmd = self.get_command()
if cmd and not self.input_handle(cmd):
break
except (KeyboardInterrupt, EOFError):
break

View file

@ -0,0 +1,38 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import re
from . import exceptions
def str2bool(s, also_true=[], also_false=[]):
s = str(s).lower()
if s in ("yes", "on", "t", "1", "true", "enable") or s in also_true:
return True
if s in ("no", "off", "f", "0", "false", "disable") or s in also_false:
return False
raise exceptions.ParameterError("%s not a valid bool string." % s)
def str2mac(s):
s = str(s)
if not re.match(r'[0-9a-fA-F]{2}(:[0-9a-fA-F]{2}){5}', s):
raise exceptions.ParameterError('%s is not a valid MAC address' % s)
mac = []
for byte in str(s).split(':'):
mac.append(int(byte, 16))
return tuple(mac)

View file

@ -0,0 +1,14 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,65 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CRC_POLY = 0x04C11DB7
def precompute_table(bits):
lookup_table = []
for i in xrange(2**bits):
rr = i << (32 - bits)
for x in xrange(bits):
if rr & 0x80000000:
rr = (rr << 1) ^ CRC_POLY
else:
rr <<= 1
lookup_table.append(rr & 0xffffffff)
return lookup_table
lookup_table = precompute_table(8)
def process_word(data, crc=0xffffffff):
if (len(data) < 4):
# The CRC data is "padded" in a very unique and confusing fashion.
data = data[::-1] + '\0' * (4 - len(data))
for char in reversed(data):
b = ord(char)
crc = ((crc << 8) ^ lookup_table[(crc >> 24) ^ b]) & 0xffffffff
return crc
def process_buffer(buf, c=0xffffffff):
word_count = (len(buf) + 3) / 4
crc = c
for i in xrange(word_count):
crc = process_word(buf[i * 4 : (i + 1) * 4], crc)
return crc
def crc32(data):
return process_buffer(data)
if __name__ == '__main__':
import sys
assert(0x89f3bab2 == process_buffer("123 567 901 34"))
assert(0xaff19057 == process_buffer("123456789"))
assert(0x519b130 == process_buffer("\xfe\xff\xfe\xff"))
assert(0x495e02ca == process_buffer("\xfe\xff\xfe\xff\x88"))
print "All tests passed!"
if len(sys.argv) >= 2:
b = open(sys.argv[1]).read()
crc = crc32(b)
print "%u or 0x%x" % (crc, crc)

View file

@ -0,0 +1,56 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
# To use a consistent encoding
from codecs import open
from os import path
import sys
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
long_description = f.read()
setup(
name='pebble.commander',
version='0.0.11',
description='Pebble Commander',
long_description=long_description,
url='https://github.com/pebble/pebble-commander',
author='Pebble Technology Corporation',
author_email='cory@pebble.com',
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
namespace_packages = ['pebble'],
install_requires=[
'pebble.pulse2>=0.0.7,<1',
],
extras_require = {
'Interactive': [
'pebble.loghash>=2.6',
'prompt_toolkit>=0.55',
],
},
entry_points={
'console_scripts': [
'pebble-commander = pebble.commander.__main__:main [Interactive]',
],
},
)

7
python_libs/pebble-loghash/.gitignore vendored Normal file
View file

@ -0,0 +1,7 @@
*.pyc
.env
.DS_Store
dist
MANIFEST
__pycache__
.cache

View file

@ -0,0 +1,2 @@
# pyegg-pebble-loghash
A python egg for dealing with hashed TinTin logs

View file

@ -0,0 +1,18 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Namespace module
"""
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,78 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#/usr/bin/env python
"""
Constants used in this module
"""
import re
# Hash Mask
HASH_MASK = 0x00FFFFFF
# Regular Expressions
LOG_LINE_CONSOLE_REGEX = r"^(?P<re_level>.)\s+(?P<task>.)\s+(?P<time>.*)\s+(?P<msg>.*:.*>\s+LH.*)$"
LOG_LINE_SUPPORT_REGEX = r"^(?P<date>.*)\s+(?P<time>.*)\s+(?P<msg>.*:.*\s+LH.*)$"
LOG_MSG_REGEX = r"^(?P<f>\w*\.?\w*):(?P<l>\d*)>?\s+(?:LH:)?(?P<h>(?:0x)?[a-f0-9]{1,8}\s?.*)$"
DEHASHED_MSG_REGEX = r"^(\w+\.?\w?):(\d+)?:?(.*)$"
HASHED_INFO_REGEX = r"^(?P<hash_key>(?:0x)?[a-f0-9]{1,8})\s?(?P<arg_list>.+)?$"
FORMAT_TAG_REGEX = r"%(\.\*)?#?[0-9]{0,3}[Xdilupcszxh]+"
STR_LITERAL_REGEX = r"^(.*?)(\".*\"\s*(?:(?:PRI[A-z](?:\d{1,2}|PTR))|B[DT]_.*_FMT)*)(.*)$"
FORMAT_SPECIFIER_REGEX = r"(%#?[0-9]{0,3}[Xdilupcszxh]+)"
# New Logging Regular Expressions
NEWLOG_LINE_CONSOLE_REGEX = LOG_LINE_CONSOLE_REGEX.replace('LH', 'NL')
NEWLOG_LINE_SUPPORT_REGEX = LOG_LINE_SUPPORT_REGEX.replace('LH', 'NL')
NEWLOG_HASHED_INFO_REGEX = r"^(?::0[>]? NL:)(?P<hash_key>(?:0x)?[a-f0-9]{1,8})\s?(?P<arg_list>.+)?$"
POINTER_FORMAT_TAG_REGEX = r"(?P<format>%-?[0-9]*)p"
HEX_FORMAT_SPECIFIER_REGEX = r"%[- +#0]*\d*(\.\d+)?(hh|h|l|ll|j|z|t|L)?(x|X)"
# re patterns
STR_LITERAL_PATTERN = re.compile(STR_LITERAL_REGEX)
FORMAT_SPECIFIER_PATTERN = re.compile(FORMAT_SPECIFIER_REGEX)
LOG_LINE_CONSOLE_PATTERN = re.compile(LOG_LINE_CONSOLE_REGEX)
LOG_LINE_SUPPORT_PATTERN = re.compile(LOG_LINE_SUPPORT_REGEX)
LOG_MSG_PATTERN = re.compile(LOG_MSG_REGEX)
DEHASHED_MSG_PATTERN = re.compile(DEHASHED_MSG_REGEX)
HASHED_INFO_PATTERN = re.compile(HASHED_INFO_REGEX)
FORMAT_TAG_PATTERN = re.compile(FORMAT_TAG_REGEX)
# New Logging Patterns
NEWLOG_LINE_CONSOLE_PATTERN = re.compile(NEWLOG_LINE_CONSOLE_REGEX)
NEWLOG_LINE_SUPPORT_PATTERN = re.compile(NEWLOG_LINE_SUPPORT_REGEX)
NEWLOG_HASHED_INFO_PATTERN = re.compile(NEWLOG_HASHED_INFO_REGEX)
POINTER_FORMAT_TAG_PATTERN = re.compile(POINTER_FORMAT_TAG_REGEX)
HEX_FORMAT_SPECIFIER_PATTERN = re.compile(HEX_FORMAT_SPECIFIER_REGEX)
# Output file lines
FORMAT_IDENTIFIER_STRING_FMT = "char *format_string_{} = \"{}\";\n"
LOOKUP_RESULT_STRING_FMT = "if (loghash == {}) fmt = format_string_{};\n"
LOOKUP_DEFAULT_STRING = "fmt = \"\";\n"
FILE_IGNORE_LIST = []
# Lines to hash
GENERIC_LOG_TYPES = ["PBL_LOG", "PBL_ASSERT", "PBL_CROAK"]
BT_LOG_TYPES = ["BLE_LOG_DEBUG", "BLE_GAP_LOG_DEBUG", "BLE_CORE_LOG_DEBUG",
"BT_LOG_ERROR", "BT_LOG_DEBUG", "HCI_LOG_ERROR", "GAP_LOG_ERROR",
"GAP_LOG_DEBUG", "GAP_LOG_WARNING", "HCI_LOG_DEBUG"]
QEMU_LOG_TYPES = ["QEMU_LOG_DEBUG", "QEMU_LOG_ERROR"]
MISC_LOG_TYPES = ["ACCEL_LOG_DEBUG", "ANIMATION_LOG_DEBUG", "VOICE_LOG",
"ISPP_LOG_DEBUG", "ISPP_LOG_DEBUG_VERBOSE",
"RECONNECT_IOS_DEBUG", "SDP_LOG_DEBUG", "SDP_LOG_ERROR",
"ANALYTICS_LOG_DEBUG"]
LINES_TO_HASH = GENERIC_LOG_TYPES + BT_LOG_TYPES + QEMU_LOG_TYPES + MISC_LOG_TYPES
# Key to force next line to be hashed
HASH_NEXT_LINE = "// HASH_NEXT_LINE"

View file

@ -0,0 +1,206 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#/usr/bin/env python
"""
Module for de-hashing log strings
"""
from pebble.loghashing.constants import (LOG_LINE_CONSOLE_PATTERN, LOG_LINE_SUPPORT_PATTERN,
LOG_MSG_PATTERN, DEHASHED_MSG_PATTERN, HASHED_INFO_PATTERN,
FORMAT_TAG_PATTERN)
from pebble.loghashing.newlogging import dehash_line as newlogging_dehash_line
from pebble.loghashing.newlogging import LOG_DICT_KEY_VERSION
def dehash_file(file_name, lookup_dict):
"""
Dehash a file
:param file_name: Path of the file to dehash
:type file_name: str
:param lookup_dict: Hash lookup dictionary
:type lookup_dict: dict
:returns: A list containing the dehashed lines
"""
# Grab the lines from the file
with open(file_name, 'r') as fp:
lines = fp.readlines()
# Dehash the lines
lines = [dehash_line(x, lookup_dict) + "\n" for x in lines]
return lines
def dehash_line(line, lookup_dict):
"""
Dehash a line
:param line: The line to dehash
:type line: str
:param lookup_dict: Hash lookup dictionary
:type lookup_dict: dict
If the lookup dictionary contains the 'new_logging_version' key, it's a newlogging style
print. Pass it off to the appropriate handler.
:returns: A string containing the dehashed line, or the submitted line.
"""
if LOG_DICT_KEY_VERSION in lookup_dict:
return newlogging_dehash_line(line, lookup_dict)
return parse_line(line, lookup_dict) or parse_support_line(line, lookup_dict) or line
def parse_line(line, lookup_dict):
"""
Parse a log line
:param msg: The line to parse
:type msg: str
:param lookup_dict: Hash lookup dictionary
:type lookup_dict: dict
:returns: A string containing the parsed line, or a null string.
"""
match = LOG_LINE_CONSOLE_PATTERN.search(line)
output = ""
if match:
parsed = parse_message(match.group('msg'), lookup_dict)
output = "{} {} {} {}:{}> {}".format(match.group('re_level'), match.group('task'),
match.group('time'), parsed['file'],
parsed['line'], parsed['msg'])
return output
def parse_support_line(line, lookup_dict):
"""
Parse a log line
:param msg: The line to parse
:type msg: str
:param lookup_dict: Hash lookup dictionary
:type lookup_dict: dict
:returns: A string containing the parsed line, or a null string.
"""
match = LOG_LINE_SUPPORT_PATTERN.search(line)
output = ""
if match:
parsed = parse_message(match.group('msg'), lookup_dict)
output = "{} {} {}:{}> {}".format(match.group('date'), match.group('time'),
parsed['file'], parsed['line'], parsed['msg'])
return output
def parse_message(msg, lookup_dict):
"""
Parse the log message part of a line
:param msg: The message to parse
:type msg: str
:param lookup_dict: Hash lookup dictionary
:type lookup_dict: dict
:returns: A dictionary containing the parsed message, file name, and line number
"""
output = {'msg':msg, 'file':"", 'line':""}
match = LOG_MSG_PATTERN.search(msg)
if match:
output['file'] = match.group('f')
output['line'] = match.group('l')
hashed = match.group('h')
dehashed_str = dehash_str(hashed, lookup_dict)
output['msg'] = "LH:{}".format(dehashed_str)
match2 = DEHASHED_MSG_PATTERN.search(dehashed_str)
if match2:
output['file'] = match2.group(1) or output['file']
output['line'] = match2.group(2) or output['line']
output['msg'] = match2.group(3) or dehashed_str
return output
def dehash_str(hashed_info, lookup_dict):
"""
Search the lookup dictionary for a match, and return the dehashed string
:param hashed_info: Hash and arguments
:type hashed_info: str
:returns: A string with after doing a hash lookup, and substituting arguments
"""
match = HASHED_INFO_PATTERN.search(hashed_info)
# If there's no mach, return the hashed info as the log message
output = hashed_info
if match:
# Look for the hex value in the dictionary keys
# If we can't find a match, set formatted string to hashed_info
formatted_string = lookup_dict.get(str(match.group('hash_key')), hashed_info)
# If we couldn't find a match, try converting to base 10 to find a match
# If we can't find a match, set formatted string to hashed_info
if formatted_string == hashed_info:
formatted_string = lookup_dict.get(str(int(match.group('hash_key'), 16)), hashed_info)
# For each argument, substitute a C-style format specififier in the string
for arg in parse_args(match.group('arg_list')):
formatted_string = FORMAT_TAG_PATTERN.sub(arg, formatted_string, 1)
# Return the filename, and log message
output = formatted_string
return output
def parse_args(raw_args):
"""
Split the argument list, taking care of `delimited strings`
Idea taken from http://bit.ly/1KHzc0y
:param raw_args: Raw argument list
:type raw_args: str
:returns: A list containing the arguments
"""
args = []
arg_run = []
in_str = False
if raw_args:
for arg_ch in raw_args:
# Start or stop of a ` delimited string
if arg_ch == "`":
in_str = not in_str
# If we find a space, and we're not in a ` delimited string, this is a Boundary
elif arg_ch == " " and not in_str:
args.append("".join(arg_run).strip())
arg_run = []
else:
arg_run.append(arg_ch)
if arg_run:
args.append("".join(arg_run).strip())
return args

View file

@ -0,0 +1,204 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#/usr/bin/env python
"""
Module for hashing log strings
"""
import json
import os
import re
from pebble.loghashing.constants import (STR_LITERAL_PATTERN, FORMAT_SPECIFIER_PATTERN,
FORMAT_IDENTIFIER_STRING_FMT, LOOKUP_RESULT_STRING_FMT,
LINES_TO_HASH, HASH_MASK, HASH_NEXT_LINE,
LOOKUP_DEFAULT_STRING)
def hash_directory(path, output_file_name):
"""
Runs the line hasher on every file in a directory tree
:param path: Root of the tree to hash
:type path: str
"""
lookup_dict = {}
for walk in os.walk(path, followlinks=True):
# First and third item, respectively
root, file_names = walk[0::2]
for file_name in file_names:
lookup_dict.update(hash_file("{}/{}".format(root, file_name)))
# Read in hash_lookup
# Write lines out
with open(output_file_name, 'w') as fp:
json.dump(lookup_dict, fp)
def hash_file(file_name):
"""
Attempt to hash each line of a file
:param file_name: Name of file to hash
:type file_name: str
:returns: A hash lookup dictionary
"""
# Read in lines
with open(file_name, 'r') as fp:
lines = fp.readlines()
hashed_lines = []
lookup_dict = {}
force_hash = False
# Hash appropriate lines with line number, and file name
for index, line in enumerate(lines):
hashed_line, line_dict = hash_line(line, file_name, index + 1, force_hash)
force_hash = False
if HASH_NEXT_LINE in hashed_line:
force_hash = True
hashed_lines.append(hashed_line)
lookup_dict.update(line_dict)
# Write lines out
with open(file_name, 'w') as fp:
fp.writelines(hashed_lines)
return lookup_dict
def hash_line(line, file_name, line_num, force_hash=False):
"""
Search line for hashable strings, and hash them.
:param line: Line to search
:type line: str
:param file_name: Name of the file that the line is in
:type file_name: str
:param line_num: Line number of the line
:type line_num: int
:returns: A tuple with: The input line (with all hashable strings hashed),
and a hash lookup dictionary
"""
hash_dict = {}
# Only match lines that contain one of the following substrings
if force_hash or any(x in line for x in LINES_TO_HASH):
if force_hash or not any(x in line for x in ["PBL_CROAK_OOM"]):
match = STR_LITERAL_PATTERN.search(line)
if match:
# Strip all double quotes from the string
str_literal = re.sub("\"", "", match.group(2))
str_literal = inttype_conversion(str_literal)
# Hash the file name and line number in as well
line_to_hash = "{}:{}:{}".format(os.path.basename(file_name), line_num, str_literal)
hashed_msg = hash_string(line_to_hash)
hash_dict[hashed_msg] = line_to_hash
line = "{}{}{}\n".format(match.group(1), hashed_msg, match.group(3))
return (line, hash_dict)
def hash_string(string):
"""
Hash and return a given string.
:param string: String to hash
:type string: str
:returns: The input string, hashed
"""
return hex(hash(string) & HASH_MASK)
def inttype_conversion(inttype):
"""
Change PRI specifiers into classical printf format specifiers
:param inttype: PRI specifier to convert
:type inttype: str
:returns: The classical printf format specifier that inttype represents
"""
# Change ' PRIu32 ' to '32u'
output = re.sub(r"\s*PRI([diouxX])(8|16|32|64|PTR)\s*", r"\g<2>\g<1>", inttype)
# No length modifier for 8 or 16 modifier
output = re.sub("(8|16)", "", output)
# 'l' modifier for 32 or PTR modifier
output = re.sub("(32|PTR)", "l", output)
# 'll' modifier for 64 modifier
output = re.sub("64", "ll", output)
# Change BT_MAC_FMT and BT_ADDR_FMT
output = re.sub("BT_MAC_FMT", "%02X:%02X:%02X:%02X:%02X:%02X", output)
output = re.sub("BT_ADDR_FMT", "%02X:%02X:%02X:%02X:%02X:%02X", output)
output = re.sub("BT_DEVICE_ADDRESS_FMT", "%02X:%02X:%02X:%02X:%02X:%02X", output)
return output
def string_formats(string):
"""
Parses a string for all format identifiers
:param string: String to parse
:type string: str
:returns: A list of all format specifiers
"""
return FORMAT_SPECIFIER_PATTERN.findall(string)
def create_lookup_function(lookup_dict, output_file_name):
"""
Create a C source file for hash to format specifiers lookup
:param lookup_dict: Hash to string lookup dictionary
:type lookup_dict: dict
"""
strings = []
lines = [LOOKUP_DEFAULT_STRING]
format_lookup = {}
index = 1
format_map = [[x, string_formats(lookup_dict[x])] for x in lookup_dict.keys()]
for line, formats in format_map:
# Only make an entry if there's a format string!
if formats:
format_as_string = ''.join(formats)
if format_as_string not in format_lookup:
format_lookup[format_as_string] = index
strings.append(FORMAT_IDENTIFIER_STRING_FMT.format(index, format_as_string))
index = index + 1
lines.append(LOOKUP_RESULT_STRING_FMT.format(line, format_lookup[format_as_string]))
with open(output_file_name, 'w') as fp:
fp.writelines(strings)
fp.writelines(lines)

View file

@ -0,0 +1,303 @@
#!/usr/bin/env python
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf8 -*-
"""
Module for dehashing NewLog input
"""
import os
import re
import string
import struct
from pebble.loghashing.constants import (NEWLOG_LINE_CONSOLE_PATTERN,
NEWLOG_LINE_SUPPORT_PATTERN,
NEWLOG_HASHED_INFO_PATTERN,
POINTER_FORMAT_TAG_PATTERN,
HEX_FORMAT_SPECIFIER_PATTERN)
hex_digits = set(string.hexdigits)
LOG_DICT_KEY_VERSION = 'new_logging_version'
NEW_LOGGING_VERSION = 'NL0101'
LOG_LEVEL_ALWAYS = 0
LOG_LEVEL_ERROR = 1
LOG_LEVEL_WARNING = 50
LOG_LEVEL_INFO = 100
LOG_LEVEL_DEBUG = 200
LOG_LEVEL_DEBUG_VERBOSE = 255
level_strings_map = {
LOG_LEVEL_ALWAYS: '*',
LOG_LEVEL_ERROR: 'E',
LOG_LEVEL_WARNING: 'W',
LOG_LEVEL_INFO: 'I',
LOG_LEVEL_DEBUG: 'D',
LOG_LEVEL_DEBUG_VERBOSE: 'V'
}
# Location of the core number in the message hash
PACKED_CORE_OFFSET = 30
PACKED_CORE_MASK = 0x03
def dehash_file(file_name, log_dict):
"""
Dehash a file
:param line: The line to dehash
:type line: str
:param log_dict: dict of dicts created from .log_strings section from tintin_fw.elf
:type log_dict: dict of dicts
:returns: A list containing the dehashed lines
"""
# Grab the lines from the file
with open(file_name, 'r') as fp:
lines = fp.readlines()
# Dehash the lines
lines = [dehash_line(x, log_dict) + "\n" for x in lines]
return lines
def dehash_line(line, log_dict):
"""
Dehash a line. Return with old formatting.
:param line: The line to dehash
:type line: str
:param log_dict: dict of dicts created from .log_strings section from tintin_fw.elf
:type log_dict: dict of dicts
:returns: Formatted line
On error, the provided line
"""
line_dict = dehash_line_unformatted(line, log_dict)
if not line_dict:
return line
output = []
if 'date' not in line_dict and 're_level' in line_dict:
output.append(line_dict['re_level'])
if 'task' in line_dict:
output.append(line_dict['task'])
if 'date' in line_dict:
output.append(line_dict['date'])
if 'time' in line_dict:
output.append(line_dict['time'])
if 'file' in line_dict and 'line' in line_dict:
filename = os.path.basename(line_dict['file'])
output.append('{}:{}>'.format(filename, line_dict['line']))
output.append(line_dict['formatted_msg'])
return " ".join(output)
def dehash_line_unformatted(line, log_dict):
"""
Dehash a line. Return an unformatted dict of the info.
:param line: The line to dehash
:type line: str
:param log_dict: dict of dicts created from .log_strings section from tintin_fw.elf
:type log_dict: dict of dicts
:returns: A line_dict with keys 'formatted_msg', 'level', 'task', 'date', 'time', added.
On error, 'formatted_output' = <input line>
"""
line_dict = parse_line(line, log_dict)
if not line_dict:
return { 'formatted_msg': line }
return line_dict
def parse_line(line, log_dict):
"""
Parse a log line
:param line: The line to dehash
:type line: str
:param log_dict: dict of dicts created from .log_strings section from tintin_fw.elf
:type log_dict: dict of dicts
:returns: A line_dict with keys 'formatted_msg', 'level', 'task', 'date', 'time',
'core_number' added.
On error, None
"""
if not log_dict:
return None
# Handle BLE logs. They have no date, time, level in the input string
ble_line = line.startswith(':0> NL:')
match = None
if not ble_line:
match = NEWLOG_LINE_CONSOLE_PATTERN.search(line)
if not match:
match = NEWLOG_LINE_SUPPORT_PATTERN.search(line)
if not match:
return None
# Search for the 'msg' in the entire log dictionary, getting back the sub-dictionary for this
# specific message
if ble_line:
line_dict = parse_message(line, log_dict)
else:
line_dict = parse_message(match.group('msg'), log_dict)
if line_dict:
if ble_line:
line_dict['task'] = '-'
else:
# Add all of the match groups (.e.g, date, time, level) to the line dict
line_dict.update(match.groupdict())
# Fixup 'level' which came from the msg string (re_level) with the ascii char
if 'level' in line_dict:
line_dict['re_level'] = level_strings_map.get(int(line_dict['level']), '?')
return line_dict
def parse_message(msg, log_dict):
"""
Parse the log message part of a line
:param msg: The message to parse
:type msg: str
:param log_dict: dict of dicts created from .log_strings section from tintin_fw.elf
:type log_dict: dict of dicts
:returns: the dict entry for the log line and the formatted message
"""
match = NEWLOG_HASHED_INFO_PATTERN.search(msg)
if not match:
return None
try:
line_number = int(match.group('hash_key'), 16)
output_dict = log_dict[str(line_number)].copy() # Must be a copy!
except KeyError:
# Hash key not found. Wrong .elf?
return None
# Python's 'printf' doesn't support %p. Sigh. Convert to %x and hope for the best
safe_output_msg = POINTER_FORMAT_TAG_PATTERN.sub('\g<format>x', output_dict['msg'])
# Python's 'printf' doesn't handle (negative) 32-bit hex values correct. Build a new
# arg list from the parsed arg list by searching for %<format>X conversions and masking
# them to 32 bits.
arg_list = []
index = 0
for arg in parse_args(match.group('arg_list')):
index = safe_output_msg.find('%', index)
if index == -1:
# This is going to cause an error below...
arg_list.append(arg)
elif HEX_FORMAT_SPECIFIER_PATTERN.match(safe_output_msg, index):
# We found a %<format>X
arg_list.append(arg & 0xFFFFFFFF)
else:
arg_list.append(arg)
# Use "printf" to generate the reconstructed string. Make sure the arguments are correct
try:
output_msg = safe_output_msg % tuple(arg_list)
except (TypeError, UnicodeDecodeError) as e:
output_msg = msg + ' ----> ERROR: ' + str(e)
# Add the formatted msg to the copy of our line dict
output_dict['formatted_msg'] = output_msg
# Add the core number to the line dict
output_dict['core_number'] = str((line_number >> PACKED_CORE_OFFSET) & PACKED_CORE_MASK)
return output_dict
def parse_args(raw_args):
"""
Split the argument list, taking care of `delimited strings`
Idea taken from http://bit.ly/1KHzc0y
:param raw_args: Raw argument list. Values are either in hex or in `strings`
:type raw_args: str
:returns: A list containing the arguments
"""
args = []
arg_run = []
in_str = False
if raw_args:
for arg_ch in raw_args:
if arg_ch not in "` ":
arg_run.append(arg_ch)
continue
if in_str:
if arg_ch == ' ':
arg_run.append(' ')
else: # Must be ending `
args.append("".join(arg_run).strip())
in_str = False
arg_run = []
continue
# Start of a string
if arg_ch == '`':
in_str = True
continue
# Must be a space boundary (arg_ch == ' ')
arg = "".join(arg_run).strip()
if not len(arg):
continue
if not all(c in hex_digits for c in arg_run):
# Hack to prevent hex conversion failure
args.append(arg)
else:
# Every parameter is a 32-bit signed integer printed as a hex string with no
# leading zeros. Add the zero padding if necessary, convert to 4 hex bytes,
# and then reinterpret as a 32-bit signed big-endian integer.
args.append(struct.unpack('>i', arg.rjust(8, '0').decode('hex'))[0])
arg_run = []
# Clean up if anything is remaining (there is no trailing space)
arg = "".join(arg_run).strip()
if len(arg):
# Handle the case where the trailing ` is missing.
if not all(c in hex_digits for c in arg):
args.append(arg)
else:
# Every parameter is a 32-bit signed integer printed as a hex string with no
# leading zeros. Add the zero padding if necessary, convert to 4 hex bytes,
# and then reinterpret as a 32-bit signed big-endian integer.
args.append(struct.unpack('>i', arg.rjust(8, '0').decode('hex'))[0])
return args

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Setup.py for distutils
"""
from setuptools import setup, find_packages
setup(
name='pebble.loghash',
version='2.6.0',
description='Pebble Log Hashing module',
author='Pebble Technology Corp',
author_email='francois@pebble.com',
url='https://github.com/pebble/pyegg-pebble-loghash',
namespace_packages = ['pebble'],
packages = find_packages()
)

View file

@ -0,0 +1,115 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#/usr/bin/env python
"""
Tests for pebble.loghashing.dehashing
"""
LOOKUP_DICT = {"13108": "activity.c:activity tracking started",
"45803": "ispp.c:Start Authentication Process (%d) %s"}
from pebble.loghashing.dehashing import (dehash_line, parse_line, parse_support_line, parse_message,
dehash_str, parse_args)
def test_dehash_file():
"""
Test for dehash_file()
"""
pass
def test_dehash_line():
"""
Test for dehash_line()
"""
# Console Line - No arguments
assert ("D A 21:35:14.375 activity.c:804> activity tracking started" ==
dehash_line("D A 21:35:14.375 :804> LH:3334", LOOKUP_DICT))
# Console Line - Arguments
assert ("D A 21:35:14.375 ispp.c:872> Start Authentication Process (2) Success" ==
dehash_line("D A 21:35:14.375 :872> LH:b2eb 2 `Success`", LOOKUP_DICT))
# Support Line - No arguments
assert ("2015-09-05 02:16:16:000GMT activity.c:804> activity tracking started" ==
dehash_line("2015-09-05 02:16:16:000GMT :804 LH:3334", LOOKUP_DICT))
# Support Line - Arguments
assert ("2015-09-05 02:16:19:000GMT ispp.c:872> Start Authentication Process (2) Success" ==
dehash_line("2015-09-05 02:16:19:000GMT :872 LH:b2eb 2 `Success`", LOOKUP_DICT))
def test_parse_line():
"""
Test for parse_line()
"""
# No arguments
assert ("D A 21:35:14.375 activity.c:804> activity tracking started" ==
parse_line("D A 21:35:14.375 :804> LH:3334", LOOKUP_DICT))
# Arguments
assert ("D A 21:35:14.375 ispp.c:872> Start Authentication Process (2) Success" ==
parse_line("D A 21:35:14.375 :872> LH:b2eb 2 `Success`", LOOKUP_DICT))
def test_parse_support_line():
"""
Test for parse_support_line()
"""
# No arguments
assert ("2015-09-05 02:16:16:000GMT activity.c:804> activity tracking started" ==
parse_support_line("2015-09-05 02:16:16:000GMT :804 LH:3334", LOOKUP_DICT))
# Arguments
assert ("2015-09-05 02:16:19:000GMT ispp.c:872> Start Authentication Process (2) Success" ==
parse_support_line("2015-09-05 02:16:19:000GMT :872 LH:b2eb 2 `Success`", LOOKUP_DICT))
def test_parse_message():
"""
Test for parse_message()
"""
# Console Line - No arguments
assert ({'msg': 'activity tracking started', 'line': '804', 'file': 'activity.c'} ==
parse_message(":804> LH:3334", LOOKUP_DICT))
# Console Line - Arguments
assert ({'msg': 'Start Authentication Process (2) Success', 'line': '872', 'file': 'ispp.c'} ==
parse_message(":872> LH:b2eb 2 `Success`", LOOKUP_DICT))
# Support Line - No arguments
assert ({'msg': 'activity tracking started', 'line': '804', 'file': 'activity.c'} ==
parse_message(":804 LH:3334", LOOKUP_DICT))
# Support Line - Arguments
assert ({'msg': 'Start Authentication Process (2) Success', 'line': '872', 'file': 'ispp.c'} ==
parse_message(":872 LH:b2eb 2 `Success`", LOOKUP_DICT))
def test_dehash_str():
"""
Test for dehash_str()
"""
# No arguments
assert ("activity.c:activity tracking started" ==
dehash_str("3334", LOOKUP_DICT))
# Arguments
assert ("ispp.c:Start Authentication Process (%d) %s" ==
dehash_str("b2eb", LOOKUP_DICT))
def test_parse_args():
"""
Test for parse_args()
"""
# No `` delimted strings
assert ["foo", "bar", "baz"] == parse_args("foo bar baz")
# `` delimited strings
assert ["foo", "bar baz"] == parse_args("foo `bar baz`")

View file

@ -0,0 +1,226 @@
#! /usr/bin/env python
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf8 -*-
"""
Tests for pebble.loghashing.newlogging
"""
from pebble.loghashing.newlogging import dehash_line, dehash_line_unformatted
from pebble.loghashing.dehashing import dehash_line as legacy_dehash_line
import os
test_log_dict = {'43': {'file': '../src/fw/activity/activity.c',
'line': '804',
'level': '200',
'color': 'YELLOW',
'msg': 'activity tracking started'},
'114': {'file': '../src/fw/driver/ispp.c',
'line': '1872',
'level': '0',
'color': 'RED',
'msg': 'Start Authentication Process %d (%x) %s'},
'214': {'file': 'pointer_print.c',
'line': '1872',
'level': '0',
'color': 'RED',
'msg': 'My address is %p %p'},
'64856': {'color': 'GREY',
'file': '../src/fw/services/common/clock.c',
'level': '200',
'line': '768',
'msg': 'Changed timezone to id %u, gmtoff is %ld'},
'100000': {'color': 'GREY',
'file': '../src/fw/services/common/string.c',
'level': '200',
'line': '111',
'msg': 'string 1 %s, string 2 %s'},
'11082': {'color': 'GREY',
'file': '../src/fw/resource/resource_storage.c',
'level': '50',
'line': '120',
'msg': '0x%lx != 0x%lx'},
'1073741824': {'color': 'GREY',
'file': 'hc_protocol.c',
'level': '0',
'line': '69',
'msg': 'Init BLE SPI Protocol'},
'new_logging_version': 'NL0101'
}
def test_dehash_line():
"""
Test for dehash_line()
"""
# Console Line - No arguments
line = "? A 21:35:14.375 :0> NL:{:x}".format(43)
assert ("D A 21:35:14.375 activity.c:804> activity tracking started" ==
dehash_line(line, test_log_dict))
# Console Line - Arguments
line = "? A 21:35:14.375 :0> NL:{:x} a a `Success`".format(114)
assert ("* A 21:35:14.375 ispp.c:1872> Start Authentication Process 10 (a) Success" ==
dehash_line(line, test_log_dict))
# Support Line - No arguments
line = "2015-09-05 02:16:16:000GMT :0> NL:{:x}".format(43)
assert ("2015-09-05 02:16:16:000GMT activity.c:804> activity tracking started" ==
dehash_line(line, test_log_dict))
# Support Line - Arguments
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 10 `Success`".format(114)
assert ("2015-09-05 02:16:19:000GMT ispp.c:1872> Start Authentication Process 16 (10) Success" ==
dehash_line(line, test_log_dict))
# App Log
line = "D A 21:35:14.375 file.c:0> This is an app debug line"
assert (line == dehash_line(line, test_log_dict))
# Pointer format conversion
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 164 1FfF".format(214)
assert ("2015-09-05 02:16:19:000GMT pointer_print.c:1872> My address is 164 1fff" ==
dehash_line(line, test_log_dict))
# Two's compliment negative value
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 ffff8170".format(64856)
assert ("2015-09-05 02:16:19:000GMT clock.c:768> Changed timezone to id 16, gmtoff is -32400" ==
dehash_line(line, test_log_dict))
# Two's compliment negative value
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 9AEBC155 43073997".format(11082)
assert ("2015-09-05 02:16:19:000GMT resource_storage.c:120> 0x9aebc155 != 0x43073997" ==
dehash_line(line, test_log_dict))
# Empty string parameter - 1
line = "? A 21:35:14.375 :0> NL:{:x} `` `string`".format(100000)
assert ("D A 21:35:14.375 string.c:111> string 1 , string 2 string" ==
dehash_line(line, test_log_dict))
# Empty string parameter - 2 - trailing space
line = "? A 21:35:14.375 :0> NL:{:x} `string` `` ".format(100000)
assert ("D A 21:35:14.375 string.c:111> string 1 string, string 2 " ==
dehash_line(line, test_log_dict))
# Empty string parameter - 2 - no trailing space
line = "? A 21:35:14.375 :0> NL:{:x} `string` ``".format(100000)
assert ("D A 21:35:14.375 string.c:111> string 1 string, string 2 " ==
dehash_line(line, test_log_dict))
# Missing closing `
line = "? A 21:35:14.375 :0> NL:{:x} `string` `string".format(100000)
assert ("D A 21:35:14.375 string.c:111> string 1 string, string 2 string" ==
dehash_line(line, test_log_dict))
def test_dehash_invalid_parameters():
"""
Tests for invalid number of parameters
"""
# Not enough parameters
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 164".format(214)
assert ("2015-09-05 02:16:19:000GMT pointer_print.c:1872> :0> NL:d6 164 " \
"----> ERROR: not enough arguments for format string" ==
dehash_line(line, test_log_dict))
# Too many parameters
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 164 1FfF 17".format(214)
assert ("2015-09-05 02:16:19:000GMT pointer_print.c:1872> :0> NL:d6 164 1FfF 17 " \
"----> ERROR: not all arguments converted during string formatting" ==
dehash_line(line, test_log_dict))
# Unterminated string (last `)
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 10 `Success".format(114)
assert ("2015-09-05 02:16:19:000GMT ispp.c:1872> Start Authentication Process 16 (10) Success" ==
dehash_line(line, test_log_dict))
# Unterminated string (first `)
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 10 Success`".format(114)
assert ("2015-09-05 02:16:19:000GMT ispp.c:1872> Start Authentication Process 16 (10) Success" ==
dehash_line(line, test_log_dict))
# Unterminated string (No `s)
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 10 Success".format(114)
assert ("2015-09-05 02:16:19:000GMT ispp.c:1872> Start Authentication Process 16 (10) Success" ==
dehash_line(line, test_log_dict))
# Invalid hex character
line = "2015-09-05 02:16:19:000GMT :0> NL:{:x} 10 1q0 Success".format(114)
assert ("2015-09-05 02:16:19:000GMT ispp.c:1872> :0> NL:72 10 1q0 Success " \
"----> ERROR: %x format: a number is required, not str" ==
dehash_line(line, test_log_dict))
# Unicode
line = "? A 21:35:14.375 :0> NL:{:x} `unicode` `Pebble β`".format(100000)
assert ("D A 21:35:14.375 string.c:111> string 1 unicode, string 2 Pebble β" ==
dehash_line(line, test_log_dict))
def test_legacy_dehash_line():
"""
Test legacy dehash_line()
"""
# Console Line - No arguments
line = "? A 21:35:14.375 :0> NL:{:x}".format(43)
assert ("D A 21:35:14.375 activity.c:804> activity tracking started" ==
legacy_dehash_line(line, test_log_dict))
def test_unformatted():
"""
Test dehash_line_unformatted()
"""
line = "? A 21:35:14.375 :0> NL:{:x} a a `Success`".format(114)
line_dict = dehash_line_unformatted(line, test_log_dict)
assert (line_dict['level'] == "0")
assert (line_dict['task'] == "A")
assert (line_dict['time'] == "21:35:14.375")
assert (os.path.basename(line_dict['file']) == "ispp.c")
assert (line_dict['line'] == "1872")
assert (line_dict['formatted_msg'] == "Start Authentication Process 10 (a) Success")
def test_core_number():
"""
Test core number decoding
"""
# Core number 0
line = "? A 21:35:14.375 :0> NL:{:x} a a `Success`".format(114)
line_dict = dehash_line_unformatted(line, test_log_dict)
assert (line_dict['core_number'] == "0")
# Core number 1
line = "? A 21:35:14.375 :0> NL:{:x}".format(1073741824)
line_dict = dehash_line_unformatted(line, test_log_dict)
assert (line_dict['core_number'] == "1")
def test_ble_decode():
"""
Test BLE decode.
timedate.now() is used, so ignore the date/time
"""
line = ":0> NL:{:x}".format(1073741824)
line_dict = dehash_line_unformatted(line, test_log_dict)
assert (line_dict['level'] == "0")
assert (line_dict['task'] == "-")
assert (os.path.basename(line_dict['file']) == "hc_protocol.c")
assert (line_dict['line'] == "69")
assert (line_dict['formatted_msg'] == "Init BLE SPI Protocol")

89
python_libs/pulse2/.gitignore vendored Normal file
View file

@ -0,0 +1,89 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject

View file

@ -0,0 +1,6 @@
pebble.pulse2
=============
pulse2 is a Python implementation of the PULSEv2 protocol suite.
https://pebbletechnology.atlassian.net/wiki/display/DEV/PULSEv2+Protocol+Suite

View file

@ -0,0 +1,15 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,24 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import link, transports
# Public aliases for the classes that users will interact with directly.
from .link import Interface
link.Link.register_transport(
'best-effort', transports.BestEffortApplicationTransport)
link.Link.register_transport('reliable', transports.ReliableTransport)

View file

@ -0,0 +1,37 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PulseException(Exception):
pass
class TTYAutodetectionUnavailable(PulseException):
pass
class ReceiveQueueEmpty(PulseException):
pass
class TransportNotReady(PulseException):
pass
class SocketClosed(PulseException):
pass
class AlreadyInProgressError(PulseException):
'''Another operation is already in progress.
'''

View file

@ -0,0 +1,148 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
PULSEv2 Framing
This module handles encoding and decoding of datagrams in PULSEv2 frames: flag
delimiters, transparency encoding and Frame Check Sequence. The content of the
datagrams themselves are not examined or parsed.
'''
from __future__ import absolute_import
import binascii
import struct
try:
import queue
except ImportError: # Py2
import Queue as queue
from cobs import cobs
FLAG = 0x55
CRC32_RESIDUE = binascii.crc32(b'\0' * 4)
class FramingException(Exception):
pass
class DecodeError(FramingException):
pass
class CorruptFrame(FramingException):
pass
class FrameSplitter(object):
'''Takes a byte stream and partitions it into frames.
Empty frames (two consecutive flag bytes) are silently discarded.
No transparency conversion is applied to the contents of the frames.
FrameSplitter objects support iteration for retrieving split frames.
>>> splitter = FrameSplitter()
>>> splitter.write(b'\x55foo\x55bar\x55')
>>> list(splitter)
[b'foo', b'bar']
'''
def __init__(self, max_frame_length=0):
self.frames = queue.Queue()
self.input_buffer = bytearray()
self.max_frame_length = max_frame_length
self.waiting_for_sync = True
def write(self, data):
'''Write bytes into the splitter for processing.
'''
for char in bytearray(data):
if self.waiting_for_sync:
if char == FLAG:
self.waiting_for_sync = False
else:
if char == FLAG:
if self.input_buffer:
self.frames.put_nowait(bytes(self.input_buffer))
self.input_buffer = bytearray()
else:
if (not self.max_frame_length or
len(self.input_buffer) < self.max_frame_length):
self.input_buffer.append(char)
else:
self.input_buffer = bytearray()
self.waiting_for_sync = True
def __iter__(self):
while True:
try:
yield self.frames.get_nowait()
except queue.Empty:
return
def decode_transparency(frame_bytes):
'''Decode the transparency encoding applied to a PULSEv2 frame.
Returns the decoded frame, or raises `DecodeError`.
'''
frame_bytes = bytearray(frame_bytes)
if FLAG in frame_bytes:
raise DecodeError("flag byte in encoded frame")
try:
return cobs.decode(bytes(frame_bytes.replace(b'\0', bytearray([FLAG]))))
except cobs.DecodeError as e:
raise DecodeError(str(e))
def strip_fcs(frame_bytes):
'''Validates the FCS in a PULSEv2 frame.
The frame is returned with the FCS removed if the FCS check passes.
A `CorruptFrame` exception is raised if the FCS check fails.
The frame must not be transparency-encoded.
'''
if len(frame_bytes) <= 4:
raise CorruptFrame('frame too short')
if binascii.crc32(frame_bytes) != CRC32_RESIDUE:
raise CorruptFrame('FCS check failure')
return frame_bytes[:-4]
def decode_frame(frame_bytes):
'''Decode and validate a PULSEv2-encoded frame.
Returns the datagram extracted from the frame, or raises a
`FramingException` or subclass if there was an error decoding the frame.
'''
return strip_fcs(decode_transparency(frame_bytes))
def encode_frame(datagram):
'''Encode a datagram in a PULSEv2 frame.
'''
datagram = bytearray(datagram)
fcs = binascii.crc32(datagram) & 0xffffffff
fcs_bytes = struct.pack('<I', fcs)
datagram.extend(fcs_bytes)
flag = bytearray([FLAG])
frame = cobs.encode(bytes(datagram)).replace(flag, b'\0')
return flag + frame + flag

View file

@ -0,0 +1,314 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
import threading
import serial
from . import exceptions, framing, ppp, transports
from . import logging as pulse2_logging
from . import pcap_file
try:
import pyftdi.serialext
except ImportError:
pass
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
DBGSERIAL_PORT_SETTINGS = dict(baudrate=1000000, timeout=0.1,
interCharTimeout=0.0001)
def get_dbgserial_tty():
# Local import so that we only depend on this package if we're attempting
# to autodetect the TTY. This package isn't always available (e.g., MFG),
# so we don't want it to be required.
try:
import pebble_tty
return pebble_tty.find_dbgserial_tty()
except ImportError:
raise exceptions.TTYAutodetectionUnavailable
class Interface(object):
'''The PULSEv2 lower data-link layer.
An Interface object is roughly analogous to a network interface,
somewhat like an Ethernet port. It provides connectionless service
with PULSEv2 framing, which upper layers build upon to provide
connection-oriented service.
An Interface is bound to an I/O stream, such as a Serial port, and
remains open until either the Interface is explicitly closed or the
underlying I/O stream is closed from underneath it.
'''
def __init__(self, iostream, capture_stream=None):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.iostream = iostream
self.closed = False
self.close_lock = threading.RLock()
self.default_packet_handler_cb = None
self.sockets = {}
self.pcap = None
if capture_stream:
self.pcap = pcap_file.PcapWriter(
capture_stream, pcap_file.LINKTYPE_PPP_WITH_DIR)
self.receive_thread = threading.Thread(target=self.receive_loop)
self.receive_thread.daemon = True
self.receive_thread.start()
self.simplex_transport = transports.SimplexTransport(self)
self._link = None
self.link_available = threading.Event()
self.lcp = ppp.LinkControlProtocol(self)
self.lcp.on_link_up = self.on_link_up
self.lcp.on_link_down = self.on_link_down
self.lcp.up()
self.lcp.open()
@classmethod
def open_dbgserial(cls, url=None, capture_stream=None):
if url is None:
url = get_dbgserial_tty()
elif url == 'qemu':
url = 'socket://localhost:12345'
ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS)
if url.startswith('socket://'):
# interCharTimeout doesn't apply to sockets, so shrink the receive
# timeout to compensate.
ser.timeout = 0.5
ser._socket.settimeout(0.5)
return cls(ser, capture_stream)
def connect(self, protocol):
'''Open a link-layer socket for sending and receiving packets
of a specific protocol number.
'''
if protocol in self.sockets and not self.sockets[protocol].closed:
raise ValueError('A socket is already bound '
'to protocol 0x%04x' % protocol)
self.sockets[protocol] = socket = InterfaceSocket(self, protocol)
return socket
def unregister_socket(self, protocol):
'''Used by InterfaceSocket objets to unregister themselves when
closing.
'''
try:
del self.sockets[protocol]
except KeyError:
pass
def receive_loop(self):
splitter = framing.FrameSplitter()
while True:
if self.closed:
self.logger.info('Interface closed; receive loop exiting')
break
try:
splitter.write(self.iostream.read(1))
except IOError:
if self.closed:
self.logger.info('Interface closed; receive loop exiting')
else:
self.logger.exception('Unexpected error while reading '
'from iostream')
self._down()
break
for frame in splitter:
try:
datagram = framing.decode_frame(frame)
if self.pcap:
# Prepend pseudo-header meaning "received by this host"
self.pcap.write_packet(b'\0' + datagram)
protocol, information = ppp.unencapsulate(datagram)
if protocol in self.sockets:
self.sockets[protocol].handle_packet(information)
else:
# TODO LCP Protocol-Reject
self.logger.info('Protocol-reject: %04X', protocol)
except (framing.DecodeError, framing.CorruptFrame):
pass
def send_packet(self, protocol, packet):
if self.closed:
raise ValueError('I/O operation on closed interface')
datagram = ppp.encapsulate(protocol, packet)
if self.pcap:
# Prepend pseudo-header meaning "sent by this host"
self.pcap.write_packet(b'\x01' + datagram)
self.iostream.write(framing.encode_frame(datagram))
def close_all_sockets(self):
# Iterating over a copy of sockets since socket.close() can call
# unregister_socket, which modifies the socket dict. Modifying
# a dict during iteration is not allowed, so the iteration is
# completed (by making the copy) before modification can begin.
for socket in list(self.sockets.values()):
socket.close()
def close(self):
with self.close_lock:
if self.closed:
return
self.lcp.shutdown()
self.close_all_sockets()
self._down()
if self.pcap:
self.pcap.close()
def _down(self):
'''The lower layer (iostream) is down. Bring down the interface.
'''
with self.close_lock:
self.closed = True
self.close_all_sockets()
self.lcp.down()
self.simplex_transport.down()
self.iostream.close()
def on_link_up(self):
# FIXME PBL-34320 proper MTU/MRU support
self._link = Link(self, mtu=1500)
# Test whether the link is ready to carry traffic
self.lcp.ping(self._ping_done)
def _ping_done(self, ping_check_succeeded):
if ping_check_succeeded:
self.link_available.set()
else:
self.lcp.restart()
def on_link_down(self):
self.link_available.clear()
self._link.down()
self._link = None
def get_link(self, timeout=60.0):
'''Get the opened Link object for this interface.
This function will block waiting for the Link to be available.
It will return `None` if the timeout expires before the link
is available.
'''
if self.closed:
raise ValueError('No link available on closed interface')
if self.link_available.wait(timeout):
assert self._link is not None
return self._link
class InterfaceSocket(object):
'''A socket for sending and receiving link-layer packets over a
PULSE2 interface.
Callbacks can be registered on the socket by assigning callables to
the appropriate attributes on the socket object. Callbacks can be
unregistered by setting the attributes back to `None`.
Available callbacks:
- `on_packet(information)`
- `on_close()`
'''
on_packet = None
on_close = None
def __init__(self, interface, protocol):
self.interface = interface
self.protocol = protocol
self.closed = False
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def send(self, information):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
self.interface.send_packet(self.protocol, information)
def handle_packet(self, information):
if self.on_packet and not self.closed:
self.on_packet(information)
def close(self):
if self.closed:
return
self.closed = True
if self.on_close:
self.on_close()
self.interface.unregister_socket(self.protocol)
self.on_packet = None
self.on_close = None
class Link(object):
'''The connectionful portion of a PULSE2 interface.
'''
TRANSPORTS = {}
on_close = None
@classmethod
def register_transport(cls, name, factory):
'''Register a PULSE transport.
'''
if name in cls.TRANSPORTS:
raise ValueError('transport name %r is already registered '
'with %r' % (name, cls.TRANSPORTS[name]))
cls.TRANSPORTS[name] = factory
def __init__(self, interface, mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.interface = interface
self.closed = False
self.mtu = mtu
self.transports = {}
for name, factory in self.TRANSPORTS.iteritems():
transport = factory(interface, mtu)
self.transports[name] = transport
def open_socket(self, transport, port, timeout=30.0):
if self.closed:
raise ValueError('Cannot open socket on closed Link')
if transport not in self.transports:
raise KeyError('Unknown transport %r' % transport)
return self.transports[transport].open_socket(port, timeout)
def down(self):
self.closed = True
if self.on_close:
self.on_close()
for transport in self.transports.itervalues():
transport.down()

View file

@ -0,0 +1,31 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
class TaggedAdapter(logging.LoggerAdapter):
'''Annotates all log messages with a "[tag]" prefix.
The value of the tag is specified in the dict argument passed into
the adapter's constructor.
>>> logger = logging.getLogger(__name__)
>>> adapter = TaggedAdapter(logger, {'tag': 'tag value'})
'''
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['tag'], msg), kwargs

View file

@ -0,0 +1,68 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Writer for Libpcap capture files
https://wiki.wireshark.org/Development/LibpcapFileFormat
'''
from __future__ import absolute_import
import struct
import threading
import time
LINKTYPE_PPP_WITH_DIR = 204
class PcapWriter(object):
def __init__(self, outfile, linktype):
self.lock = threading.Lock()
self.outfile = outfile
self._write_pcap_header(linktype)
def close(self):
with self.lock:
self.outfile.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def _write_pcap_header(self, linktype):
header = struct.pack('!IHHiIII',
0xa1b2c3d4, # guint32 magic_number
2, # guint16 version_major
4, # guint16 version_minor
0, # guint32 thiszone
0, # guint32 sigfigs (unused)
65535, # guint32 snaplen
linktype) # guint32 network
self.outfile.write(header)
def write_packet(self, data, timestamp=None, orig_len=None):
assert len(data) <= 65535
if timestamp is None:
timestamp = time.time()
if orig_len is None:
orig_len = len(data)
ts_seconds = int(timestamp)
ts_usec = int((timestamp - ts_seconds) * 1000000)
header = struct.pack('!IIII', ts_seconds, ts_usec, len(data), orig_len)
with self.lock:
self.outfile.write(header + data)

View file

@ -0,0 +1,201 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''PULSE Control Message Protocol
'''
from __future__ import absolute_import
import codecs
import collections
import enum
import logging
import struct
import threading
from . import exceptions
from . import logging as pulse2_logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
class ParseError(exceptions.PulseException):
pass
@enum.unique
class PCMPCode(enum.Enum):
Echo_Request = 1
Echo_Reply = 2
Discard_Request = 3
Port_Closed = 129
Unknown_Code = 130
class PCMPPacket(collections.namedtuple('PCMPPacket', 'code information')):
__slots__ = ()
@classmethod
def parse(cls, packet):
packet = bytes(packet)
if len(packet) < 1:
raise ParseError('packet too short')
return cls(code=struct.unpack('B', packet[0])[0],
information=packet[1:])
@staticmethod
def build(code, information):
return struct.pack('B', code) + bytes(information)
class PulseControlMessageProtocol(object):
'''This protocol is unique in that it is logically part of the
transport but is layered on top of the transport over the wire.
To keep from needing to create a new thread just for reading from
the socket, the implementation acts both like a socket and protocol
all in one.
'''
PORT = 0x0001
on_port_closed = None
@classmethod
def bind(cls, transport):
return transport.open_socket(cls.PORT, factory=cls)
def __init__(self, transport, port):
assert port == self.PORT
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': 'PCMP(%s)' % (type(transport).__name__)})
self.transport = transport
self.closed = False
self.ping_lock = threading.RLock()
self.ping_cb = None
self.ping_attempts_remaining = 0
self.ping_timer = None
def close(self):
if self.closed:
return
with self.ping_lock:
self.ping_cb = None
if self.ping_timer:
self.ping_timer.cancel()
self.closed = True
self.transport.unregister_socket(self.PORT)
def send_unknown_code(self, bad_code):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Unknown_Code.value, struct.pack('B', bad_code)))
def send_echo_request(self, data):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Echo_Request.value, data))
def send_echo_reply(self, data):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Echo_Reply.value, data))
def on_receive(self, raw_packet):
try:
packet = PCMPPacket.parse(raw_packet)
except ParseError:
self.logger.exception('Received malformed packet')
return
try:
code = PCMPCode(packet.code)
except ValueError:
self.logger.error('Received packet with unknown code %d',
packet.code)
self.send_unknown_code(packet.code)
return
if code == PCMPCode.Discard_Request:
pass
elif code == PCMPCode.Echo_Request:
self.send_echo_reply(packet.information)
elif code == PCMPCode.Echo_Reply:
with self.ping_lock:
if self.ping_cb:
self.ping_timer.cancel()
self.ping_cb(True)
self.ping_cb = None
self.logger.debug('Echo-Reply: %s',
codecs.encode(packet.information, 'hex'))
elif code == PCMPCode.Port_Closed:
if len(packet.information) == 2:
if self.on_port_closed:
closed_port, = struct.unpack('!H', packet.information)
self.on_port_closed(closed_port)
else:
self.logger.error(
'Remote peer sent malformed Port-Closed packet: %s',
codecs.encode(packet.information, 'hex'))
elif code == PCMPCode.Unknown_Code:
if len(packet.information) == 1:
self.logger.error('Remote peer sent Unknown-Code(%d) packet',
struct.unpack('B', packet.information)[0])
else:
self.logger.error(
'Remote peer sent malformed Unknown-Code packet: %s',
codecs.encode(packet.information, 'hex'))
else:
assert False, 'Known code not handled'
def ping(self, result_cb, attempts=3, timeout=1.0):
'''Test the link quality by sending Echo-Request packets and
listening for Echo-Reply packets from the remote peer.
The ping is performed asynchronously. The `result_cb` callable
will be called when the ping completes. It will be called with
a single positional argument: a truthy value if the remote peer
responded to the ping, or a falsy value if all ping attempts
timed out.
'''
if attempts < 1:
raise ValueError('attempts must be positive')
if timeout <= 0:
raise ValueError('timeout must be positive')
with self.ping_lock:
if self.ping_cb:
raise exceptions.AlreadyInProgressError(
'another ping is currently in progress')
self.ping_cb = result_cb
self.ping_attempts_remaining = attempts - 1
self.ping_timeout = timeout
self.send_echo_request(b'')
self.ping_timer = threading.Timer(timeout,
self._ping_timer_expired)
self.ping_timer.daemon = True
self.ping_timer.start()
def _ping_timer_expired(self):
with self.ping_lock:
if not self.ping_cb:
# The Echo-Reply packet must have won the race
return
if self.ping_attempts_remaining:
self.ping_attempts_remaining -= 1
self.send_echo_request(b'')
self.ping_timer = threading.Timer(self.ping_timeout,
self._ping_timer_expired)
self.ping_timer.daemon = True
self.ping_timer.start()
else:
self.ping_cb(False)
self.ping_cb = None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,68 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division
import math
class OnlineStatistics(object):
'''Calculates various statistical properties of a data series
iteratively, without keeping the data items in memory.
Available statistics:
- Count
- Min
- Max
- Mean
- Variance
- Standard Deviation
The variance calculation algorithm is taken from
https://en.wikipedia.org/w/index.php?title=Algorithms_for_calculating_variance&oldid=715886413#Online_algorithm
'''
def __init__(self):
self.count = 0
self.min = float('nan')
self.max = float('nan')
self.mean = 0.0
self.M2 = 0.0
def update(self, datum):
self.count += 1
if self.count == 1:
self.min = datum
self.max = datum
else:
self.min = min(self.min, datum)
self.max = max(self.max, datum)
delta = datum - self.mean
self.mean += delta / self.count
self.M2 += delta * (datum - self.mean)
@property
def variance(self):
if self.count < 2:
return float('nan')
return self.M2 / (self.count - 1)
@property
def stddev(self):
return math.sqrt(self.variance)
def __str__(self):
return 'min/avg/max/stddev = {:.03f}/{:.03f}/{:.03f}/{:.03f}'.format(
self.min, self.mean, self.max, self.stddev)

View file

@ -0,0 +1,628 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
import threading
import time
try:
import queue
except ImportError:
import Queue as queue
import construct
from . import exceptions
from . import logging as pulse2_logging
from . import pcmp
from . import ppp
from . import stats
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
class Socket(object):
'''A socket for sending and receiving packets over a single port
of a PULSE2 transport.
'''
def __init__(self, transport, port):
self.transport = transport
self.port = port
self.closed = False
self.receive_queue = queue.Queue()
def on_receive(self, packet):
self.receive_queue.put((True, packet))
def receive(self, block=True, timeout=None):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
try:
info_good, info = self.receive_queue.get(block, timeout)
if not info_good:
assert self.closed
raise exceptions.SocketClosed('Socket closed during receive')
return info
except queue.Empty:
raise exceptions.ReceiveQueueEmpty
def send(self, information):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
self.transport.send(self.port, information)
def close(self):
if self.closed:
return
self.closed = True
self.transport.unregister_socket(self.port)
# Wake up the thread blocking on a receive (if any) so that it
# can abort the receive quickly.
self.receive_queue.put((False, None))
@property
def mtu(self):
return self.transport.mtu
class TransportControlProtocol(ppp.ControlProtocol):
def __init__(self, interface, transport, ncp_protocol, display_name=None):
ppp.ControlProtocol.__init__(self, display_name)
self.interface = interface
self.ncp_protocol = ncp_protocol
self.transport = transport
def up(self):
ppp.ControlProtocol.up(self, self.interface.connect(self.ncp_protocol))
def this_layer_up(self, *args):
self.transport.this_layer_up()
def this_layer_down(self, *args):
self.transport.this_layer_down()
BestEffortPacket = construct.Struct('BestEffortPacket', # noqa
construct.UBInt16('port'),
construct.UBInt16('length'),
construct.Field('information', lambda ctx: ctx.length - 4),
ppp.OptionalGreedyString('padding'),
)
class BestEffortTransportBase(object):
def __init__(self, interface, link_mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.sockets = {}
self.closed = False
self._mtu = link_mtu - 4
self.link_socket = interface.connect(self.PROTOCOL_NUMBER)
self.link_socket.on_packet = self.packet_received
def send(self, port, information):
if len(information) > self.mtu:
raise ValueError('Packet length (%d) exceeds transport MTU (%d)' % (
len(information), self.mtu))
packet = BestEffortPacket.build(construct.Container(
port=port, length=len(information)+4,
information=information, padding=b''))
self.link_socket.send(packet)
def packet_received(self, packet):
if self.closed:
self.logger.warning('Received packet on closed transport')
return
try:
fields = BestEffortPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed packet')
return
if len(fields.information) + 4 != fields.length:
self.logger.error('Received truncated or corrupt packet '
'(expected %d, got %d data bytes)',
fields.length-4, len(fields.information))
return
if fields.port in self.sockets:
self.sockets[fields.port].on_receive(fields.information)
else:
self.logger.warning('Received packet for unopened port %04X',
fields.port)
def open_socket(self, port, factory=Socket):
if self.closed:
raise ValueError('Cannot open socket on closed transport')
if port in self.sockets and not self.sockets[port].closed:
raise KeyError('Another socket is already opened '
'on port 0x%04x' % port)
socket = factory(self, port)
self.sockets[port] = socket
return socket
def unregister_socket(self, port):
del self.sockets[port]
def down(self):
'''Called by the Link when the link layer goes down.
This closes the Transport object. Once closed, the Transport
cannot be reopened.
'''
self.closed = True
self.close_all_sockets()
self.link_socket.close()
def close_all_sockets(self):
# A socket could try to unregister itself when closing, which
# would modify the sockets dict. Make a copy of the sockets
# collection before closing them so that we are not iterating
# over the dict when it could get modified.
for socket in list(self.sockets.values()):
socket.close()
self.sockets = {}
@property
def mtu(self):
return self._mtu
class BestEffortApplicationTransport(BestEffortTransportBase):
NCP_PROTOCOL_NUMBER = 0xBA29
PROTOCOL_NUMBER = 0x3A29
def __init__(self, interface, link_mtu):
BestEffortTransportBase.__init__(self, interface=interface,
link_mtu=link_mtu)
self.opened = threading.Event()
self.ncp = TransportControlProtocol(
interface=interface, transport=self,
ncp_protocol=self.NCP_PROTOCOL_NUMBER,
display_name='BestEffortControlProtocol')
self.ncp.up()
self.ncp.open()
def this_layer_up(self):
# We can't let PCMP bind itself using the public open_socket
# method as the method will block until self.opened is set, but
# it won't be set until we use PCMP Echo to test that the
# transport is ready to carry traffic. So we must manually bind
# the port without waiting.
self.pcmp = pcmp.PulseControlMessageProtocol(
self, pcmp.PulseControlMessageProtocol.PORT)
self.sockets[pcmp.PulseControlMessageProtocol.PORT] = self.pcmp
self.pcmp.on_port_closed = self.on_port_closed
self.pcmp.ping(self._ping_done)
def _ping_done(self, ping_check_succeeded):
# Don't need to do anything in the success case as receiving
# any packet is enough to set the transport as Opened.
if not ping_check_succeeded:
self.logger.warning('Ping check failed. Restarting transport.')
self.ncp.restart()
def this_layer_down(self):
self.opened.clear()
self.close_all_sockets()
def send(self, *args, **kwargs):
if self.closed:
raise exceptions.TransportNotReady(
'I/O operation on closed transport')
if not self.ncp.is_Opened():
raise exceptions.TransportNotReady(
'I/O operation before transport is opened')
BestEffortTransportBase.send(self, *args, **kwargs)
def packet_received(self, packet):
if self.ncp.is_Opened():
self.opened.set()
BestEffortTransportBase.packet_received(self, packet)
else:
self.logger.warning('Received packet before the transport is open. '
'Discarding.')
def open_socket(self, port, timeout=30.0, factory=Socket):
if not self.opened.wait(timeout):
return None
return BestEffortTransportBase.open_socket(self, port, factory)
def down(self):
self.ncp.down()
BestEffortTransportBase.down(self)
def on_port_closed(self, closed_port):
self.logger.info('Remote peer says port 0x%04X is closed; '
'closing socket', closed_port)
try:
self.sockets[closed_port].close()
except KeyError:
self.logger.exception('No socket is open on port 0x%04X!',
closed_port)
class SimplexTransport(BestEffortTransportBase):
PROTOCOL_NUMBER = 0x5021
def __init__(self, interface):
BestEffortTransportBase.__init__(self, interface=interface, link_mtu=0)
def send(self, *args, **kwargs):
raise NotImplementedError
@property
def mtu(self):
return 0
ReliableInfoPacket = construct.Struct('ReliableInfoPacket', # noqa
# BitStructs are parsed MSBit-first
construct.EmbeddedBitStruct(
construct.BitField('sequence_number', 7), # N(S) in LAPB
construct.Const(construct.Bit('discriminator'), 0),
construct.BitField('ack_number', 7), # N(R) in LAPB
construct.Flag('poll'),
),
construct.UBInt16('port'),
construct.UBInt16('length'),
construct.Field('information', lambda ctx: ctx.length - 6),
ppp.OptionalGreedyString('padding'),
)
ReliableSupervisoryPacket = construct.BitStruct(
'ReliableSupervisoryPacket',
construct.Const(construct.Nibble('reserved'), 0b0000),
construct.Enum(construct.BitField('kind', 2), # noqa
RR=0b00,
RNR=0b01,
REJ=0b10,
),
construct.Const(construct.BitField('discriminator', 2), 0b01),
construct.BitField('ack_number', 7), # N(R) in LAPB
construct.Flag('poll'),
construct.Alias('final', 'poll'),
)
def build_reliable_info_packet(sequence_number, ack_number, poll,
port, information):
return ReliableInfoPacket.build(construct.Container(
sequence_number=sequence_number, ack_number=ack_number, poll=poll,
port=port, information=information, length=len(information)+6,
discriminator=None, padding=b''))
def build_reliable_supervisory_packet(
kind, ack_number, poll=False, final=False):
return ReliableSupervisoryPacket.build(construct.Container(
kind=kind, ack_number=ack_number, poll=poll or final,
final=None, reserved=None, discriminator=None))
class ReliableTransport(object):
'''The reliable transport protocol, also known as TRAIN.
The protocol is based on LAPB from ITU-T Recommendation X.25.
'''
NCP_PROTOCOL_NUMBER = 0xBA33
COMMAND_PROTOCOL_NUMBER = 0x3A33
RESPONSE_PROTOCOL_NUMBER = 0x3A35
MODULUS = 128
max_retransmits = 10 # N2 system parameter in LAPB
retransmit_timeout = 0.2 # T1 system parameter
def __init__(self, interface, link_mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.send_queue = queue.Queue()
self.opened = threading.Event()
self.closed = False
self.last_sent_packet = None
# The sequence number of the next in-sequence I-packet to be Tx'ed
self.send_variable = 0 # V(S) in LAPB
self.retransmit_count = 0
self.waiting_for_ack = False
self.last_ack_number = 0 # N(R) of the most recently received packet
self.transmit_lock = threading.RLock()
self.retransmit_timer = None
# The expected sequence number of the next received I-packet
self.receive_variable = 0 # V(R) in LAPB
self.sockets = {}
self._mtu = link_mtu - 6
self.command_socket = interface.connect(
self.COMMAND_PROTOCOL_NUMBER)
self.response_socket = interface.connect(
self.RESPONSE_PROTOCOL_NUMBER)
self.command_socket.on_packet = self.command_packet_received
self.response_socket.on_packet = self.response_packet_received
self.ncp = TransportControlProtocol(
interface=interface, transport=self,
ncp_protocol=self.NCP_PROTOCOL_NUMBER,
display_name='ReliableControlProtocol')
self.ncp.up()
self.ncp.open()
@property
def mtu(self):
return self._mtu
def reset_stats(self):
self.stats = {
'info_packets_sent': 0,
'info_packets_received': 0,
'retransmits': 0,
'out_of_order_packets': 0,
'round_trip_time': stats.OnlineStatistics(),
}
self.last_packet_sent_time = None
def this_layer_up(self):
self.send_variable = 0
self.receive_variable = 0
self.retransmit_count = 0
self.last_ack_number = 0
self.waiting_for_ack = False
self.reset_stats()
# We can't let PCMP bind itself using the public open_socket
# method as the method will block until self.opened is set, but
# it won't be set until the peer sends us a packet over the
# transport. But we want to bind the port without waiting.
self.pcmp = pcmp.PulseControlMessageProtocol(
self, pcmp.PulseControlMessageProtocol.PORT)
self.sockets[pcmp.PulseControlMessageProtocol.PORT] = self.pcmp
self.pcmp.on_port_closed = self.on_port_closed
# Send an RR command packet to elicit an RR response from the
# remote peer. Receiving a response from the peer confirms that
# the transport is ready to carry traffic, at which point we
# will allow applications to start opening sockets.
self.send_supervisory_command(kind='RR', poll=True)
self.start_retransmit_timer()
def this_layer_down(self):
self.opened.clear()
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.close_all_sockets()
self.logger.info('Info packets sent=%d retransmits=%d',
self.stats['info_packets_sent'],
self.stats['retransmits'])
self.logger.info('Info packets received=%d out-of-order=%d',
self.stats['info_packets_received'],
self.stats['out_of_order_packets'])
self.logger.info('Round-trip %s ms', self.stats['round_trip_time'])
def open_socket(self, port, timeout=30.0, factory=Socket):
if self.closed:
raise ValueError('Cannot open socket on closed transport')
if port in self.sockets and not self.sockets[port].closed:
raise KeyError('Another socket is already opened '
'on port 0x%04x' % port)
if not self.opened.wait(timeout):
return None
socket = factory(self, port)
self.sockets[port] = socket
return socket
def unregister_socket(self, port):
del self.sockets[port]
def down(self):
self.closed = True
self.close_all_sockets()
self.command_socket.close()
self.response_socket.close()
self.ncp.down()
def close_all_sockets(self):
for socket in list(self.sockets.values()):
socket.close()
self.sockets = {}
def on_port_closed(self, closed_port):
self.logger.info('Remote peer says port 0x%04X is closed; '
'closing socket', closed_port)
try:
self.sockets[closed_port].close()
except KeyError:
self.logger.exception('No socket is open on port 0x%04X!',
closed_port)
def _send_info_packet(self, port, information):
packet = build_reliable_info_packet(
sequence_number=self.send_variable,
ack_number=self.receive_variable,
poll=True, port=port, information=information)
self.command_socket.send(packet)
self.stats['info_packets_sent'] += 1
self.last_packet_sent_time = time.time()
def send(self, port, information):
if self.closed:
raise exceptions.TransportNotReady(
'I/O operation on closed transport')
if not self.opened.is_set():
raise exceptions.TransportNotReady(
'Attempted to send a packet while the reliable transport '
'is not open')
if len(information) > self.mtu:
raise ValueError('Packet length (%d) exceeds transport MTU (%d)' % (
len(information), self.mtu))
self.send_queue.put((port, information))
self.pump_send_queue()
def process_ack(self, ack_number):
with self.transmit_lock:
if not self.waiting_for_ack:
# Could be in the timer recovery condition (waiting for
# a response to an RR Poll command). This is a bit
# hacky and should probably be changed to use an
# explicit state machine when this transport is
# extended to support Go-Back-N ARQ.
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.retransmit_count = 0
if (ack_number - 1) % self.MODULUS == self.send_variable:
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.retransmit_count = 0
self.waiting_for_ack = False
self.send_variable = (self.send_variable + 1) % self.MODULUS
if self.last_packet_sent_time:
self.stats['round_trip_time'].update(
(time.time() - self.last_packet_sent_time) * 1000)
def pump_send_queue(self):
with self.transmit_lock:
if not self.waiting_for_ack:
try:
port, information = self.send_queue.get_nowait()
self.last_sent_packet = (port, information)
self.waiting_for_ack = True
self._send_info_packet(port, information)
self.start_retransmit_timer()
except queue.Empty:
pass
def start_retransmit_timer(self):
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = threading.Timer(
self.retransmit_timeout,
self.retransmit_timeout_expired)
self.retransmit_timer.daemon = True
self.retransmit_timer.start()
def retransmit_timeout_expired(self):
with self.transmit_lock:
self.retransmit_count += 1
if self.retransmit_count < self.max_retransmits:
self.stats['retransmits'] += 1
if self.last_sent_packet:
self._send_info_packet(*self.last_sent_packet)
else:
# No info packet to retransmit; must be an RR command
# that needs to be retransmitted.
self.send_supervisory_command(kind='RR', poll=True)
self.start_retransmit_timer()
else:
self.logger.warning('Reached maximum number of retransmit '
'attempts')
self.ncp.restart()
def send_supervisory_command(self, kind, poll=False):
with self.transmit_lock:
command = build_reliable_supervisory_packet(
kind=kind, poll=poll, ack_number=self.receive_variable)
self.command_socket.send(command)
def send_supervisory_response(self, kind, final=False):
with self.transmit_lock:
response = build_reliable_supervisory_packet(
kind=kind, final=final, ack_number=self.receive_variable)
self.response_socket.send(response)
def command_packet_received(self, packet):
if not self.ncp.is_Opened():
self.logger.warning('Received command packet before transport '
'is open. Discarding.')
return
# Information packets have the LSBit of the first byte cleared.
is_info = (bytearray(packet[0])[0] & 0b1) == 0
try:
if is_info:
fields = ReliableInfoPacket.parse(packet)
else:
fields = ReliableSupervisoryPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed command packet')
self.ncp.restart()
return
self.opened.set()
if is_info:
if fields.sequence_number == self.receive_variable:
self.receive_variable = (
self.receive_variable + 1) % self.MODULUS
self.stats['info_packets_received'] += 1
if len(fields.information) + 6 == fields.length:
if fields.port in self.sockets:
self.sockets[fields.port].on_receive(
fields.information)
else:
self.logger.warning(
'Received packet on closed port %04X',
fields.port)
else:
self.logger.error(
'Received truncated or corrupt info packet '
'(expected %d data bytes, got %d)',
fields.length-6, len(fields.information))
else:
self.stats['out_of_order_packets'] += 1
self.send_supervisory_response(kind='RR', final=fields.poll)
else:
if fields.kind not in ('RR', 'REJ'):
self.logger.error('Received a %s command packet, which is not '
'yet supported by this implementation',
fields.kind)
# Pretend it is an RR packet
self.process_ack(fields.ack_number)
if fields.poll:
self.send_supervisory_response(kind='RR', final=True)
self.pump_send_queue()
def response_packet_received(self, packet):
if not self.ncp.is_Opened():
self.logger.error(
'Received response packet before transport is open. '
'Discarding.')
return
# Information packets cannot be responses; we only need to
# handle receiving Supervisory packets.
try:
fields = ReliableSupervisoryPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed response packet')
self.ncp.restart()
return
self.opened.set()
self.process_ack(fields.ack_number)
self.pump_send_queue()
if fields.kind not in ('RR', 'REJ'):
self.logger.error('Received a %s response packet, which is not '
'yet supported by this implementation.',
fields.kind)

View file

@ -0,0 +1,60 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
# To use a consistent encoding
from codecs import open
from os import path
import sys
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
long_description = f.read()
requires = [
'cobs',
'construct>=2.5.3,<2.8',
'pyserial>=2.7,<3',
'transitions>=0.4.0',
]
test_requires = []
if sys.version_info < (3, 3, 0):
test_requires.append('mock>=2.0.0')
if sys.version_info < (3, 4, 0):
requires.append('enum34')
setup(
name='pebble.pulse2',
version='0.0.7',
description='Python tools for connecting to PULSEv2 links',
long_description=long_description,
url='https://github.com/pebble/pulse2',
author='Pebble Technology Corporation',
author_email='cory@pebble.com',
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
namespace_packages = ['pebble'],
install_requires=requires,
extras_require={
'test': test_requires,
},
test_suite = 'tests',
)

View file

@ -0,0 +1,14 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,60 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class FakeTimer(object):
TIMERS = []
def __init__(self, interval, function):
self.interval = interval
self.function = function
self.started = False
self.expired = False
self.cancelled = False
type(self).TIMERS.append(self)
def __repr__(self):
state_flags = ''.join([
'S' if self.started else 'N',
'X' if self.expired else '.',
'C' if self.cancelled else '.'])
return '<FakeTimer({}, {}) {} at {:#x}>'.format(
self.interval, self.function, state_flags, id(self))
def start(self):
if self.started:
raise RuntimeError("threads can only be started once")
self.started = True
def cancel(self):
self.cancelled = True
def expire(self):
'''Simulate the timeout expiring.'''
assert self.started, 'timer not yet started'
assert not self.expired, 'timer can only expire once'
self.expired = True
self.function()
@property
def is_active(self):
return self.started and not self.expired and not self.cancelled
@classmethod
def clear_timer_list(cls):
cls.TIMERS = []
@classmethod
def get_active_timers(cls):
return [t for t in cls.TIMERS if t.is_active]

View file

@ -0,0 +1,156 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import unittest
from pebble.pulse2 import framing
class TestEncodeFrame(unittest.TestCase):
def test_empty_frame(self):
# CRC-32 of nothing is 0
# COBS encoding of b'\0\0\0\0' is b'\x01\x01\x01\x01\x01' (5 bytes)
self.assertEqual(framing.encode_frame(b''),
b'\x55\x01\x01\x01\x01\x01\x55')
def test_simple_data(self):
self.assertEqual(framing.encode_frame(b'abcdefg'),
b'\x55\x0cabcdefg\xa6\x6a\x2a\x31\x55')
def test_flag_in_datagram(self):
# ASCII 'U' is 0x55 hex
self.assertEqual(framing.encode_frame(b'QUACK'),
b'\x55\x0aQ\0ACK\xdf\x8d\x80\x74\x55')
def test_flag_in_fcs(self):
# crc32(b'R') -> 0x5767df55
# Since there is an \x55 byte in the FCS, it must be substituted,
# just like when that byte value is present in the datagram itself.
self.assertEqual(framing.encode_frame(b'R'),
b'\x55\x06R\0\xdf\x67\x57\x55')
class TestFrameSplitter(unittest.TestCase):
def setUp(self):
self.splitter = framing.FrameSplitter()
def test_basic_functionality(self):
self.splitter.write(b'\x55abcdefg\x55foobar\x55asdf\x55')
self.assertEqual(list(self.splitter),
[b'abcdefg', b'foobar', b'asdf'])
def test_wait_for_sync(self):
self.splitter.write(b'garbage data\x55frame 1\x55')
self.assertEqual(list(self.splitter), [b'frame 1'])
def test_doubled_flags(self):
self.splitter.write(b'\x55abcd\x55\x55efgh\x55')
self.assertEqual(list(self.splitter), [b'abcd', b'efgh'])
def test_multiple_writes(self):
self.splitter.write(b'\x55ab')
self.assertEqual(list(self.splitter), [])
self.splitter.write(b'cd\x55')
self.assertEqual(list(self.splitter), [b'abcd'])
def test_lots_of_writes(self):
for char in b'\x55abcd\x55ef':
self.splitter.write(bytearray([char]))
self.assertEqual(list(self.splitter), [b'abcd'])
def test_iteration_pops_frames(self):
self.splitter.write(b'\x55frame 1\x55frame 2\x55frame 3\x55')
self.assertEqual(next(iter(self.splitter)), b'frame 1')
self.assertEqual(list(self.splitter), [b'frame 2', b'frame 3'])
def test_stopiteration_latches(self):
# The iterator protocol requires that once an iterator raises
# StopIteration, it must continue to do so for all subsequent calls
# to its next() method.
self.splitter.write(b'\x55frame 1\x55')
iterator = iter(self.splitter)
self.assertEqual(next(iterator), b'frame 1')
with self.assertRaises(StopIteration):
next(iterator)
next(iterator)
self.splitter.write(b'\x55frame 2\x55')
with self.assertRaises(StopIteration):
next(iterator)
self.assertEqual(list(self.splitter), [b'frame 2'])
def test_max_frame_length(self):
splitter = framing.FrameSplitter(max_frame_length=6)
splitter.write(
b'\x5512345\x55123456\x551234567\x551234\x5512345678\x55')
self.assertEqual(list(splitter), [b'12345', b'123456', b'1234'])
def test_dynamic_max_length_1(self):
self.splitter.write(b'\x5512345')
self.splitter.max_frame_length = 6
self.splitter.write(b'6\x551234567\x551234\x55')
self.assertEqual(list(self.splitter), [b'123456', b'1234'])
def test_dynamic_max_length_2(self):
self.splitter.write(b'\x551234567')
self.splitter.max_frame_length = 6
self.splitter.write(b'89\x55123456\x55')
self.assertEqual(list(self.splitter), [b'123456'])
class TestDecodeTransparency(unittest.TestCase):
def test_easy_decode(self):
self.assertEqual(framing.decode_transparency(b'\x06abcde'), b'abcde')
def test_escaped_flag(self):
self.assertEqual(framing.decode_transparency(b'\x06Q\0ACK'), b'QUACK')
def test_flag_byte_in_frame(self):
with self.assertRaises(framing.DecodeError):
framing.decode_transparency(b'\x06ab\x55de')
def test_truncated_cobs_block(self):
with self.assertRaises(framing.DecodeError):
framing.decode_transparency(b'\x0aabc')
class TestStripFCS(unittest.TestCase):
def test_frame_too_short(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abcd')
def test_good_fcs(self):
self.assertEqual(framing.strip_fcs(b'abcd\x11\xcd\x82\xed'), b'abcd')
def test_frame_corrupted(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abce\x11\xcd\x82\xed')
def test_fcs_corrupted(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abcd\x13\xcd\x82\xed')
class TestDecodeFrame(unittest.TestCase):
def test_it_works(self):
# Not much to test; decode_frame is just chained decode_transparency
# with strip_fcs, and both of those have already been tested separately.
self.assertEqual(framing.decode_frame(b'\x0aQ\0ACK\xdf\x8d\x80t'),
b'QUACK')

View file

@ -0,0 +1,261 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import threading
import time
import unittest
try:
from unittest import mock
except ImportError:
import mock
try:
import queue
except ImportError:
import Queue as queue
from pebble.pulse2 import exceptions, framing, link, ppp
class FakeIOStream(object):
def __init__(self):
self.read_queue = queue.Queue()
self.write_queue = queue.Queue()
self.closed = False
def read(self, length):
if self.closed:
raise IOError('I/O operation on closed FakeIOStream')
try:
return self.read_queue.get(timeout=0.001)
except queue.Empty:
return b''
def write(self, data):
if self.closed:
raise IOError('I/O operation on closed FakeIOStream')
self.write_queue.put(data)
def close(self):
self.closed = True
def pop_all_written_data(self):
data = []
try:
while True:
data.append(self.write_queue.get_nowait())
except queue.Empty:
pass
return data
class TestInterface(unittest.TestCase):
def setUp(self):
self.iostream = FakeIOStream()
self.uut = link.Interface(self.iostream)
self.addCleanup(self.iostream.close)
# Speed up test execution by overriding the LCP timeout
self.uut.lcp.restart_timeout = 0.001
self.uut.lcp.ping = self.fake_ping
self.ping_should_succeed = True
def fake_ping(self, cb, *args, **kwargs):
cb(self.ping_should_succeed)
def test_send_packet(self):
self.uut.send_packet(0x8889, b'data')
self.assertIn(framing.encode_frame(ppp.encapsulate(0x8889, b'data')),
self.iostream.pop_all_written_data())
def test_connect_returns_socket(self):
self.assertIsNotNone(self.uut.connect(0xf0f1))
def test_send_from_socket(self):
socket = self.uut.connect(0xf0f1)
socket.send(b'data')
self.assertIn(framing.encode_frame(ppp.encapsulate(0xf0f1, b'data')),
self.iostream.pop_all_written_data())
def test_interface_closing_closes_sockets_and_iostream(self):
socket1 = self.uut.connect(0xf0f1)
socket2 = self.uut.connect(0xf0f3)
self.uut.close()
self.assertTrue(socket1.closed)
self.assertTrue(socket2.closed)
self.assertTrue(self.iostream.closed)
def test_iostream_closing_closes_interface_and_sockets(self):
socket = self.uut.connect(0xf0f1)
self.iostream.close()
time.sleep(0.01) # Wait for receive thread to notice
self.assertTrue(self.uut.closed)
self.assertTrue(socket.closed)
def test_opening_two_sockets_on_same_protocol_is_an_error(self):
socket1 = self.uut.connect(0xf0f1)
with self.assertRaisesRegexp(ValueError, 'socket is already bound'):
socket2 = self.uut.connect(0xf0f1)
def test_closing_socket_allows_another_to_be_opened(self):
socket1 = self.uut.connect(0xf0f1)
socket1.close()
socket2 = self.uut.connect(0xf0f1)
self.assertIsNot(socket1, socket2)
def test_sending_from_closed_interface_is_an_error(self):
self.uut.close()
with self.assertRaisesRegexp(ValueError, 'closed interface'):
self.uut.send_packet(0x8889, b'data')
def test_get_link_returns_None_when_lcp_is_down(self):
self.assertIsNone(self.uut.get_link(timeout=0))
def test_get_link_from_closed_interface_is_an_error(self):
self.uut.close()
with self.assertRaisesRegexp(ValueError, 'closed interface'):
self.uut.get_link(timeout=0)
def test_get_link_when_lcp_is_up(self):
self.uut.on_link_up()
self.assertIsNotNone(self.uut.get_link(timeout=0))
def test_link_object_is_closed_when_lcp_goes_down(self):
self.uut.on_link_up()
link = self.uut.get_link(timeout=0)
self.assertFalse(link.closed)
self.uut.on_link_down()
self.assertTrue(link.closed)
def test_lcp_bouncing_doesnt_reopen_old_link_object(self):
self.uut.on_link_up()
link1 = self.uut.get_link(timeout=0)
self.uut.on_link_down()
self.uut.on_link_up()
link2 = self.uut.get_link(timeout=0)
self.assertTrue(link1.closed)
self.assertFalse(link2.closed)
def test_close_gracefully_shuts_down_lcp(self):
self.uut.lcp.receive_configure_request_acceptable(0, b'')
self.uut.lcp.receive_configure_ack()
self.uut.close()
self.assertTrue(self.uut.lcp.is_finished.is_set())
def test_ping_failure_triggers_lcp_restart(self):
self.ping_should_succeed = False
self.uut.lcp.restart = mock.Mock()
self.uut.on_link_up()
self.assertIsNone(self.uut.get_link(timeout=0))
self.uut.lcp.restart.assert_called_once_with()
class TestInterfaceSocket(unittest.TestCase):
def setUp(self):
self.interface = mock.MagicMock()
self.uut = link.InterfaceSocket(self.interface, 0xf2f1)
def test_socket_is_not_closed_when_constructed(self):
self.assertFalse(self.uut.closed)
def test_send(self):
self.uut.send(b'data')
self.interface.send_packet.assert_called_once_with(0xf2f1, b'data')
def test_close_sets_socket_as_closed(self):
self.uut.close()
self.assertTrue(self.uut.closed)
def test_close_unregisters_socket_with_interface(self):
self.uut.close()
self.interface.unregister_socket.assert_called_once_with(0xf2f1)
def test_close_calls_on_close_handler(self):
on_close = mock.Mock()
self.uut.on_close = on_close
self.uut.close()
on_close.assert_called_once_with()
def test_send_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.send(b'data')
def test_handle_packet(self):
self.uut.on_packet = mock.Mock()
self.uut.handle_packet(b'data')
self.uut.on_packet.assert_called_once_with(b'data')
def test_handle_packet_does_not_call_on_packet_handler_after_close(self):
on_packet = mock.Mock()
self.uut.on_packet = on_packet
self.uut.close()
self.uut.handle_packet(b'data')
on_packet.assert_not_called()
def test_context_manager(self):
with self.uut as uut:
self.assertIs(self.uut, uut)
self.assertFalse(self.uut.closed)
self.assertTrue(self.uut.closed)
def test_close_is_idempotent(self):
on_close = mock.Mock()
self.uut.on_close = on_close
self.uut.close()
self.uut.close()
self.assertEqual(1, self.interface.unregister_socket.call_count)
self.assertEqual(1, on_close.call_count)
class TestLink(unittest.TestCase):
def setUp(self):
transports_patcher = mock.patch.dict(
link.Link.TRANSPORTS, {'fake': mock.Mock()}, clear=True)
transports_patcher.start()
self.addCleanup(transports_patcher.stop)
self.uut = link.Link(mock.Mock(), 1500)
def test_open_socket(self):
socket = self.uut.open_socket(
transport='fake', port=0xabcd, timeout=1.0)
self.uut.transports['fake'].open_socket.assert_called_once_with(
0xabcd, 1.0)
self.assertIs(socket, self.uut.transports['fake'].open_socket())
def test_down(self):
self.uut.down()
self.assertTrue(self.uut.closed)
self.uut.transports['fake'].down.assert_called_once_with()
def test_on_close_callback_when_going_down(self):
self.uut.on_close = mock.Mock()
self.uut.down()
self.uut.on_close.assert_called_once_with()
def test_open_socket_after_down_is_an_error(self):
self.uut.down()
with self.assertRaisesRegexp(ValueError, 'closed Link'):
self.uut.open_socket('fake', 0xabcd)
def test_open_socket_with_bad_transport_name(self):
with self.assertRaisesRegexp(KeyError, "Unknown transport 'bad'"):
self.uut.open_socket('bad', 0xabcd)

View file

@ -0,0 +1,168 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
try:
from unittest import mock
except ImportError:
import mock
from pebble.pulse2 import pcmp
from .fake_timer import FakeTimer
class TestPCMP(unittest.TestCase):
def setUp(self):
self.uut = pcmp.PulseControlMessageProtocol(mock.Mock(), 1)
def test_close_unregisters_the_socket(self):
self.uut.close()
self.uut.transport.unregister_socket.assert_called_once_with(1)
def test_close_is_idempotent(self):
self.uut.close()
self.uut.close()
self.assertEqual(1, self.uut.transport.unregister_socket.call_count)
def test_send_unknown_code(self):
self.uut.send_unknown_code(42)
self.uut.transport.send.assert_called_once_with(1, b'\x82\x2a')
def test_send_echo_request(self):
self.uut.send_echo_request(b'abcdefg')
self.uut.transport.send.assert_called_once_with(1, b'\x01abcdefg')
def test_send_echo_reply(self):
self.uut.send_echo_reply(b'abcdefg')
self.uut.transport.send.assert_called_once_with(1, b'\x02abcdefg')
def test_on_receive_empty_packet(self):
self.uut.on_receive(b'')
self.uut.transport.send.assert_not_called()
def test_on_receive_message_with_unknown_code(self):
self.uut.on_receive(b'\x00')
self.uut.transport.send.assert_called_once_with(1, b'\x82\x00')
def test_on_receive_malformed_unknown_code_message_1(self):
self.uut.on_receive(b'\x82')
self.uut.transport.send.assert_not_called()
def test_on_receive_malformed_unknown_code_message_2(self):
self.uut.on_receive(b'\x82\x00\x01')
self.uut.transport.send.assert_not_called()
def test_on_receive_discard_request(self):
self.uut.on_receive(b'\x03')
self.uut.transport.send.assert_not_called()
def test_on_receive_discard_request_with_data(self):
self.uut.on_receive(b'\x03asdfasdfasdf')
self.uut.transport.send.assert_not_called()
def test_on_receive_echo_request(self):
self.uut.on_receive(b'\x01')
self.uut.transport.send.assert_called_once_with(1, b'\x02')
def test_on_receive_echo_request_with_data(self):
self.uut.on_receive(b'\x01a')
self.uut.transport.send.assert_called_once_with(1, b'\x02a')
def test_on_receive_echo_reply(self):
self.uut.on_receive(b'\x02')
self.uut.transport.send.assert_not_called()
def test_on_receive_echo_reply_with_data(self):
self.uut.on_receive(b'\x02abc')
self.uut.transport.send.assert_not_called()
def test_on_receive_port_closed_with_no_handler(self):
self.uut.on_receive(b'\x81\xab\xcd')
self.uut.transport.send.assert_not_called()
def test_on_receive_port_closed(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab\xcd')
self.uut.on_port_closed.assert_called_once_with(0xabcd)
def test_on_receive_malformed_port_closed_message_1(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab')
self.uut.on_port_closed.assert_not_called()
def test_on_receive_malformed_port_closed_message_2(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab\xcd\xef')
self.uut.on_port_closed.assert_not_called()
class TestPing(unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.uut = pcmp.PulseControlMessageProtocol(mock.Mock(), 1)
def test_successful_ping(self):
cb = mock.Mock()
self.uut.ping(cb)
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_succeeds_after_retry(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=2)
FakeTimer.TIMERS[-1].expire()
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_succeeds_after_multiple_retries(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
timer1 = FakeTimer.TIMERS[-1]
timer1.expire()
timer2 = FakeTimer.TIMERS[-1]
self.assertIsNot(timer1, timer2)
timer2.expire()
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_failed_ping(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_fails_after_multiple_retries(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
for _ in range(3):
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
self.assertFalse(FakeTimer.get_active_timers())
def test_socket_close_aborts_ping(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
self.uut.close()
cb.assert_not_called()
self.assertFalse(FakeTimer.get_active_timers())

View file

@ -0,0 +1,589 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import unittest
try:
from unittest import mock
except ImportError:
import mock
import construct
from pebble.pulse2 import ppp, exceptions
from .fake_timer import FakeTimer
from . import timer_helper
class TestPPPEncapsulation(unittest.TestCase):
def test_ppp_encapsulate(self):
self.assertEqual(ppp.encapsulate(0xc021, b'Information'),
b'\xc0\x21Information')
class TestPPPUnencapsulate(unittest.TestCase):
def test_ppp_unencapsulate(self):
protocol, information = ppp.unencapsulate(b'\xc0\x21Information')
self.assertEqual((protocol, information), (0xc021, b'Information'))
def test_unencapsulate_empty_frame(self):
with self.assertRaises(ppp.UnencapsulationError):
ppp.unencapsulate(b'')
def test_unencapsulate_too_short_frame(self):
with self.assertRaises(ppp.UnencapsulationError):
ppp.unencapsulate(b'\x21')
def test_unencapsulate_empty_information(self):
protocol, information = ppp.unencapsulate(b'\xc0\x21')
self.assertEqual((protocol, information), (0xc021, b''))
class TestConfigurationOptionsParser(unittest.TestCase):
def test_no_options(self):
options = ppp.OptionList.parse(b'')
self.assertEqual(len(options), 0)
def test_one_empty_option(self):
options = ppp.OptionList.parse(b'\xaa\x02')
self.assertEqual(len(options), 1)
self.assertEqual(options[0].type, 0xaa)
self.assertEqual(options[0].data, b'')
def test_one_option_with_length(self):
options = ppp.OptionList.parse(b'\xab\x07Data!')
self.assertEqual((0xab, b'Data!'), options[0])
def test_multiple_options_empty_first(self):
options = ppp.OptionList.parse(b'\x22\x02\x23\x03a\x21\x04ab')
self.assertEqual([(0x22, b''), (0x23, b'a'), (0x21, b'ab')], options)
def test_multiple_options_dataful_first(self):
options = ppp.OptionList.parse(b'\x31\x08option\x32\x02')
self.assertEqual([(0x31, b'option'), (0x32, b'')], options)
def test_option_with_length_too_short(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x41\x01')
def test_option_list_with_malformed_option(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x0a\x02\x0b\x01\x0c\x03a')
def test_truncated_terminal_option(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x61\x02\x62\x03a\x63\x0ccandleja')
class TestConfigurationOptionsBuilder(unittest.TestCase):
def test_no_options(self):
serialized = ppp.OptionList.build([])
self.assertEqual(b'', serialized)
def test_one_empty_option(self):
serialized = ppp.OptionList.build([ppp.Option(0xaa, b'')])
self.assertEqual(b'\xaa\x02', serialized)
def test_one_option_with_length(self):
serialized = ppp.OptionList.build([ppp.Option(0xbb, b'Data!')])
self.assertEqual(b'\xbb\x07Data!', serialized)
def test_two_options(self):
serialized = ppp.OptionList.build([
ppp.Option(0xcc, b'foo'), ppp.Option(0xdd, b'xyzzy')])
self.assertEqual(b'\xcc\x05foo\xdd\x07xyzzy', serialized)
class TestLCPEnvelopeParsing(unittest.TestCase):
def test_packet_no_padding(self):
parsed = ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcdef')
self.assertEqual(parsed.code, 1)
self.assertEqual(parsed.identifier, 0xab)
self.assertEqual(parsed.data, b'abcdef')
self.assertEqual(parsed.padding, b'')
def test_padding(self):
parsed = ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcdefpadding')
self.assertEqual(parsed.data, b'abcdef')
self.assertEqual(parsed.padding, b'padding')
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcde')
def test_bogus_length(self):
with self.assertRaises(ppp.ParseError):
ppp.LCPEncapsulation.parse(b'\x01\xbc\x00\x03')
def test_empty_data(self):
parsed = ppp.LCPEncapsulation.parse(b'\x03\x01\x00\x04')
self.assertEqual((3, 1, b'', b''), parsed)
class TestLCPEnvelopeBuilder(unittest.TestCase):
def test_build_empty_data(self):
serialized = ppp.LCPEncapsulation.build(1, 0xfe, b'')
self.assertEqual(b'\x01\xfe\x00\x04', serialized)
def test_build_with_data(self):
serialized = ppp.LCPEncapsulation.build(3, 0x2a, b'Hello, world!')
self.assertEqual(b'\x03\x2a\x00\x11Hello, world!', serialized)
class TestProtocolRejectParsing(unittest.TestCase):
def test_protocol_and_info(self):
self.assertEqual((0xabcd, b'asdfasdf'),
ppp.ProtocolReject.parse(b'\xab\xcdasdfasdf'))
def test_empty_info(self):
self.assertEqual((0xf00d, b''),
ppp.ProtocolReject.parse(b'\xf0\x0d'))
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.ProtocolReject.parse(b'\xab')
class TestMagicNumberAndDataParsing(unittest.TestCase):
def test_magic_and_data(self):
self.assertEqual(
(0xabcdef01, b'datadata'),
ppp.MagicNumberAndData.parse(b'\xab\xcd\xef\x01datadata'))
def test_magic_no_data(self):
self.assertEqual(
(0xfeedface, b''),
ppp.MagicNumberAndData.parse(b'\xfe\xed\xfa\xce'))
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.MagicNumberAndData.parse(b'abc')
class TestMagicNumberAndDataBuilder(unittest.TestCase):
def test_build_empty_data(self):
serialized = ppp.MagicNumberAndData.build(0x12345678, b'')
self.assertEqual(b'\x12\x34\x56\x78', serialized)
def test_build_with_data(self):
serialized = ppp.MagicNumberAndData.build(0xabcdef01, b'foobar')
self.assertEqual(b'\xab\xcd\xef\x01foobar', serialized)
def test_build_with_named_attributes(self):
serialized = ppp.MagicNumberAndData.build(magic_number=0, data=b'abc')
self.assertEqual(b'\0\0\0\0abc', serialized)
class TestControlProtocolRestartTimer(unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.uut = ppp.ControlProtocol()
self.uut.timeout_retry = mock.Mock()
self.uut.timeout_giveup = mock.Mock()
self.uut.restart_count = 5
def test_timeout_event_called_if_generation_ids_match(self):
self.uut.restart_timer_expired(self.uut.restart_timer_generation_id)
self.uut.timeout_retry.assert_called_once_with()
def test_timeout_event_not_called_if_generation_ids_mismatch(self):
self.uut.restart_timer_expired(42)
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_not_called_after_stopped(self):
self.uut.start_restart_timer(1)
self.uut.stop_restart_timer()
FakeTimer.TIMERS[-1].expire()
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_not_called_from_old_timer_after_restart(self):
self.uut.start_restart_timer(1)
zombie_timer = FakeTimer.get_active_timers()[-1]
self.uut.start_restart_timer(1)
zombie_timer.expire()
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_called_only_once_after_restart(self):
self.uut.start_restart_timer(1)
self.uut.start_restart_timer(1)
for timer in FakeTimer.TIMERS:
timer.expire()
self.uut.timeout_retry.assert_called_once_with()
self.uut.timeout_giveup.assert_not_called()
class InstrumentedControlProtocol(ppp.ControlProtocol):
methods_to_mock = (
'this_layer_up this_layer_down this_layer_started '
'this_layer_finished send_packet start_restart_timer '
'stop_restart_timer').split()
attributes_to_mock = ('restart_timer',)
def __init__(self):
ppp.ControlProtocol.__init__(self)
for method in self.methods_to_mock:
setattr(self, method, mock.Mock())
for attr in self.attributes_to_mock:
setattr(self, attr, mock.NonCallableMock())
class ControlProtocolTestMixin(object):
CONTROL_CODE_ENUM = ppp.ControlCode
def _map_control_code(self, code):
try:
return int(code)
except ValueError:
return self.CONTROL_CODE_ENUM[code].value
def assert_packet_sent(self, code, identifier, body=b''):
self.fsm.send_packet.assert_called_once_with(
ppp.LCPEncapsulation.build(
self._map_control_code(code), identifier, body))
self.fsm.send_packet.reset_mock()
def incoming_packet(self, code, identifier, body=b''):
self.fsm.packet_received(
ppp.LCPEncapsulation.build(self._map_control_code(code),
identifier, body))
class TestControlProtocolFSM(ControlProtocolTestMixin, unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.fsm = InstrumentedControlProtocol()
def test_open_down(self):
self.fsm.open()
self.fsm.this_layer_started.assert_called_once_with()
self.fsm.this_layer_up.assert_not_called()
self.fsm.this_layer_down.assert_not_called()
self.fsm.this_layer_finished.assert_not_called()
def test_closed_up(self):
self.fsm.up(mock.Mock())
self.fsm.this_layer_up.assert_not_called()
self.fsm.this_layer_down.assert_not_called()
self.fsm.this_layer_started.assert_not_called()
self.fsm.this_layer_finished.assert_not_called()
def test_trivial_handshake(self):
self.fsm.open()
self.fsm.up(mock.Mock())
self.assert_packet_sent('Configure_Request', 0)
self.incoming_packet('Configure_Ack', 0)
self.incoming_packet('Configure_Request', 17)
self.assert_packet_sent('Configure_Ack', 17)
self.assertEqual('Opened', self.fsm.state)
self.assertTrue(self.fsm.this_layer_up.called)
self.assertEqual(self.fsm.restart_count, self.fsm.max_configure)
def test_terminate_cleanly(self):
self.test_trivial_handshake()
self.fsm.close()
self.fsm.this_layer_down.assert_called_once_with()
self.assert_packet_sent('Terminate_Request', 42)
def test_remote_terminate(self):
self.test_trivial_handshake()
self.incoming_packet('Terminate_Request', 42)
self.assert_packet_sent('Terminate_Ack', 42)
self.assertTrue(self.fsm.this_layer_down.called)
self.assertTrue(self.fsm.start_restart_timer.called)
self.fsm.this_layer_finished.assert_not_called()
self.fsm.restart_timer_expired(self.fsm.restart_timer_generation_id)
self.assertTrue(self.fsm.this_layer_finished.called)
self.assertEqual('Stopped', self.fsm.state)
def test_remote_rejects_configure_request_code(self):
self.fsm.open()
self.fsm.up(mock.Mock())
received_packet = self.fsm.send_packet.call_args[0][0]
self.assert_packet_sent('Configure_Request', 0)
self.incoming_packet('Code_Reject', 3, received_packet)
self.assertEqual('Stopped', self.fsm.state)
self.assertTrue(self.fsm.this_layer_finished.called)
def test_receive_extended_code(self):
self.fsm.handle_unknown_code = mock.Mock()
self.test_trivial_handshake()
self.incoming_packet(42, 11, b'Life, the universe and everything')
self.fsm.handle_unknown_code.assert_called_once_with(
42, 11, b'Life, the universe and everything')
def test_receive_unimplemented_code(self):
self.test_trivial_handshake()
self.incoming_packet(0x55, 0)
self.assert_packet_sent('Code_Reject', 0, b'\x55\0\0\x04')
def test_code_reject_truncates_rejected_packet(self):
self.test_trivial_handshake()
self.incoming_packet(0xaa, 0x20, 'a'*1496) # 1500-byte Info
self.assert_packet_sent('Code_Reject', 0,
b'\xaa\x20\x05\xdc' + b'a'*1492)
def test_code_reject_identifier_changes(self):
self.test_trivial_handshake()
self.incoming_packet(0xaa, 0)
self.assert_packet_sent('Code_Reject', 0, b'\xaa\0\0\x04')
self.incoming_packet(0xaa, 0)
self.assert_packet_sent('Code_Reject', 1, b'\xaa\0\0\x04')
# Local events: up, down, open, close
# Option negotiation: reject, nak
# Exceptional situations: catastrophic code-reject
# Restart negotiation after opening
# Remote Terminate-Req, -Ack at various points in the lifecycle
# Negotiation infinite loop
# Local side gives up on negotiation
# Corrupt packets received
class TestLCPReceiveEchoRequest(ControlProtocolTestMixin, unittest.TestCase):
CONTROL_CODE_ENUM = ppp.LCPCode
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.fsm = ppp.LinkControlProtocol(mock.Mock())
self.fsm.send_packet = mock.Mock()
self.fsm.state = 'Opened'
def send_echo_request(self, identifier=0, data=b'\0\0\0\0'):
result = self.fsm.handle_unknown_code(
ppp.LCPCode.Echo_Request.value, identifier, data)
self.assertIsNot(result, NotImplemented)
def test_echo_request_is_dropped_when_not_in_opened_state(self):
self.fsm.state = 'Ack-Sent'
self.send_echo_request()
self.fsm.send_packet.assert_not_called()
def test_echo_request_elicits_reply(self):
self.send_echo_request()
self.assert_packet_sent('Echo_Reply', 0, b'\0\0\0\0')
def test_echo_request_with_data_is_echoed_in_reply(self):
self.send_echo_request(5, b'\0\0\0\0datadata')
self.assert_packet_sent('Echo_Reply', 5, b'\0\0\0\0datadata')
def test_echo_request_missing_magic_number_field_is_dropped(self):
self.send_echo_request(data=b'')
self.fsm.send_packet.assert_not_called()
def test_echo_request_with_nonzero_magic_number_is_dropped(self):
self.send_echo_request(data=b'\0\0\0\x01')
self.fsm.send_packet.assert_not_called()
class TestLCPPing(ControlProtocolTestMixin, unittest.TestCase):
CONTROL_CODE_ENUM = ppp.LCPCode
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.fsm = ppp.LinkControlProtocol(mock.Mock())
self.fsm.send_packet = mock.Mock()
self.fsm.state = 'Opened'
def respond_to_ping(self):
[echo_request_packet], _ = self.fsm.send_packet.call_args
self.assertEqual(b'\x09'[0], echo_request_packet[0])
echo_response_packet = b'\x0a' + echo_request_packet[1:]
self.fsm.packet_received(echo_response_packet)
def test_ping_when_lcp_is_not_opened_is_an_error(self):
cb = mock.Mock()
self.fsm.state = 'Ack-Rcvd'
with self.assertRaises(ppp.LinkStateError):
self.fsm.ping(cb)
cb.assert_not_called()
def test_zero_attempts_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), attempts=0)
def test_negative_attempts_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), attempts=-1)
def test_zero_timeout_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), timeout=0)
def test_negative_timeout_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), timeout=-0.1)
def test_straightforward_ping(self):
cb = mock.Mock()
self.fsm.ping(cb)
cb.assert_not_called()
self.assertEqual(1, self.fsm.send_packet.call_count)
self.respond_to_ping()
cb.assert_called_once_with(True)
def test_one_timeout_before_responding(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=2)
FakeTimer.TIMERS[-1].expire()
cb.assert_not_called()
self.assertEqual(2, self.fsm.send_packet.call_count)
self.respond_to_ping()
cb.assert_called_once_with(True)
def test_one_attempt_with_no_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
self.assertEqual(1, self.fsm.send_packet.call_count)
cb.assert_called_once_with(False)
def test_multiple_attempts_with_no_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=2)
timer_one = FakeTimer.TIMERS[-1]
timer_one.expire()
timer_two = FakeTimer.TIMERS[-1]
self.assertIsNot(timer_one, timer_two)
timer_two.expire()
self.assertEqual(2, self.fsm.send_packet.call_count)
cb.assert_called_once_with(False)
def test_late_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
self.respond_to_ping()
cb.assert_called_once_with(False)
def test_this_layer_down_during_ping(self):
cb = mock.Mock()
self.fsm.ping(cb)
self.fsm.this_layer_down()
FakeTimer.TIMERS[-1].expire()
cb.assert_not_called()
def test_echo_reply_with_wrong_identifier(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
[echo_request_packet], _ = self.fsm.send_packet.call_args
echo_response_packet = bytearray(echo_request_packet)
echo_response_packet[0] = 0x0a
echo_response_packet[1] += 1
self.fsm.packet_received(bytes(echo_response_packet))
cb.assert_not_called()
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
def test_echo_reply_with_wrong_data(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
[echo_request_packet], _ = self.fsm.send_packet.call_args
# Generate a syntactically valid Echo-Reply with the right
# identifier but completely different data.
identifier = bytearray(echo_request_packet)[1]
echo_response_packet = bytes(
b'\x0a' + bytearray([identifier]) +
b'\0\x26\0\0\0\0bad reply bad reply bad reply.')
self.fsm.packet_received(bytes(echo_response_packet))
cb.assert_not_called()
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
def test_successive_pings_use_different_identifiers(self):
self.fsm.ping(mock.Mock(), attempts=1)
[echo_request_packet_1], _ = self.fsm.send_packet.call_args
identifier_1 = bytearray(echo_request_packet_1)[1]
self.respond_to_ping()
self.fsm.ping(mock.Mock(), attempts=1)
[echo_request_packet_2], _ = self.fsm.send_packet.call_args
identifier_2 = bytearray(echo_request_packet_2)[1]
self.assertNotEqual(identifier_1, identifier_2)
def test_unsolicited_echo_reply_doesnt_break_anything(self):
self.fsm.packet_received(b'\x0a\0\0\x08\0\0\0\0')
def test_malformed_echo_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
# Only three bytes of Magic-Number
self.fsm.packet_received(b'\x0a\0\0\x07\0\0\0')
cb.assert_not_called()
# Trying to start a second ping while the first ping is still happening
def test_starting_a_ping_while_another_is_active_is_an_error(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
cb2 = mock.Mock()
with self.assertRaises(exceptions.AlreadyInProgressError):
self.fsm.ping(cb2, attempts=1)
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
cb2.assert_not_called()
# General tests:
# - Length too short for a valid packet
# - Packet truncated (length field > packet len)
# - Packet with padding
# OptionList codes:
# 1 Configure-Request
# 2 Configure-Ack
# 3 Configure-Nak
# 4 Configure-Reject
# Raw data codes:
# 5 Terminate-Request
# 6 Terminate-Ack
# 7 Code-Reject
# 8 Protocol-Reject
# - Empty Rejected-Information field
# - Rejected-Protocol field too short
# Magic number + data codes:
# 10 Echo-Reply
# 11 Discard-Request
# 12 Identification (RFC 1570)

View file

@ -0,0 +1,538 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import threading
import unittest
try:
from unittest import mock
except ImportError:
import mock
import construct
from pebble.pulse2 import exceptions, pcmp, transports
from .fake_timer import FakeTimer
from . import timer_helper
# Save a reference to the real threading.Timer for tests which need to
# use timers even while threading.Timer is patched with FakeTimer.
RealThreadingTimer = threading.Timer
class CommonTransportBeforeOpenedTestCases(object):
def test_send_raises_exception(self):
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xdead, b'not gonna get through')
def test_open_socket_returns_None_when_ncp_fails_to_open(self):
self.assertIsNone(self.uut.open_socket(0xbeef, timeout=0))
class CommonTransportTestCases(object):
def test_send_raises_exception_after_transport_is_closed(self):
self.uut.down()
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xaaaa, b'asdf')
def test_socket_is_closed_when_transport_is_closed(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.down()
self.assertTrue(socket.closed)
with self.assertRaises(exceptions.SocketClosed):
socket.send(b'foo')
def test_opening_two_sockets_on_same_port_is_an_error(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(KeyError):
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_closing_a_socket_allows_another_to_be_opened(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
socket1.close()
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_opening_socket_fails_after_transport_down(self):
self.uut.this_layer_down()
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
def test_opening_socket_succeeds_after_transport_bounces(self):
self.uut.this_layer_down()
self.uut.this_layer_up()
self.uut.open_socket(0xabcd, timeout=0)
class TestBestEffortTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.is_Opened.return_value = False
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened.return_value = True
def on_ping(cb, *args):
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
cb(True)
with mock.patch.object(pcmp.PulseControlMessageProtocol, 'ping') \
as mock_ping:
mock_ping.side_effect = on_ping
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestBestEffortTransport(CommonTransportTestCases, unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.receive_configure_request_acceptable(0, [])
self.uut.ncp.receive_configure_ack()
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
def test_send(self):
self.uut.send(0xabcd, b'information')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=15, information=b'information',
padding=b'')))
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
self.assertEqual(b'info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xface, length=8, information=b'info', padding=b'')))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_packet(self):
self.uut.packet_received(b'garbage')
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1497)
def test_transport_down_closes_link_socket_and_ncp(self):
self.uut.down()
self.uut.link_socket.close.assert_called_with()
self.assertIsNone(self.uut.ncp.socket)
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xab\xcd',
padding=b'')))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xaa\xaa',
padding=b'')))
class TestReliableTransportPacketBuilders(unittest.TestCase):
def test_build_info_packet(self):
self.assertEqual(
b'\x1e\x3f\xbe\xef\x00\x14Data goes here',
transports.build_reliable_info_packet(
sequence_number=15, ack_number=31, poll=True,
port=0xbeef, information=b'Data goes here'))
def test_build_receive_ready_packet(self):
self.assertEqual(
b'\x01\x18',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12))
def test_build_receive_ready_poll_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, poll=True))
def test_build_receive_ready_final_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, final=True))
def test_build_receive_not_ready_packet(self):
self.assertEqual(
b'\x05\x18',
transports.build_reliable_supervisory_packet(
kind='RNR', ack_number=12))
def test_build_reject_packet(self):
self.assertEqual(
b'\x09\x18',
transports.build_reliable_supervisory_packet(
kind='REJ', ack_number=12))
class TestReliableTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened = mock.Mock()
self.uut.ncp.is_Opened.return_value = True
self.uut.command_socket.send = lambda packet: (
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True)))
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestReliableTransportConnectionEstablishment(unittest.TestCase):
expected_rr_packet = transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, poll=True)
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
def send_rr_response(self):
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_rr_packet_is_sent_after_this_layer_up_event(self):
self.uut.command_socket.send.assert_called_once_with(
self.expected_rr_packet)
def test_rr_command_is_retransmitted_until_response_is_received(self):
for _ in range(3):
FakeTimer.TIMERS[-1].expire()
self.send_rr_response()
self.assertFalse(FakeTimer.get_active_timers())
self.assertEqual(self.uut.command_socket.send.call_args_list,
[mock.call(self.expected_rr_packet)]*4)
self.assertIsNotNone(self.uut.open_socket(0xabcd, timeout=0))
def test_transport_negotiation_restarts_if_no_responses(self):
for _ in range(self.uut.max_retransmits):
FakeTimer.TIMERS[-1].expire()
self.assertFalse(FakeTimer.get_active_timers())
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
self.uut.ncp.restart.assert_called_once_with()
class TestReliableTransport(CommonTransportTestCases,
unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_send_with_immediate_ack(self):
self.uut.send(0xbeef, b'Just some packet data')
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True,
port=0xbeef, information=b'Just some packet data'))
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_one_timeout_before_ack(self):
self.uut.send(0xabcd, b'this will be sent twice')
active_timers = FakeTimer.get_active_timers()
self.assertEqual(1, len(active_timers))
active_timers[0].expire()
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.command_socket.send.assert_has_calls(
[mock.call(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd,
information=b'this will be sent twice'))]*2)
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_no_response(self):
self.uut.send(0xd00d, b'blarg')
for _ in xrange(self.uut.max_retransmits):
FakeTimer.get_active_timers()[-1].expire()
self.uut.ncp.restart.assert_called_once_with()
def test_receive_info_packet(self):
socket = self.uut.open_socket(0xcafe, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xcafe,
information=b'info'))
self.assertEqual(b'info', socket.receive(block=False))
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
def test_receive_duplicate_packet(self):
socket = self.uut.open_socket(0xba5e, timeout=0)
packet = transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xba5e,
information=b'all your base are belong to us')
self.uut.command_packet_received(packet)
self.assertEqual(b'all your base are belong to us',
socket.receive(block=False))
self.uut.response_socket.reset_mock()
self.uut.command_packet_received(packet)
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_queueing_multiple_packets_to_send(self):
packets = [(0xfeed, b'Some data'),
(0x6789, b'More data'),
(0xfeed, b'Third packet')]
for protocol, information in packets:
self.uut.send(protocol, information)
for seq, (port, information) in enumerate(packets):
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=seq, ack_number=0, poll=True,
port=port, information=information))
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=seq+1, final=True))
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1494)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.command_socket.send.assert_called_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd, information=b'info'))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xabcd,
information=b'info info info'))
self.assertEqual(b'info info info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x3333,
information=b'info'))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_command_packet(self):
self.uut.command_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_receive_malformed_response_packet(self):
self.uut.response_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_transport_down_closes_link_sockets_and_ncp(self):
self.uut.down()
self.uut.command_socket.close.assert_called_with()
self.uut.response_socket.close.assert_called_with()
self.uut.ncp.down.assert_called_with()
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xab\xcd'))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xaa\xaa'))
class TestSocket(unittest.TestCase):
def setUp(self):
self.uut = transports.Socket(mock.Mock(), 1234)
def test_empty_receive_queue(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_empty_receive_queue_blocking(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(timeout=0.001)
def test_receive(self):
self.uut.on_receive(b'data')
self.assertEqual(b'data', self.uut.receive(block=False))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_receive_twice(self):
self.uut.on_receive(b'one')
self.uut.on_receive(b'two')
self.assertEqual(b'one', self.uut.receive(block=False))
self.assertEqual(b'two', self.uut.receive(block=False))
def test_receive_interleaved(self):
self.uut.on_receive(b'one')
self.assertEqual(b'one', self.uut.receive(block=False))
self.uut.on_receive(b'two')
self.assertEqual(b'two', self.uut.receive(block=False))
def test_send(self):
self.uut.send(b'data')
self.uut.transport.send.assert_called_once_with(1234, b'data')
def test_close(self):
self.uut.close()
self.uut.transport.unregister_socket.assert_called_once_with(1234)
def test_send_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.send(b'data')
def test_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(block=False)
def test_blocking_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(timeout=0.001)
def test_close_during_blocking_receive_aborts_the_receive(self):
thread_started = threading.Event()
result = [None]
def test_thread():
thread_started.set()
try:
self.uut.receive(timeout=0.3)
except Exception as e:
result[0] = e
thread = threading.Thread(target=test_thread)
thread.daemon = True
thread.start()
assert thread_started.wait(timeout=0.5)
self.uut.close()
thread.join()
self.assertIsInstance(result[0], exceptions.SocketClosed)
def test_close_is_idempotent(self):
self.uut.close()
self.uut.close()
self.assertEqual(1, self.uut.transport.unregister_socket.call_count)

View file

@ -0,0 +1,25 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
def cancel_all_timers():
'''Cancel all running timer threads in the process.
'''
for thread in threading.enumerate():
try:
thread.cancel()
except AttributeError:
pass