In [4]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

class CA():
    
    def __init__(self, size):
        self.size = size
        self.x = np.arange(size)
        self.y = np.arange(size)
        self.cells = np.zeros((size, size))
        
    def pattern_setup(self, pattern):
        for p in pattern:
            self.cells[p[0] + self.size // 2, p[1] + self.size // 2] = 1
        
    def random_setup(self, n):
        for i in range(n):
            a, b = np.random.randint(0, size, 2)
            self.cells[a, b] = 1
        
    def image_setup(self):
        self.plt = plt.imshow(self.cells, interpolation='nearest', 
                            origin='bottom', 
                            vmin=np.min(self.cells),
                            vmax=np.max(self.cells), 
                            cmap=plt.cm.binary)

    def update(self):
        newcells = np.zeros((self.size, self.size))
        for i in range(self.size):
            for j in range(self.size):
                count = 0
                for a in [(i - 1) % self.size, i, (i + 1) % self.size]:
                    for b in [(j - 1) % self.size, j, (j + 1) % self.size]:
                        if a != i or b != j:
                            count += self.cells[a,b]                       
                if self.cells[i,j] == 1 and (count == 2 or count == 3):
                    newcells[i,j] = 1
                if self.cells[i,j] == 0 and (count == 3):
                    newcells[i,j] = 1
        self.cells = newcells
        
    def plot(self):
        self.plt.set_data(self.cells)
        return self.plt
        
size = 60

fig, ax = plt.subplots()
ax.set_ylim(-1, size)
ax.set_xlim(-1, size)

acorn = ((1, 2), (1, 1), (1, 0), (2, 2), (0, 1))
       
ca = CA(size)
ca.random_setup(900)
ca.pattern_setup(acorn)
ca.image_setup()

def update(data):
    ca.update()
    return ca.plot(),

def data_gen():
    while True: yield 1

ani = animation.FuncAnimation(fig, update, data_gen, blit=False, interval=50)
plt.show()