9
0
Fork 0
barebox/scripts/remote/ratp.py

774 lines
21 KiB
Python

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import crcmod
import logging
import struct
from enum import Enum
from time import sleep
try:
from time import monotonic
except:
from .missing import monotonic
csum_func = crcmod.predefined.mkCrcFun('xmodem')
class RatpState(Enum):
listen = "listen" # 1
syn_sent = "syn-sent" # 2
syn_received = "syn-received" # 3
established = "established" # 4
fin_wait = "fin-wait" # 5
last_ack = "last-ack" # 6
closing = "closing" # 7
time_wait = "time-wait" # 8
closed = "closed" # 9
class RatpInvalidHeader(ValueError):
pass
class RatpInvalidPayload(ValueError):
pass
class RatpError(ValueError):
pass
class RatpPacket(object):
def __init__(self, data=None, flags=''):
self.payload = None
self.synch = 0x01
self._control = 0
self.length = 0
self.csum = 0
self.c_syn = False
self.c_ack = False
self.c_fin = False
self.c_rst = False
self.c_sn = 0
self.c_an = 0
self.c_eor = False
self.c_so = False
if data:
(self.synch, self._control, self.length, self.csum) = \
struct.unpack('!BBBB', data)
if self.synch != 0x01:
raise RatpInvalidHeader("invalid synch octet (%x != %x)" %
(self.synch, 0x01))
csum = (self._control + self.length + self.csum) & 0xff
if csum != 0xff:
raise RatpInvalidHeader("invalid csum octet (%x != %x)" %
(csum, 0xff))
self._unpack_control()
elif flags:
if 'S' in flags:
self.c_syn = True
if 'A' in flags:
self.c_ack = True
if 'F' in flags:
self.c_fin = True
if 'R' in flags:
self.c_rst = True
if 'E' in flags:
self.c_eor = True
def __repr__(self):
s = "RatpPacket("
if self.c_syn:
s += "SYN,"
if self.c_ack:
s += "ACK,"
if self.c_fin:
s += "FIN,"
if self.c_rst:
s += "RST,"
s += "SN=%i,AN=%i," % (self.c_sn, self.c_an)
if self.c_eor:
s += "EOR,"
if self.c_so:
s += "SO,DATA=%i)" % self.length
else:
s += "DATA=%i)" % self.length
return s
def _pack_control(self):
self._control = 0 | \
self.c_syn << 7 | \
self.c_ack << 6 | \
self.c_fin << 5 | \
self.c_rst << 4 | \
self.c_sn << 3 | \
self.c_an << 2 | \
self.c_eor << 1 | \
self.c_so << 0
def _unpack_control(self):
self.c_syn = bool(self._control & 1 << 7)
self.c_ack = bool(self._control & 1 << 6)
self.c_fin = bool(self._control & 1 << 5)
self.c_rst = bool(self._control & 1 << 4)
self.c_sn = bool(self._control & 1 << 3)
self.c_an = bool(self._control & 1 << 2)
self.c_eor = bool(self._control & 1 << 1)
self.c_so = bool(self._control & 1 << 0)
def pack(self):
self._pack_control()
self.csum = 0
self.csum = (self._control + self.length + self.csum)
self.csum = (self.csum & 0xff) ^ 0xff
return struct.pack('!BBBB', self.synch, self._control, self.length,
self.csum)
def unpack_payload(self, payload):
(c_recv,) = struct.unpack('!H', payload[-2:])
c_calc = csum_func(payload[:-2])
if c_recv != c_calc:
raise RatpInvalidPayload("bad checksum (%04x != %04x)" %
(c_recv, c_calc))
self.payload = payload[:-2]
def pack_payload(self):
c_calc = csum_func(self.payload)
return self.payload+struct.pack('!H', c_calc)
class RatpConnection(object):
def __init__(self):
self._state = RatpState.closed
self._passive = True
self._input = b''
self._s_sn = 0
self._r_sn = 0
self._retrans = None
self._retrans_counter = None
self._retrans_deadline = None
self._r_mdl = None
self._s_mdl = 0xff
self._rx_buf = [] # reassembly buffer
self._rx_queue = []
self._tx_queue = []
self._rtt_alpha = 0.8
self._rtt_beta = 2.0
self._srtt = 0.2
self._rto_min, self._rto_max = 0.2, 1
self._tx_timestamp = None
self.total_retransmits = 0
self.total_crc_errors = 0
def _update_srtt(self, rtt):
self._srtt = (self._rtt_alpha * self._srtt) + \
((1.0 - self._rtt_alpha) * rtt)
logging.info("SRTT: %r", self._srtt)
def _get_rto(self):
return min(self._rto_max,
max(self._rto_min, self._rtt_beta * self._srtt))
def _write(self, pkt):
if pkt.payload or pkt.c_so or pkt.c_syn or pkt.c_rst or pkt.c_fin:
self._s_sn = pkt.c_sn
if not self._retrans:
self._retrans = pkt
self._retrans_counter = 0
else:
self.total_retransmits += 1
self._retrans_counter += 1
if self._retrans_counter > 10:
raise RatpError("Maximum retransmit count exceeded")
self._retrans_deadline = monotonic()+self._get_rto()
logging.info("Write: %r", pkt)
self._write_raw(pkt.pack())
if pkt.payload:
self._write_raw(pkt.pack_payload())
self._tx_timestamp = monotonic()
def _check_rto(self):
if self._retrans is None:
return
if self._retrans_deadline < monotonic():
logging.debug("Retransmit...")
self._write(self._retrans)
def _check_time_wait(self):
if not self._state == RatpState.time_wait:
return
remaining = self._time_wait_deadline - monotonic()
if remaining < 0:
self._state = RatpState.closed
else:
logging.debug("Time-Wait: %.2f remaining" % remaining)
sleep(min(remaining, 0.1))
def _read(self):
if len(self._input) < 4:
self._input += self._read_raw(4-len(self._input))
if len(self._input) < 4:
return
try:
pkt = RatpPacket(data=self._input[:4])
except RatpInvalidHeader as e:
logging.info("%r", e)
self._input = self._input[1:]
return
self._input = self._input[4:]
logging.info("Read: %r", pkt)
if pkt.c_syn or pkt.c_rst or pkt.c_so or pkt.c_fin:
return pkt
if pkt.length == 0:
return pkt
while len(self._input) < pkt.length+2:
self._input += self._read_raw()
try:
pkt.unpack_payload(self._input[:pkt.length+2])
except RatpInvalidPayload as e:
self.total_crc_errors += 1
return
finally:
self._input = self._input[pkt.length+2:]
return pkt
def _close(self):
pass
def _a(self, r):
logging.info("A")
if r.c_rst:
return True
if r.c_ack:
s = RatpPacket(flags='R')
s.c_sn = r.c_an
self._write(s)
return False
if r.c_syn:
self._r_mdl = r.length
s = RatpPacket(flags='SA')
s.c_sn = 0
s.c_an = (r.c_sn + 1) % 2
s.length = self._s_mdl
self._write(s)
self._state = RatpState.syn_received
return False
return False
def _b(self, r):
logging.info("B")
if r.c_ack and r.c_an != (self._s_sn + 1) % 2:
if r.c_rst:
return False
else:
s = RatpPacket(flags='R')
s.c_sn = r.c_an
self._write(s)
return False
if r.c_rst:
if r.c_ack:
self._retrans = None
# FIXME: delete the TCB
self._state = RatpState.closed
return False
else:
return False
if r.c_syn:
if r.c_ack:
self._r_mdl = r.length
self._retrans = None
self._r_sn = r.c_sn
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._state = RatpState.established
return False
else:
self._retrans = None
s = RatpPacket(flags='SA')
s.c_sn = 0
s.c_an = (r.c_sn + 1) % 2
s.length = self._s_mdl
self._write(s)
self._state = RatpState.syn_received
return False
return False
def _c1(self, r):
logging.info("C1")
if r.c_sn != self._r_sn:
return True
if r.c_rst or r.c_fin:
return False
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
return False
def _c2(self, r):
logging.info("C2")
if r.length == 0 and r.c_so == 0:
return True
if r.c_sn != self._r_sn:
return True
if r.c_rst or r.c_fin:
return False
if r.c_syn:
s = RatpPacket(flags='RA')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._retrans = None
# FIXME: inform the user "Error: Connection reset"
self._state = RatpState.closed
return False
# FIXME: only ack duplicate data packages?
# This is not documented in RFC 916
if r.length or r.c_so:
logging.info("C2: duplicate data packet, dropping")
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
return False
def _d1(self, r):
logging.info("D1")
if not r.c_rst:
return True
if self._passive:
self._retrans = None
self._state = RatpState.listen
return False
else:
self._retrans = None
self._state = RatpState.closed
raise RatpError("Connection refused")
def _d2(self, r):
logging.info("D2")
if not r.c_rst:
return True
self._retrans = None
self._state = RatpState.closed
raise RatpError("Connection reset")
def _d3(self, r):
logging.info("C3")
if not r.c_rst:
return True
self._state = RatpState.closed
return False
def _e(self, r):
logging.info("E")
if not r.c_syn:
return True
self._retrans = None
s = RatpPacket(flags='R')
if r.c_ack:
s.c_sn = r.c_an
else:
s.c_sn = 0
self._write(s)
self._state = RatpState.closed
raise RatpError("Connection reset")
def _f1(self, r):
logging.info("F1")
if not r.c_ack:
return False
if r.c_an == (self._s_sn + 1) % 2:
return True
if self._passive:
self._retrans = None
s = RatpPacket(flags='R')
s.c_sn = r.c_an
self._write(s)
self._state = RatpState.listen
return False
else:
self._retrans = None
s = RatpPacket(flags='R')
s.c_sn = r.c_an
self._write(s)
self._state = RatpState.closed
raise RatpError("Connection refused")
def _f2(self, r):
logging.info("F2")
if not r.c_ack:
return False
if r.c_an == (self._s_sn + 1) % 2:
if self._retrans:
self._retrans = None
self._update_srtt(monotonic()-self._tx_timestamp)
# FIXME: inform the user with an "Ok" if a buffer has been
# entirely acknowledged. Another packet containing data may
# now be sent.
return True
return True
def _f3(self, r):
logging.info("F3")
if not r.c_ack:
return False
if r.c_an == (self._s_sn + 1) % 2:
return True
return True
def _g(self, r):
logging.info("G")
if not r.c_rst:
return False
self._retrans = None
if r.c_ack:
s = RatpPacket(flags='R')
s.c_sn = r.c_an
self._write(s)
else:
s = RatpPacket(flags='RA')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
return False
def _h1(self, r):
logging.info("H1")
# FIXME: initial data?
self._state = RatpState.established
self._r_sn = r.c_sn
return False
def _h2(self, r):
logging.info("H2")
if not r.c_fin:
return True
if self._retrans is not None:
# FIXME: inform the user "Warning: Data left unsent.", "Connection closing."
self._retrans = None
s = RatpPacket(flags='FA')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._state = RatpState.last_ack
raise RatpError("Connection closed by remote")
def _h3(self, r):
logging.info("H3")
if not r.c_fin:
# Our fin was lost, rely on retransmission
return False
if r.length or r.c_so:
self._retrans = None
s = RatpPacket(flags='RA')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._state = RatpState.closed
raise RatpError("Connection reset")
if r.c_an == (self._s_sn + 1) % 2:
self._retrans = None
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._time_wait_deadline = monotonic() + self._get_rto()
self._state = RatpState.time_wait
return False
else:
self._retrans = None
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._state = RatpState.closing
return False
def _h4(self, r):
logging.info("H4")
if r.c_an == (self._s_sn + 1) % 2:
self._retrans = None
self._time_wait_deadline = monotonic() + self._get_rto()
self._state = RatpState.time_wait
return False
return False
def _h5(self, r):
logging.info("H5")
if r.c_an == (self._s_sn + 1) % 2:
self._time_wait_deadline = monotonic() + self._get_rto()
self._state = RatpState.time_wait
return False
return False
def _h6(self, r):
logging.info("H6")
if not r.c_ack:
return False
if not r.c_fin:
return False
self._retrans = None
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
self._time_wait_deadline = monotonic() + self._get_rto()
return False
def _i1(self, r):
logging.info("I1")
if r.c_so:
self._r_sn = r.c_sn
self._rx_buf.append(chr(r.length))
elif r.length:
self._r_sn = r.c_sn
self._rx_buf.append(r.payload)
else:
return False
# reassemble
if r.c_eor:
logging.info("Reassembling %i frames", len(self._rx_buf))
self._rx_queue.append(''.join(self._rx_buf))
self._rx_buf = []
s = RatpPacket(flags='A')
s.c_sn = r.c_an
s.c_an = (r.c_sn + 1) % 2
self._write(s)
return False
def _machine(self, pkt):
logging.info("State: %r", self._state)
if self._state == RatpState.listen:
self._a(pkt)
elif self._state == RatpState.syn_sent:
self._b(pkt)
elif self._state == RatpState.syn_received:
self._c1(pkt) and \
self._d1(pkt) and \
self._e(pkt) and \
self._f1(pkt) and \
self._h1(pkt)
elif self._state == RatpState.established:
self._c2(pkt) and \
self._d2(pkt) and \
self._e(pkt) and \
self._f2(pkt) and \
self._h2(pkt) and \
self._i1(pkt)
elif self._state == RatpState.fin_wait:
self._c2(pkt) and \
self._d2(pkt) and \
self._e(pkt) and \
self._f3(pkt) and \
self._h3(pkt)
elif self._state == RatpState.last_ack:
self._c2(pkt) and \
self._d3(pkt) and \
self._e(pkt) and \
self._f3(pkt) and \
self._h4(pkt)
elif self._state == RatpState.closing:
self._c2(pkt) and \
self._d3(pkt) and \
self._e(pkt) and \
self._f3(pkt) and \
self._h5(pkt)
elif self._state == RatpState.time_wait:
self._d3(pkt) and \
self._e(pkt) and \
self._f3(pkt) and \
self._h6(pkt)
elif self._state == RatpState.closed:
self._g(pkt)
def wait(self, deadline):
while deadline is None or deadline > monotonic():
pkt = self._read()
if pkt:
self._machine(pkt)
else:
self._check_rto()
self._check_time_wait()
if not self._retrans or self._rx_queue:
return
def wait1(self, deadline):
while deadline is None or deadline > monotonic():
pkt = self._read()
if pkt:
self._machine(pkt)
else:
self._check_rto()
self._check_time_wait()
if not self._retrans:
return
def listen(self):
logging.info("LISTEN")
self._state = RatpState.listen
def connect(self, timeout=5.0):
deadline = monotonic() + timeout
logging.info("CONNECT")
self._retrans = None
syn = RatpPacket(flags='S')
syn.length = self._s_mdl
self._write(syn)
self._state = RatpState.syn_sent
self.wait(deadline)
def send_one(self, data, eor=True, timeout=1.0):
deadline = monotonic() + timeout
logging.info("SEND_ONE (len=%i, eor=%r)", len(data), eor)
assert self._state == RatpState.established
assert self._retrans is None
snd = RatpPacket(flags='A')
snd.c_eor = eor
snd.c_sn = (self._s_sn + 1) % 2
snd.c_an = (self._r_sn + 1) % 2
snd.length = len(data)
snd.payload = data
self._write(snd)
self.wait1(deadline=None)
def send(self, data, timeout=1.0):
logging.info("SEND (len=%i)", len(data))
while len(data) > 255:
self.send_one(data[:255], eor=False, timeout=timeout)
data = data[255:]
self.send_one(data, eor=True, timeout=timeout)
def recv(self, timeout=1.0):
deadline = monotonic() + timeout
assert self._state == RatpState.established
if self._rx_queue:
return self._rx_queue.pop(0)
self.wait(deadline)
if self._rx_queue:
return self._rx_queue.pop(0)
def close(self, timeout=1.0):
deadline = monotonic() + timeout
logging.info("CLOSE")
if self._state == RatpState.established:
fin = RatpPacket(flags='FA') # FIXME: only F?
fin.c_sn = (self._s_sn + 1) % 2
fin.c_an = (self._r_sn + 1) % 2
self._write(fin)
self._state = RatpState.fin_wait
while deadline > monotonic() and not self._state == RatpState.time_wait:
self.wait(deadline)
while self._state == RatpState.time_wait:
self.wait(None)
if self._state == RatpState.closed:
logging.info("CLOSE: success")
else:
logging.info("CLOSE: failure")
def abort(self):
logging.info("ABORT")
def status(self):
logging.info("STATUS")
return self._state
class SerialRatpConnection(RatpConnection):
def __init__(self, port):
super(SerialRatpConnection, self).__init__()
self.__port = port
self.__port.timeout = 0.01
self.__port.writeTimeout = None
self.__port.flushInput()
def _write_raw(self, data):
if data:
logging.debug("-> %r", bytearray(data))
return self.__port.write(data)
def _read_raw(self, size=1):
data = self.__port.read(size)
if data:
logging.debug("<- %r", bytearray(data))
return data