#-----------------------------------------------------------------------
#
#   Algorithms for Arithmetic in Base Fibonacci
#
#   Greg Ewing, November 2022
#
#   These routines take and return Base Fibonacci (BF) numbers
#   represented as lists of (0, 1).
#
#   The weights of the bits in an n-bit BF number are the Fibonacci
#   numbers [F(n), F(n-1), .., F(2), F(1)] where F(0) == 1, F(1) == 1.
#
#-----------------------------------------------------------------------

# Cache for memoising Fibobacci numbers
_Fc = [1, 1]

def F(n):
  "Returns the nth Fibonacci number, n >= 1, where F(1) ==1, F(2) == 2."
  while len(_Fc) <= n:
    _Fc.append(_Fc[-1] + _Fc[-2])
  return _Fc[n]

#-----------------------------------------------------------------------
#   Conversion
#-----------------------------------------------------------------------

def fib_to_str(f):
  "Returns a string representation of the given BF number."
  return "".join("01"[b] for b in f)

def str_to_fib(s):
  "Returns a BF number given a string representation."
  result = []
  for c in s:
    if c == "0":
      result.append(0)
    elif c == "1":
      result.append(1)
    else:
      raise ValueError("Invalid BF digit %s" % c)
  return result

def int_to_fib(n):
  "Convert unsigned int to BF."
  result = []
  i = 0
  while F(i+1) <= n:
    i += 1
  while i:
    Fi = F(i)
    if Fi <= n:
      result.append(1)
      n -= Fi
    else:
      result.append(0)
    i -= 1
  return result

def int_to_fib_signed(i, n):
  "Convert signed int to n-bit BF."
  if i < 0:
    i += F(n + 1)
  return int_to_fib(i)

def fib_to_list(f):
  "Returns a list of Fibonacci numbers corresponding to the bits set in the given BF number."
  result = []
  for i in range(1, len(f) + 1):
    if f[-i] == 1:
      result.append(F(i))
  return result

def fib_to_int(f):
  "Convert BF to int."
  return sum(fib_to_list(f))

def fib_to_int_signed(f, n):
  "Convert n-bit signed BF to int."
  i = fib_to_int(f)
  Fnplus1 = F(n + 1)
  if i >= Fnplus1 // 2:
    i -= Fnplus1
  return i

def lzstrip_fib(f):
  "Strip leading zeroes from a BF number."
  result = f[:]
  while result and result[0] == 0:
    del result[0]
  return result

def lzpad_fib(f, n):
  "Pad BF number to length n with leading zeroes."
  result = f[:]
  while len(result) < n:
    result.insert(0, 0)
  return result

def zero_extend_fib(f, n):
  "Make BF exactly n bits long by padding with zeroes. Raises ValueError if it does not fit."
  if len(f) > n:
    print("zero_extend_fib: f =", f, "n =", n)
    raise ValueError("BF number exceeds %s bits" % n)
  return lzpad_fib(f, n)

#-----------------------------------------------------------------------
#   Normalisation
#-----------------------------------------------------------------------

_norm_parity_state_transitions = {
  # (state, input): (output, state)
  (0, 0): (0, 0),
  (0, 1): (0, 1),
  (1, 0): (1, 0),
  (1, 1): (1, 0),
}

def _norm_parity(f):
  tbl = _norm_parity_state_transitions
  result = []
  state = 0
  for inp in f:
    out, state = tbl[(state, inp)]
    result.append(out)
  return result

def _expand_norm_state_transitions(inp):
  out = {}
  for (parities, inputs, carries, n), value in inp.items():
    for parity in ((0, 1) if parities == "*" else (parities,)):
      for input in (((0, 0), (0, 1), (1, 0), (1, 1)) if inputs == "**" else (inputs,)):
        for carry in ((0, 1) if carries == "*" else (carries,)):
          out[parity, input, carry, n] = value
  return out

_norm_state_transitions = _expand_norm_state_transitions({
  # (Parity, Input, Carry, N) --> (Output, Carry, N)
  ("*", (0, 0),  0,  0): (0, 0, 0),
  ("*", (0, 1),  0,  0): (1, 0, 0),
  ("*", (1, 0),  0,  0): (0, 0, 0),
  ( 1,  (1, 1),  0,  0): (0, 0, 1),
  ( 0,  (1, 1),  0,  0): (1, 0, 0),
  ("*", (0, 0),  1,  0): (1, 0, 0),
  ( 1,  (1, 0),  1,  0): (0, 0, 1),
  ( 0,  (1, 0),  1,  0): (1, 0, 0),
  ( 1,  (1, 1),  1,  0): (1, 0, 1),
  ("*",  "**",  "*", 1): (0, 1, 0),
})

