113 lines
3.2 KiB
Python
113 lines
3.2 KiB
Python
|
import sophus
|
||
|
import sympy
|
||
|
import sys
|
||
|
import unittest
|
||
|
|
||
|
|
||
|
class Complex:
|
||
|
""" Complex class """
|
||
|
|
||
|
def __init__(self, real, imag):
|
||
|
self.real = real
|
||
|
self.imag = imag
|
||
|
|
||
|
def __mul__(self, right):
|
||
|
""" complex multiplication """
|
||
|
return Complex(self.real * right.real - self.imag * right.imag,
|
||
|
self.imag * right.real + self.real * right.imag)
|
||
|
|
||
|
def __add__(self, right):
|
||
|
return Complex(elf.real + right.real, self.imag + right.imag)
|
||
|
|
||
|
def __neg__(self):
|
||
|
return Complex(-self.real, -self.image)
|
||
|
|
||
|
def __truediv__(self, scalar):
|
||
|
""" scalar division """
|
||
|
return Complex(self.real / scalar, self.imag / scalar)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return "( " + repr(self.real) + " + " + repr(self.imag) + "i )"
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
""" We use the following convention [real, imag] """
|
||
|
if key == 0:
|
||
|
return self.real
|
||
|
else:
|
||
|
return self.imag
|
||
|
|
||
|
def squared_norm(self):
|
||
|
""" squared norm when considering the complex number as tuple """
|
||
|
return self.real**2 + self.imag**2
|
||
|
|
||
|
def conj(self):
|
||
|
""" complex conjugate """
|
||
|
return Complex(self.real, -self.imag)
|
||
|
|
||
|
def inv(self):
|
||
|
""" complex inverse """
|
||
|
return self.conj() / self.squared_norm()
|
||
|
|
||
|
@staticmethod
|
||
|
def identity():
|
||
|
return Complex(1, 0)
|
||
|
|
||
|
@staticmethod
|
||
|
def zero():
|
||
|
return Complex(0, 0)
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
if isinstance(self, other.__class__):
|
||
|
return self.real == other.real and self.imag == other.imag
|
||
|
return False
|
||
|
|
||
|
def subs(self, x, y):
|
||
|
return Complex(self.real.subs(x, y), self.imag.subs(x, y))
|
||
|
|
||
|
def simplify(self):
|
||
|
return Complex(sympy.simplify(self.real),
|
||
|
sympy.simplify(self.imag))
|
||
|
|
||
|
@staticmethod
|
||
|
def Da_a_mul_b(a, b):
|
||
|
""" derivatice of complex muliplication wrt left multiplier a """
|
||
|
return sympy.Matrix([[b.real, -b.imag],
|
||
|
[b.imag, b.real]])
|
||
|
|
||
|
@staticmethod
|
||
|
def Db_a_mul_b(a, b):
|
||
|
""" derivatice of complex muliplication wrt right multiplicand b """
|
||
|
return sympy.Matrix([[a.real, -a.imag],
|
||
|
[a.imag, a.real]])
|
||
|
|
||
|
|
||
|
class TestComplex(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
x, y = sympy.symbols('x y', real=True)
|
||
|
u, v = sympy.symbols('u v', real=True)
|
||
|
self.a = Complex(x, y)
|
||
|
self.b = Complex(u, v)
|
||
|
|
||
|
def test_muliplications(self):
|
||
|
product = self.a * self.a.inv()
|
||
|
self.assertEqual(product.simplify(),
|
||
|
Complex.identity())
|
||
|
product = self.a.inv() * self.a
|
||
|
self.assertEqual(product.simplify(),
|
||
|
Complex.identity())
|
||
|
|
||
|
def test_derivatives(self):
|
||
|
d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
|
||
|
(self.a * self.b)[r], self.a[c]))
|
||
|
self.assertEqual(d,
|
||
|
Complex.Da_a_mul_b(self.a, self.b))
|
||
|
d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
|
||
|
(self.a * self.b)[r], self.b[c]))
|
||
|
self.assertEqual(d,
|
||
|
Complex.Db_a_mul_b(self.a, self.b))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|
||
|
print('hello')
|