def normalise_fib(f):
  result = []
  n = len(f)
  f = [0, 0] + f
  parities = _norm_parity(f)
  c = 0
  N = 0
  i = 0
  while i <= n:
    i += 1
    p = parities[-i]
    inp = (f[-i-1], f[-i])
    out, c, N = _norm_state_transitions[p, inp, c, N]
    result.insert(0, out)
  return lzstrip_fib(result)

#-----------------------------------------------------------------------
#   Negation
#-----------------------------------------------------------------------

_negation_end_state_transitions = {
# (State, Input) --> (oOutput, State)
  (0, 0): (1, 0),
  (0, 1): (1, 1),
  (1, 0): (0, 1),
  (1, 1): (0, 1),
}

def _generate_negation_end_bits(f):
  result = []
  tbl = _negation_end_state_transitions
  state = 0
  for i in range(len(f) - 1, -1, -1):
    inp = f[i]
    out, state = tbl[state, inp]
    result.insert(0, out)
  return result

_negation_state_transitions = {
# D == 0: (End, State) --> (Out, State, D)[Input]
  (0, 2): [(0, 3, 0), (0, 3, 0), None     ],
  (0, 3): [(1, 2, 0), (0, 3, 1), (0, 2, 0)],
  (1, 0): [(0, 0, 0), None,      None     ],
  (1, 2): [(1, 0, 0), (0, 3, 0), None     ],
  (1, 3): [None     , (1, 0, 1), (0, 2, 1)],
# D == 1: state --> (Out, State, D)
  0: (0, 0, 0),
  2: (1, 0, 0),
  3: (1, 2, 0),
}

def negate_fib(f):
  tbl = _negation_state_transitions
  n = len(f)
  f = [0] + f + [0, 0]
  E = _generate_negation_end_bits(f)
  state = 2
  result = []
  D = 0
  for i in range(n + 1):
    inp = (f[i] << 1) + f[i+1]
    end = E[i+1]
    if D:
      entry = tbl[state]
    else:
      entry = tbl[end, state][inp]
    out, state, D = entry
    result.append(out)
  if state & 2:
    result[-1] = 1
  return result[1:]

#-----------------------------------------------------------------------
#   Comparison and Sign Testing
#-----------------------------------------------------------------------

_comparison_state_transitions = {
# State --> State[Input]
   0: [ 0, -1,  1,  0],
   1: [ 1,  1,  1,  1],
  -1: [-1, -1, -1, -1],
}

def compare_fibs_unsigned(f1, f2):
  "Unsigned comparison of BF numbers. Returns -1, 0, 1 for <, =, >"
  tbl = _comparison_state_transitions
  n = max(len(f1), len(f2))
  f1 = zero_extend_fib(f1, n)
  f2 = zero_extend_fib(f2, n)
  state = 0
  for i in range(n):
    inp = (f1[i] << 1) + f2[i]
    state = tbl[state][inp]
  return state

_first_negative_state_transitions = {
# State --> (Output, State)
  0: (0, 1),
  1: (1, 2),
  2: (0, 0),
}

def first_negative_fib(n):
  tbl = _first_negative_state_transitions
  result = [0] * n
  state = 0
  i = n
  while i:
    out, state = tbl[state]
    result[-i] = out
    i -= 1
  out, state = tbl[state]
  result[-1] |= out
  return result

def fib_sign(f, n):
  "Returns 1 if n-bit signed BF number is negative, else 0."
  fmin = first_negative_fib(n)
  return 1 if compare_fibs_unsigned(f, fmin) >= 0 else 0
  
def compare_fibs_signed(f1, f2):
  "Signed comparison of BF numbers. Returns -1, 0, 1 for <, =, >"
  tbl = _comparison_state_transitions
  n = max(len(f1), len(f2))
  fmin = first_negative_fib(n)
  sf1 = [fib_sign(f1, n)] + zero_extend_fib(f1, n)
  sf2 = [fib_sign(f2, n)] + zero_extend_fib(f2, n)
  return compare_fibs_unsigned(sf1, sf2)

#-----------------------------------------------------------------------
#   Addition
#-----------------------------------------------------------------------

_addition_state_transition_table = {
# State --> [(Output, State)][Input]
  0: [(0, 0), (0, 4), (0, 9)],
  1: [(0, 2), (0, 8), (1, 0)],
  2: [(0, 4), (0, 9), (1, 1)],
  4: [(0, 8), (1, 0), (1, 4)],
  8: [(1, 0), (1, 4), (1, 9)],
  9: [(1, 2), (1, 8), None         ],
}

def add_fibs(f1, f2):
  "Add BF numbers."
  tbl = _addition_state_transition_table
  n = max(len(f1), len(f2))
  f1 = zero_extend_fib(f1, n) + [0, 0]
  f2 = zero_extend_fib(f2, n) + [0, 0]
  result = []
  state = 0
  for i in range(n + 2):
    inp = f1[i] + f2[i]
    out, state = tbl[state][inp]
    result.append(out)
  if state & 8:
    result[-1] = 1
  return normalise_fib(result)
