"""
A simple layer over OpenGL, for ease of use...

Nicolas Belin

Last modification : 14-5-2024

"""
import math
import ctypes
import OpenGL
OpenGL.FULL_LOGGING = True
OpenGL.ERROR_ON_COPY = True
from OpenGL.GL import *

try:
    import pygame
except ModuleNotFoundError:
    try:
        import pygame_sdl2 as pygame # Debian packaging
    except ModuleNotFoundError:
        raise ModuleNotFoundError("No module named 'pygame'")

def verb(self, s= ''):
    name: str  = '' if 'name' not in dir(self) else f'[{self.name}]'
    print(f"{self.object_type}{name}: {s}")

    
debug = True
#verb = lambda self, s: 0

class NumTypes:
    """
    The names of numbers types + conversion functions.
    """
    
    _types: dict = dict()
    _pre: tuple = (
        {'name':'float', 'gl':GL_FLOAT, 'c':ctypes.c_float},        # coordinates
        {'name':'ushort', 'gl':GL_UNSIGNED_SHORT, 'c':ctypes.c_ushort},# indexes
        {'name':'uint', 'gl':GL_UNSIGNED_INT, 'c':ctypes.c_uint},  # indexes or RGBA
        {'name':'ubyte', 'gl':GL_UNSIGNED_BYTE, 'c':ctypes.c_ubyte}, # R, B, G or A
        {'name':'byte', 'gl':GL_BYTE, 'c':ctypes.c_byte})      # normal coordinates

    def is_defined(name):
        return name in NumTypes._types

    def get(name):
        return NumTypes._types[name]

    def initClass():
        for d in NumTypes._pre:
            NumTypes(**d)
    
    def __init__(self, name = None, gl = None, c = None):
        """
        :param name: (str) the name used by this program
        :param gl:   (OpenGL type) the name used by OpenGL functions
        :param c:    (ctypes type) the name used by the ctypes module
        :returns: an object combining the 3 names
        :SE: add the returned object to a dictionnary, for later use
        """
        self.gl = gl
        self.c = c
        self._types[name] = self

    def conv(self, nlist):
        """
        Convert a homogeneous list of numbers to a ctypes array.
        :param nlist: a list, to be converted
        :returns: a ctypes array
        """
        return (self.c * len(nlist))(*nlist)

class Uniform:
    """
    A data input to the vertex shader, constant over a drawing.
    The matrices are to be given by columns.
    """
    def __init__(self, name, utype, array_size=1):
        """
        :param name: (str) the name of the uniform, as used in the shaders
        :param utype: (str) the type of uniform, cf. the NumTypes dictionnary
        :param array_size: (int) the number of elements in the array, if this uniform is to be an array
        :returns: an object, to be used by drawing functions
        """
        self.name = name
        self.object_type = 'Uniform'
        self._transpose = GL_FALSE
        self._locs = dict()
        self._plug_nr = dict()
        self._array_size = array_size
        types = {'float': (self.vec1fv, 'float'),
                 'uint': (self.vec1uiv, 'uint'),
                 'vec2': (self.vec2fv, 'float'),
                 'vec3': (self.vec3fv, 'float'),
                 'vec4': (self.vec4fv, 'float'),
                 'mat2': (self.matrix2fv, 'float'),
                 'mat3': (self.matrix3fv, 'float'),
                 'mat4': (self.matrix4fv, 'float')}
        if utype in types:
            self._activate_function, self._components_type  = types[utype]
            verb(self, f'New {utype}' + (f'[{array_size}]' if array_size > 1 else '') + ' uniform')
        else:
            raise Exception(f"Uniform(): unknown components' type : '{utype}'")
            
        self._value = None # a reference to a ctypes array

    def activate(self, program):
        if self._updated:
            self._activate_function(program)
            self._updated = False
        
    def plug(self, program):
        """
        Prepare the uniform for use with a rendering program.
        :param program: (Program object) the rendering program
        """
        program.use()
        loc = glGetUniformLocation(program.id, self.name)
        if loc == -1:
            raise Exception(f"Uniform['{self.name}'].plug('{program.name}'): no uniform by this name in the vertex shader")
            return False
        self._locs[program.id] = loc
        if program.id in self._plug_nr:
            self._plug_nr[program.id] += 1
        else:
            self._plug_nr[program.id] = 1
        debug and Wgl.check_GL_error(1)
        return True

    def unplug(self, program):
        """
        Remove a program that had been plugged.
        :param program: (Program object) the program
        """
        self._plug_nr[program.id] -= 1
        if self._plug_nr[program.id] == 0:
            del self._locs[program.id]
    
    def vec1fv(self, program):
        """
        This function and the following ones are to be called just before
        calling the drawing functions to initialise the uniform.
        """
        glUniform1fv(self._locs[program.id], self._array_size, self._value)
        
    def vec1uiv(self, program):
        glUniform1uiv(self._locs[program.id], self._array_size, self._value)

    def vec2fv(self, program):
        glUniform2fv(self._locs[program.id], self._array_size, self._value)
        
    def vec3fv(self, program):
        glUniform3fv(self._locs[program.id], self._array_size, self._value)
        
    def vec4fv(self, program):
        glUniform4fv(self._locs[program.id], self._array_size, self._value)
        
    def matrix2fv(self, program):
        glUniformMatrix2fv(self._locs[program.id], self._array_size,
                           self._transpose, self._value)
        
    def matrix3fv(self, program):
        glUniformMatrix3fv(self._locs[program.id], self._array_size,
                           self._transpose, self._value)
        
    def matrix4fv(self, program):
        glUniformMatrix4fv(self._locs[program.id], self._array_size,
                           self._transpose, self._value)
        
    def is_transposed(self, b): 
        """
        The matrix was given by lines.
        :param b: (bool) False if matrix given by columns (default)
        """
        self._transpose = GL_TRUE if b else GL_FALSE

    def set_value(self, v):
        """
        Set the uniform's value.
        :param v: (list or tuple) the initial value of the uniform
        """
        self._value = NumTypes.get(self._components_type).conv(v)
        self._updated = True

    def update_element(self, ve, i):
        """
        Update the value of an element
        No conversion needed.
        :param ve: (python number) the element's new value
        :param i: (int) the indice of the element to be updated 
        """
        self._value[i] = ve
        self._updated = True
    
class Ortho(Uniform):
    """
    A 4x4 matrix uniform : the orthogonal projection on plane (Oxy).
    """
    
    def __init__(self, name, xmin, xmax, ymin, ymax, zmin, zmax):
        """
        Create a 4x4 matrix uniform, projecting orthogonally the cube :
        [xmin;xmax] x [ymin;ymax] x [zmin;zmax]
        on the square :
        [-1;1] x [-1;1] x {0}
        :param name: (str) name of the uniform
        :param xmin, xmax, ymin,...: (float) coordinates of the cubes's vertices
        """
        super().__init__(name, 'mat4')
        self._xmin, self._xmax = xmin, xmax
        self._ymin, self._ymax = ymin, ymax
        self._zmin, self._zmax = zmin, zmax
        self._window_ratio = None
        m = (1.0, 0.0, 0.0, 0.0,
             0.0, 1.0, 0.0, 0.0,
             0.0, 0.0, 1.0, 0.0,
             0.0, 0.0, 0.0, 1.0)
        self.set_value(m)
        self.change(xmin, xmax, ymin, ymax, zmin, zmax)
        
    def change(self, xmin, xmax, ymin, ymax, zmin, zmax):
        """
        Update the value of the matrix, without changing its reference. 
        """
        self._value[0] = 2.0 / (xmax - xmin)
        self._value[5] = 2.0 / (ymax - ymin)
        self._value[10] = 2.0 / (zmax - zmin)
        self._value[12] = -(xmax + xmin) / (xmax - xmin)
        self._value[13] = -(ymax + ymin) / (ymax - ymin)
        self._value[14] = -(zmax + zmin) / (zmax - zmin)
        self._updated = True

    def resize(self, w, h):
        """
        Update the value of the matrix, without changing its reference.
        Upon resizing the viewport, maintain a constant w/h ratio.
        :param w: (float) width of the new viewport
        :param h: (float) height of the new viewport
        :SE: change the value of the uniform
        """
        if self._window_ratio is None:
            self._window_ratio = w / h
        else:
            r = w / h / self._window_ratio
            if r >= 1.0:
                l = (self._xmax - self._xmin) * r / 2.0
                m = (self._xmax + self._xmin ) / 2.0
                self.change(m-l, m+l, self._ymin, self._ymax,
                            self._zmin, self._zmax)
            elif r < 1.0:
                l = (self._ymax - self._ymin) / r / 2.0
                m = (self._ymax + self._ymin ) / 2.0
                self.change(self._xmin, self._xmax, m-l, m+l,
                            self._zmin, self._zmax)
                
            
class Transform2D(Uniform):
    """
    A 4x4 matrix uniform acting on the (Oxy) plane with, in any order :
    * a rotation of angle a (radian) and axis [Oz)
    * a scaling of ratio s centered on O
    * and a translation (tx, ty, 0)
    """

    def __init__(self, name, a = 0.0, s = 1.0, tx = 0.0, ty = 0.0):
        """
        :param name: (str) the uniform's name
        :param a, s, tx, ty: (float) the parameters of the transformations
        """
        super().__init__(name, 'mat4')
        m = (1.0, 0.0, 0.0, 0.0,
             0.0, 1.0, 0.0, 0.0,
             0.0, 0.0, 1.0, 0.0,
             0.0, 0.0, 0.0, 1.0)
        self.set_value(m)
        self.update_rot_scale(math.cos(a), math.sin(a), s)
        self.update_trans(tx, ty)

    def update_rot_scale(self, cos_a, sin_a, s):
        """
        Update the matrix for rotation and scale.
        :param a: (float) the cosinus of the angle.
        :param a: (float) the sinus of the angle.
        :param s: (float) the scale's ratio.
        :SE: change the value of the uniform but keep the reference toward it
        """
        scos = s * cos_a
        ssin = s * sin_a
        self._value[0] = scos
        self._value[1] = ssin
        self._value[4] = -ssin
        self._value[5] = scos
        self._updated = True
        
    def update_trans(self, tx, ty):
        """
        Update the vector of the translation.
        :SE: change the value of the uniform but keep the reference toward it
        """
        self._value[12] = tx
        self._value[13] = ty
        self._updated = True

class Frustum(Uniform):
    """ 
    A 4x4 matrix uniform used for projective projection,
    looking along the (Oz) axis, toward negative numbers.
    That is, the matrix squeezes a frustum along the z axis into the cube :
    [-1;1] x [-1;1] x [-1;1]
    """

    def __init__(self, name, a=0.5, p=1.0, e=10.0, n=1.0, f=-1.0):
        """
        :param name: (str) the uniform's name
        :param a: (float) vertical aperture angle (radian)
        :param p: (float) width/height ratio
        :param e: (float) eye's z
        :param n: (float) near clipping plane z
        :param f: (float) far clipping plane z
        """
        super().__init__(name, 'mat4')
        self._p = p
        m = (1.0, 0.0, 0.0, 0.0,
             0.0, 1.0, 0.0, 0.0,
             0.0, 0.0, 1.0, 0.0,
             0.0, 0.0, 0.0, 1.0)
        self.set_value(m)
        self.change(a, e, n, f, p)

    def change(self, a, e, n, f, p = None, q = None):
        """
        Update the value of the matrix, without changing its reference. 
        :SE: change the value of the uniform
        """
        if p is None:
            p = self._p
        else:
            self._p = p
        if q is None:
            q = (e - n) * math.tan(a / 2.0)
        n -= e
        f -= e
        self._value[0] = 1.0 / (q * p) 
        self._value[5] = 1.0 / q
        self._value[10] = -(f + n) / (n * (f - n))
        self._value[11] = 1.0 / n
        self._value[14] = (f * (n + e) + n * (f + e)) / (n * (f - n))
        self._value[15] = -e / n
        self._updated = True

    def resize(self, w, h):
        """
        Update the value of the matrix, without changing its reference.
        Upon resizing the viewport, maintain a constant w/h ratio.
        :param w: (float) width of the new viewport
        :param h: (float) height of the new viewport
        :SE: change the value of the uniform
        """
        self._value[0] *= self._p * h / w
        self._p = w / h
        self._updated = True
            
class Transform3D(Uniform):
    """
    A 4x4 matrix uniform to place things along the (Oz) axis (negative).
    Move the camera, with 3 successive transformations :
    Rotate( [Oy), phi ) -> Rotate( [Ox), theta ) -> Translate( 0, 0, r )
    """

    def __init__(self, name, theta = 0.0, phi = 0.0, r = 0.0):
        """
        :param name: (str) the uniform's name
        :param theta: (float) angle (radian) of the rotation around axis [Ox)
        :param phi: (float) angle (radian) of the rotation around axis [Oy)
        :param r: (float) translation (0, 0, r)
        """
        super().__init__(name, 'mat4')
        m = (1.0, 0.0, 0.0, 0.0,
             0.0, 1.0, 0.0, 0.0,
             0.0, 0.0, 1.0, 0.0,
             0.0, 0.0, 0.0, 1.0)
        self.set_value(m)
        self.change(theta, phi, r)
        
    def change(self, theta, phi, r):
        """
        Update the value of the matrix, without changing its reference. 
        :SE: change the value of the uniform
        """
        cos_phi = math.cos(phi)
        cos_theta = math.cos(theta)
        sin_phi = math.sin(phi)
        sin_theta = math.sin(theta)
        self._value[0] = cos_phi
        self._value[1] = cos_theta * sin_phi
        self._value[2] = sin_theta * sin_phi
        self._value[4] = -sin_phi
        self._value[5] = cos_theta * cos_phi
        self._value[6] = sin_theta * cos_phi
        self._value[9] = -sin_theta
        self._value[10] = cos_theta
        self._value[14] = r 
        self._updated = True
        
Uniform.Ortho = Ortho
Uniform.Transform2D = Transform2D
Uniform.Frustum = Frustum
Uniform.Transform3D = Transform3D

class VBO:
    """
    Vertex Buffer Object (VBO) are buffers wherein data are stored in the video RAM
    These buffers can contain attributes data or indexes data.
    """
    object_type = 'VBO'

    def __init__(self, type, name = None, usage = 'static'):
        """
        Create a VBO (Vertex Buffer Object).
        :param type: (str) 'attributes' or 'indexes'
        :param name: (str) this VBO's name
        :param usage: (str) 'static' or 'dynamic' (will the data change ?)
        """
        self.id = glGenBuffers(1)
        self.name = self.id if name is None else name
        self.type = type
        self.usage = GL_STATIC_DRAW if usage == 'static' else GL_DYNAMIC_DRAW 
        self.nr = None
        self._divisor = None
        self.attributes = []
        if type == 'indexes':
            self.target = GL_ELEMENT_ARRAY_BUFFER
        elif type == 'attributes':
            self.target = GL_ARRAY_BUFFER
        else:
            raise Exception("VBO(): unknown buffer type '" + type +"'")
        verb(self, f'New {usage} {type} buffer')
        debug and Wgl.check_GL_error(2)
            
    def specify_attributes(self, *attributes_list):
        """
        Define the VBO's attributes.
        :param attributes_list: (list or tuple) the list of attributes

        Attribute : {'name':name, 'dtype':dtype, 'dim':dim, 'divisor':divisor}

        name: the attribute's name as it appears in the shader
        dtype: the type of data consumed by the attribute
        dim: the attribute's dimension
        """
        if self.type != 'attributes':
            raise Exception("VBO.specify_attributes() : wrong VBO type")
        if len(attributes_list) > 1: # Interleaved attributes data
            struct_fields = []
            first = 0
            for d in attributes_list:
                name, dtype, dim = d['name'], d['dtype'], d['dim']
                if not NumTypes.is_defined(dtype):
                    raise Exception(f"VBO.specify_attributes(): "
                                    f" type de données '{dtype}' inconnu")
                ntype = NumTypes.get(dtype)
                d['gltype'] = ntype.gl
                d['first'] = first
                self.attributes.append(d)
                size_vec = ctypes.sizeof(ntype.c) * dim
                if size_vec % 4 != 0:
                    print(f"Warning : alignement problem with attribute '{name}' : "
                          f"size {size_vec} is not divisible by 4.") 
                first += size_vec
                struct_fields.append((name, ntype.c * dim))
            self.stride = first
            class cell_type(ctypes.Structure):
                _fields_ = struct_fields
            self.cell_type = cell_type
            verb(self, f'Contains data for {len(attributes_list)} attributes : ' +
                 ', '.join([f'{a["name"]} of type {a["dtype"]}[{a["dim"]}]' for a in attributes_list]))
        else: # only one attribute data in this VBO
            d = attributes_list[0]
            name, dtype, dim = d['name'], d['dtype'], d['dim']
            if not NumTypes.is_defined(dtype):
                raise Exception(f"VBO.specify_attributes(): "
                                f" type de données '{dtype}' inconnu")
            ntype = NumTypes.get(dtype)
            d['gltype'] = ntype.gl
            d['first'] = 0
            self.attributes.append(d)
            size_vec = ctypes.sizeof(ntype.c) * dim
            if size_vec % 4 != 0:
                print(f"Warning : alignement problem with attribute '{name}' : "
                      f"size {size_vec} is not divisible by 4.") 
            self.stride = size_vec
            self.cell_type = ntype.c * dim
            verb(self, f'Contains data for 1 attribute : {name} of type {dtype}[{dim}]')
        self.cell_size = ctypes.sizeof(self.cell_type)
        return True

    def specify_indexes(self, itype = 'ushort'):
        """
        :param itype: (str) if VBO contains indexes, the type of these indexes
        """
        ntype = NumTypes.get(itype)
        self.cell_type = ntype.c
        self.cell_size = ctypes.sizeof(self.cell_type)
        self.indexes_type = ntype.gl

    def set_divisor(self, divisor):
        '''
        Set VBO's divisor, to do instanced rendering
        :param divisor: (int)
                 if 0, one value of the attributes is used per vertex (default)
                 if n >= 1, one value of the attribute is used per n instances
        '''
        if divisor > 0:
            verb(self, 'Attributes values will be updated only once per ' + ('instance' if divisor == 1 else f'{divisor} instances'))
            self._divisor = divisor

    def set_size(self, cell_nr):
        self.nr = cell_nr * self.cell_size
        glBindBuffer(self.target, self.id)
        glBufferData(self.target, self.nr, None, self.usage)
        debug and Wgl.check_GL_error(3)
        
    def load(self, data, size=None):
        """
        Load the data to the VBO in VRAM
        :param data: (list or tuple) the data to be loaded in the VBO
        :param size: (int) the size of the buffer, if different from data length
        """
        if size is None:
            size = len(data)
        self.nr = size
        glBindBuffer(self.target, self.id)
        glBufferData(self.target, (self.cell_type * self.nr)(*data), self.usage)
        debug and Wgl.check_GL_error(4)
  
    def update(self, data, offset=0):
        """
        Update a part of VBO's data.
        :param data: (list or tuple) the data to be loaded in the VBO
        :param offset: (int) the nombre of cell before inserting data
        """
        glBindBuffer(self.target, self.id)
        glBufferSubData(self.target, offset*self.cell_size, (self.cell_type * len(data))(*data))
        debug and Wgl.check_GL_error(4)
        
    def bind(self, program):
        """
        Bind the VBO to the rendering program (VAO).
        :param program: (Program object) the rendering program
        """
        glBindBuffer(self.target, self.id)
        for attr in self.attributes:
            name = attr['name']
            loc = glGetAttribLocation(program.id, name)
            if loc < 0:
                raise Exception(f"VBO.bind('{program.name}'): attribute '{name}' is not active in vertex shader")
            glVertexAttribPointer(loc, attr['dim'], attr['gltype'], GL_FALSE,
                                  self.stride, ctypes.c_voidp(attr['first']))
            if self._divisor is not None:
                glVertexAttribDivisor(loc, self._divisor)
            glEnableVertexAttribArray(loc)
        debug and Wgl.check_GL_error(5)
        return True

    def delete(self):
        glDeleteBuffers(1, [self.id])
        
class VAO:
    """
    Vertex Array Object (VAO) are VRAM objects that specify the attributes and their VBOs
    When VAO is set, only uniforms are still to be transfered to the VRAM before rendering can start
    """
    object_type = 'VAO'
    has_indexes = False
    primitive_type = None
    vtx_nr = 0
    idx_nr = 0
    
    def __init__(self, ptype, vbos, name = None):
        """
        Create a VAO.
        :param ptype: (str) the type of the primitive to be used
        :param vbos: (list) a list of VBO objects
        :param name: (str) the name of the VAO
        """
        self.id = glGenVertexArrays(1)
        self.name = self.id if name is None else name 
        verb(self, f"Prepare to render {ptype} from " + ', '.join([f'VBO[{vbo.name}]' for vbo in vbos]))
        types = {'triangles': GL_TRIANGLES,
                 'triangle_fan': GL_TRIANGLE_FAN,
                 'triangle_strip': GL_TRIANGLE_STRIP,
                 'lines': GL_LINES,
                 'line_strip': GL_LINE_STRIP,
                 'line_loop': GL_LINE_LOOP,
                 'points': GL_POINTS}
        if not ptype in types:
            self.primitive_type = None
            raise Exception(f"VAO[{self.name}] : unknown draw primitive : '{ptype}'")
        self.primitive_type = types[ptype]
        self.vbos = vbos
        for vbo in vbos:
            self.check_vbo(vbo)
        Wgl.check_GL_error(6)

    def bind(self):
        """
        Bind the VAO.
        :SE: the VAO is bound to its target
        """
        glBindVertexArray(self.id)

    def unbind(self):
        """
        Unbind the VAO.
        :SE: no VAO is bound
        """
        glBindVertexArray(0)
        
        
    def check_vbo(self, vbo):
        """
        Check a VBO length.
        :param vbo: (VBO object) the VBO to be checked
        """
        if vbo.type == 'indexes':
            self.has_indexes = True
            self.nr = vbo.nr
            self.indexes_type = vbo.indexes_type
        else: # vbo.type == 'attribute':
            if not self.vtx_nr and vbo._divisor is None:
                self.vtx_nr = vbo.nr
            elif self.vtx_nr != vbo.nr and vbo._divisor is None:
                raise Exception(f"VAO['{self.name}'].check_vbo(): two VBOs have different vertex counts")
            if not self.has_indexes:
                self.nr = self.vtx_nr

    def draw(self, first, nr, instances_nr):
        """
        Start the drawing.
        :param first: (int) the first element to be drawn
        :param nr: (int) draw only nr elements
        :param instances_nr: (int) how many instances are to be drawn
        """
        if nr is None:
            nr = self.nr
        self.bind()
        if self.has_indexes:
            if instances_nr == 1:
                glDrawElements(self.primitive_type, nr, self.indexes_type,
                               ctypes.c_voidp(0))
            else:
                glDrawElementsInstanced(self.primitive_type, nr, self.indexes_type,
                                        ctypes.c_voidp(0), instances_nr)
        else:
            if instances_nr == 1:
                glDrawArrays(self.primitive_type, first, nr)
            else:
                glDrawArraysInstanced(self.primitive_type, first, nr, instances_nr)
        self.unbind()
        debug and Wgl.check_GL_error(7)

    def plug(self, program):
        """
        All vertex attributes specifications are encapsulated in the VAO.
        :param program: (Program object) the program which is to be used for rendering
        """
        self.bind()
        for vbo in self.vbos:
            vbo.bind(program)
        self.unbind()
        debug and Wgl.check_GL_error(8)

    def delete(self):
        """
        Delete the VAO.
        """
        array = (ctypes.c_uint * 1)(self.id)
        glDeleteVertexArrays(1, array)
        debug and Wgl.check_GL_error(9)

        
class Program:
    """
    A rendering program, that is a rendering pipeline with some compiled and linked shaders.
    """
    
    def __init__(self, wgl, name=None, **shaders_d):
        """
        Create a rendering program, compiling and linking shaders.
        Vertex and fragment shaders are mandatory, geometry shader is optionnal
        :param wgl: (Wgl object) a window to be used for drawing
        :param name: (str) the program's name
        :param vs_filename: (str) the filename of the vertex shader (VS)
        :param fs_filename: (str) the filename of the fragment shader (FS)
        :param gs_filename: (str) the filename of the geometry shader (GS)
        :param vs_source: (str) the source code of the vertex shader (VS)
        :param fs_source: (str) the source code of the fragment shader (FS)
        :param gs_source: (str) the source code of the geometry shader (GS)
        """
        self.is_ready = False
        self.object_type = 'Program'
        self.wgl = wgl
        self.inputs = dict()

        self.id = glCreateProgram()
        self.name = self.id if name is None else name
        vertex_shader = self.create_shader('vs', *self.get_source(shaders_d, 'vs'))
        fragment_shader = self.create_shader('fs', *self.get_source(shaders_d, 'fs'))
        if 'gs_source' in shaders_d or 'gs_filename' in shaders_d:
            geometry_shader = self.create_shader('gs', *self.get_source(shaders_d, 'gs'))
        else:
            geometry_shader = None
        if not vertex_shader or not fragment_shader:
            Wgl.check_GL_error(10)
            return
        for shader in vertex_shader, fragment_shader:
            glAttachShader(self.id, shader)
        if geometry_shader is not None:
            glAttachShader(self.id, geometry_shader)
        Wgl.check_GL_error(11)
        glLinkProgram(self.id)
        Wgl.check_GL_error(12)
        l = glGetProgramiv(self.id, GL_INFO_LOG_LENGTH, None)
        Wgl.check_GL_error(13)
        verb(self, f'Linking the shaders')
        if l > 0:
            print(glGetProgramInfoLog(self.id).decode('UTF-8'))
        if glGetProgramiv(self.id, GL_LINK_STATUS, None) == GL_FALSE:
            raise Exception('Shaders linking error !')
        self.is_ready = True

    def get_source(self, d, stype):
        file_key = stype + '_filename' 
        source_key = stype + '_source'
        if file_key in d:
            filename = d[file_key]
            verb(self, f"Loading {filename}")
            with open(filename, 'r') as shader_file:
                src = shader_file.read()
        elif source_key in d:
            src = d[source_key]
            filename = None
        else:
            raise Exception(f"Program[{self.name}]: require a {stype} source !")
        return src, filename
        
    def create_shader(self, stype, source, filename):
        """
        Built a shader from a source file.
        :param source: (str) the shader's source code
        :param stype: (str) the shader's type
        :param filename: (str) the source filename or None
        :returns: the compiled shader or False if there was an error
        """
        if stype == 'vs':
            t = GL_VERTEX_SHADER
        elif stype == 'fs':
            t = GL_FRAGMENT_SHADER
        elif stype == 'gs':
            t = GL_GEOMETRY_SHADER
        shader = glCreateShader(t)
        glShaderSource(shader, source)
        glCompileShader(shader)
        l = glGetShaderiv(shader, GL_INFO_LOG_LENGTH, None)
        id = stype.upper() if filename is None else filename
        verb(self, f"Compiling {id}")
        if l > 0:
            verb(self, f'[{id}] Compilation log :')
            print(glGetShaderInfoLog(shader).decode('UTF-8'))
        if glGetShaderiv(shader, GL_COMPILE_STATUS, None) == GL_FALSE:
            raise Exception('Shader compilation error !')
        else:
            return shader

    def use(self):
        glUseProgram(self.id)

    def plug(self, name, primitive, vbos, uniforms, first=0, count=None, rendered=True, instances_nr=1):
        """
        Plug some inputs (VBOs + uniforms) to the program.
        :param name: (str) the name of the input
        :param primitive: (str) the type of openGL primitive to render
        :param vbos: (list) a list of VBOs
        :param uniforms: (list) a list of uniforms
        :param first: (int) the first element to render
        :param count: (int) the number of elements to render
        :param rendered: (bool) Do these inputs should be rendered ?
        :param instances_nr: (int) the number of instances of the vbos to be rendered (duplicates)
        :SE: the dictionnary is added to a list of inputs
        """
        for uniform in uniforms:
            uniform.plug(self)
        vao = VAO(primitive, vbos, name=name)
        vao.plug(self)
        self.inputs[name] = {'name': name, 'primitive': primitive, 'vbos': vbos, 'uniforms': uniforms,
                             'first': first, 'count': count, 'rendered': rendered, 'vao': vao,
                             'instances_nr':instances_nr}
        verb(self, f'Plugging in ' + ', '.join([('' if instances_nr == 1 else (str(instances_nr) + '*')) + f'VAO[{vao.name}]']
                                               + [f'Uniform[{u.name}]' for u in uniforms]))
            
        
    def unplug(self, name):
        """
        Remove the inputs by the name 'name' from the program
        """
        input_d = self.inputs[name]
        for uniform in input_d['uniforms']:
            uniform.unplug(self)
        input_d['vao'].delete()
        del self.inputs[name]
        
    def draw(self):
        """
        Render through the program all the inputs previously plugged. 
        """
        self.use()
        for input_d in self.inputs.values():
            if input_d['rendered']:
                for uniform in input_d['uniforms']:
                    uniform.activate(self)
                input_d['vao'].draw(input_d['first'], input_d['count'], input_d['instances_nr'])

    def change_draw_range(self, name, first, nr, instances_nr=1):
        """
        Change the range of the elements to be rendered
        in the 'name' plugged inputs
        :param name: (str) the name of the inputs, as defined upon plugging
        :param first: (int) the offset of the first element to render
        :param nr: (int) the number of element to render
        :param instances_nr: (int) the number of instances of the vbos to be rendered (duplicates)
        """
        if not name in self.inputs:
            raise NameError(f"No plugged input by the name '{name}'")
        self.inputs[name]['first'] = first
        self.inputs[name]['count'] = nr
        self.inputs[name]['instances_nr'] = instances_nr
            
class Wgl:
    """
    The window object, OpenGL side.
    """

    def __init__(self, interface, blending = False, depth_test = False,
                 culling = False):
        """
        Create a wgl object that encapsulate the OpenGL configuration
        at the window level.
        :param interface: (sdl object) the sdl object that provide the window and
        the OpenGL context
        :param blending: (bool) Enable transparency (RGBA)
        :param depth_test: (bool) Only the nearest fragment is kept and become a pixel
        :param culling: (bool) Only display one of triangles faces
        """
        self.object_type = 'Wgl'
        self.sdl = interface
        self.programs = dict()
        NumTypes.initClass()
        self.sdl.register_resize(self)
        self.clear_bits = GL_COLOR_BUFFER_BIT
        self.set_depth_test(depth_test)
        self.set_blending(blending)
        self.set_face_culling(culling)
        assert not (depth_test and blending), "Blending and depth tests aren't compatible"
        verb(self, f"OpenGL version = {glGetString(GL_VERSION).decode('UTF-8')}")
        verb(self, f"OpenGL vendor = {glGetString(GL_VENDOR).decode('UTF-8')}")
        verb(self, f"OpenGL renderer = {glGetString(GL_RENDERER).decode('UTF-8')}")
        verb(self, f"Shading language version = {glGetString(GL_SHADING_LANGUAGE_VERSION).decode('UTF-8')}")
        verb(self, f"Linewidth in {glGetFloat(GL_ALIASED_LINE_WIDTH_RANGE)} pixels")
        Wgl.check_GL_error(14)

    def set_depth_test(self, do_it):
        """
        Do tests on the fragments depth, only the nearest becomes a pixel
        :param do_it: (bool) 
        """
        if do_it:
            glDepthFunc(GL_GEQUAL)
            glEnable(GL_DEPTH_TEST)
            glClearDepth(-1.0)
            self.clear_bits = GL_COLOR_BUFFER_BIT|GL_DEPTH_BUFFER_BIT
        else:
            glDisable(GL_DEPTH_TEST)

    def set_blending(self, do_it):
        """
        Blend fragments to makes pixels
        :param do_it: (bool)
        """
        if do_it:
            glEnable(GL_BLEND)
            glBlendEquation(GL_FUNC_ADD)
            glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
        else:
            glDisable(GL_BLEND)

    def set_face_culling(self, do_it):
        """
        Only the front face of triangles is displayed.
        The front is given by the vertices' order : counter-clockwise.
        """
        if do_it:
            glEnable(GL_CULL_FACE)
        else:
            glDisable(GL_CULL_FACE)

    def resize(self, w, h):
        """
        Define the viewport, that is the window area where OpenGL renders.
        :param w: (int) width in pixels
        :param h: (int) height in pixels
        :SE: the viewport is defined
        """
        verb(self, f"Viewport changed to {w}x{h}")
        glViewport(0, 0, w, h)
        
    def set_clear_color(self, r, g, b, a=255):
        """
        Define the color being used when clearing the screen before rendering.
        :param r, g, b, a: (int) the color's components, bytes
        :SE: this color is used herafter
        """
        glClearColor(r/255, g/255, b/255, a/255)

    def set_lines_antialiasing(self, do_it):
        """
        Do lines have to be antialiased ?
        :param do_it: (bool) yes or no
        """
        if do_it:
            glEnable(GL_LINE_SMOOTH)
        else:
            glDisable(GL_LINE_SMOOTH)
        
    def set_linewidth(self, width):
        """
        Set the width of lines
        :param width: (float) the width in pixels
        """
        glLineWidth(width)
        
    def add_program(self, program):
        """
        Add a rendering program.
        :param program: (Program object) a rendering program object
        :SE: the program is added to the list of programs
        """
        self.programs[program.id] = program

    def clear(self):
        glClear(self.clear_bits)
        
    def refresh(self):
        """
        Start the whole rendering process.
        """
        self.clear()
        for program in self.programs.values():
            program.draw()

    def check_GL_error(pos):
        error_id = glGetError()
        if error_id == GL_NO_ERROR:
            return False
        errors_d = {
            GL_INVALID_ENUM:\
            "INVALID_ENUM : An unacceptable value is specified for an enumerated argument.\
            The offending command is ignored and has no other side effect than to set the error flag.",
            GL_INVALID_VALUE:\
            "INVALID_VALUE: A numeric argument is out of range.\
            The offending command is ignored and has no other side effect than to set the error flag.",
            GL_INVALID_OPERATION:\
            "INVALID_OPERATION: The specified operation is not allowed in the current state.\
            The offending command is ignored and has no other side effect than to set the error flag.",
            GL_INVALID_FRAMEBUFFER_OPERATION:\
            "INVALID_FRAMEBUFFER_OPERATION: The framebuffer object is not complete.\
            The offending command is ignored and has no other side effect than to set the error flag.",
            GL_OUT_OF_MEMORY:\
            "OUT_OF_MEMORY: There is not enough memory left to execute the command.\
            The state of the GL is undefined, except for the state of the error flags,\
            after this error is recorded.",
            GL_STACK_UNDERFLOW:\
            "STACK_UNDERFLOW: An attempt has been made to perform an operation\
            that would cause an internal stack to underflow.",
            GL_STACK_OVERFLOW:\
            "STACK_OVERFLOW: An attempt has been made to perform an operation\
            that would cause an internal stack to overflow."}
        print(f"glGetError() in #{pos} : {errors_d[error_id]}")
        return True


class Sdl:
    """
    Simple Direct Layer (SDL_2) provides the window, the OpenGL context
    and user input, timer events management.
    """
    object_type = 'Sdl'
    pygame = pygame
    is_fullscreen = False
    using_internal_event_loop = True
    looping = True
    call_on_resize = []
    call_on_refresh = []
    keys_pygame = dict()
    keys_unicode = dict()
    keys_scancode = dict()
    user_events_dict = dict()
    
    def __init__(self, caption, width = None, height = None, ratio = None):
        """
        Initialise SDL and in particular its Display component, providing
        a window.
        :param caption: (str) the caption of the window
        :param width: (int) the width of the window
        :param height: (int) the height of the window
        :param ratio: (float) the ration width/height, only used if the laters
        are set to None.
        """
        verb(self, "Initialising SDL2 (pygame)")
        pygame.init()
        if width is None or height is None :
            w, h = pygame.display.list_modes(0, pygame.FULLSCREEN)[0]
            if ratio is None:
                ratio = w / h
            if ratio >= w / h:
                w = 3 * w // 4
                h = int(w / ratio)
            else:
                h = 3 * h // 4
                w = int(h * ratio)
        else:
            w, h = width, height
        self.width = w
        self.height = h
        self.ratio = w / h
        self.set_screen(w, h)
        pygame.display.set_caption(caption)
        pygame.mouse.set_visible(False)
        self._mouse_pos = None
        self.register_key(lambda t: self.stop_looping(), 'q')
        self.register_key(lambda t: t and self.toggle_fullscreen(), 'f')
        self.register_key(lambda t: self.dump_framebuffer(), 'd')
        
    def set_screen(self, w, h):
        """
        Set the dimensions of the screen.
        :param w: (int) screen's width
        :param h: (int) screen's height
        :SE: set or change the screen's size
        """
        flags = pygame.OPENGL | pygame.DOUBLEBUF
        if self.is_fullscreen:
            flags |= pygame.FULLSCREEN
        else:
            flags |= pygame.RESIZABLE

        #self.screen = pygame.display.set_mode((w, h), flags, vsync = 1)
        self.screen = pygame.display.set_mode((w, h), flags)
        
    def stop_looping(self):
        """
        Post a QUIT event.
        """
        pygame.event.post(pygame.event.Event(pygame.QUIT))
    
    def refresh(self):
        """
        Refresh the screen.
        """
        for obj in self.call_on_refresh:
            obj.refresh()
        pygame.display.flip()
        
    def toggle_fullscreen(self):
        """
        Toggle fullscren...
        """
        if self.is_fullscreen:
            self.is_fullscreen = False
            ev = pygame.event.Event(pygame.VIDEORESIZE)
            ev.size = (self.w_old, self.h_old)
            pygame.event.post(ev)
        else:
            self.is_fullscreen = True
            self.w_old, self.h_old = self.width, self.height
            w, h = pygame.display.list_modes(0, pygame.FULLSCREEN)[0]
            ev = pygame.event.Event(pygame.VIDEORESIZE)
            ev.size = (w, h)
            pygame.event.post(ev)
        
        pygame.display.toggle_fullscreen()
        
    def dump_framebuffer(self):
        """
        Dump the rendered pixels (framebuffer) in the file dump.png
        """
        if not 'Image' in globals():
            from PIL import Image
        glReadBuffer(GL_FRONT_LEFT)
        raw = glReadPixels(0, 0, self.width, self.height,
                           GL_RGBA, GL_UNSIGNED_BYTE, None)
        img = Image.frombytes('RGBA', (self.width, self.height), raw)
        img.save('dump.png')

    def get_mouse_position(self):
        """
        Get the mouse position in window coordinates (pixels from the left upper corner)
        """
        return self._mouse_pos
        
    def mouse_left_button(self, up, x, y):
        '''
        This method is called when the mouse left button is pressed or released
        :param up: (bool) True if pressed, False if released
        :param x: (int) the x position of the pointer in window coordinates
        :param y: (int) the y position of the pointer in window coordinates
        '''
        pass

    def mouse_middle_button(self, up, x, y):
        '''
        This method is called when the mouse middle button is pressed or released
        :param up: (bool) True if pressed, False if released
        :param x: (int) the x position of the pointer in window coordinates
        :param y: (int) the y position of the pointer in window coordinates
        '''
        pass

    def mouse_right_button(self, up, x, y):
        '''
        This method is called when the mouse right button is pressed or released
        :param up: (bool) True if pressed, False if released
        :param x: (int) the x position of the pointer in window coordinates
        :param y: (int) the y position of the pointer in window coordinates
        '''
        pass
    
    def mousewheel_up(self, x, y):
        '''
        This method is called when the mouse right button is pressed or released
        :param x: (int) the x position of the pointer in window coordinates
        :param y: (int) the y position of the pointer in window coordinates
        '''
        pass
    
    def mousewheel_down(self, x, y):
        '''
        This method is called when the mouse right button is pressed or released
        :param x: (int) the x position of the pointer in window coordinates
        :param y: (int) the y position of the pointer in window coordinates
        '''
        pass
    
    def register_key(self, func, key_name):
        """
        Register a function to be called when a key is pressed.
        :param func: (function) the function to be called
        :param key_name: (str) the name of the key, as in the 'pygame' object,
                               but without the prefix 'K_'
        """
        if len(key_name) == 1: # solve 'azerty' keyboards problem
            self.keys_unicode[key_name] = func
        else:
            pygame_key_number = pygame.__getattribute__('K_' + key_name)
            self.keys_pygame[pygame_key_number] = func
        
    def register_user_event(self, name, func):
        """
        Register a user event, calling a function when it happens.
        :param name: (str) the name of the event
        :param func: (function) the function to be called
        """
        self.user_events_dict[name] = func

    def register_drawing(self, obj):
        """
        Register an object, whose method refresh() will
        be called upon refreshing the screen.
        :param obj: (object) an object with a method obj.refresh()
        """
        self.call_on_refresh.append(obj)

    def set_timer(self, func, time):
        """
        Set a function to be called at regular time intervals.
        :param func: (function) the function to be called
        :param time: (int) the time interval (in ms)
        """
        self.t0 = pygame.time.get_ticks()
        self.timer_func = func
        pygame.time.set_timer(pygame.USEREVENT+1, time)

    def get_current_time(self):
        """
        Return the current time (ms) since SDL has been initialised.
        """
        return pygame.time.get_ticks()
        
    def register_resize(self, obj):
        """
        Register a objet whose method resize will be called upon resize events,
        and call this fonction a first time.
        :param obj: (object) an object with a method obj.resize(width, height)
        """
        obj.resize(self.width, self.height)
        self.call_on_resize.append(obj)

    def on_resize(self, w, h):
        """
        Called when the window changes size.
        :param w: (int) new window's width
        :param h: (int) new window's height
        """
        if self.width != w or self.height != h:
            self.width, self.height = w, h
            for obj in self.call_on_resize:
                obj.resize(w, h)
            self.refresh()

    def use_internal_event_loop(self, b):
        """
        Use this module event loop or not (default : True)
        Useful when the whole event loop, and in particular the refresh call,
        is outside this module
        :param b: (bool)
        """
        self.using_internal_event_loop = b
        
    def wait_event(self):
        return pygame.event.wait()
            
    def check_event(self, ev = None):
        if not ev:
            ev = pygame.event.poll()
        #if ev.type != pygame.USEREVENT+1:
        #    print(ev)
        if ev.type == pygame.QUIT:
            self.looping = False
        elif ev.type == pygame.KEYDOWN :
            if ev.scancode in self.keys_scancode:
                self.keys_scancode[ev.scancode](True)
            elif ev.unicode in self.keys_unicode:
                # for some reason, the KEYUP event doesn't have
                # a unicode field
                func = self.keys_unicode[ev.unicode]
                self.keys_scancode[ev.scancode] = func
                func(True)
            elif ev.key in self.keys_pygame:
                self.keys_pygame[ev.key](True)
        elif ev.type == pygame.KEYUP:
            if ev.scancode in self.keys_scancode:
                self.keys_scancode[ev.scancode](False)
            elif ev.key in self.keys_pygame:
                self.keys_pygame[ev.key](False)
        elif ev.type == pygame.MOUSEBUTTONDOWN:
            if ev.button == 1:                          # left button
                self.mouse_left_button(True, *ev.pos)
            elif ev.button == 2:                        # middle button
                self.mouse_middle_button(True, *ev.pos)
            elif ev.button == 3:                        # right button
                self.mouse_right_button(True, *ev.pos)
            elif ev.button == 4:                        # mousewheel up
                self.mousewheel_up(*ev.pos)
            elif ev.button == 5:                        # mousewheel down
                self.mousewheel_down(*ev.pos)
        elif ev.type == pygame.MOUSEBUTTONUP:
            if ev.button == 1:
                self.mouse_left_button(False, *ev.pos)
            elif ev.button == 2:
                self.mouse_middle_button(False, *ev.pos)
            elif ev.button == 3:
                self.mouse_right_button(False, *ev.pos)
        elif ev.type == pygame.VIDEORESIZE:
            self.on_resize(*ev.size)
        elif ev.type == pygame.VIDEOEXPOSE:
            self.refresh()
        elif ev.type == pygame.USEREVENT:
            if ev.utype in self.user_events_dict:
                self.user_events_dict[ev.utype](*ev.args)
        elif self.using_internal_event_loop and ev.type == pygame.USEREVENT+1: # Refresh event
            pygame.event.get(pygame.USEREVENT + 1) # remove all other refresh events in the queue
            t1 = self.t0
            self.t0 = pygame.time.get_ticks()
            self.timer_func(self.t0, t1)
        return self.looping 
            

    def loop(self):
        """
        The event loop : wait for a event to happen, then act accordingly
        """
        while self.looping:
            ev = self.wait_event()
            looping = self.check_event(ev)
            
    def quit(self):
        """
        Quit SDL cleanly, closing the window.
        """
        pygame.quit()

class Task:
    """
    The function is called every frame_rate ms.
    If order < 0.0, called before redraw, if order > 0.0, called after redraw.
    Different tasks are called sequentially according to their orders.
    """
    object_type = 'Task'
    id = [0]
    
    def __init__(self, function, order, name=None):
        """
        :param function: (function) a function f that takes 3 arguments t0, t1, tf
                                    and return a tuple (require_redraw, keep_doing)
                                    * t_now (float) present time
                                    * t_last (float) last call time
                                    * t_start (float) time when enqueued
                                    * require_redraw (bool) a redraw is required
                                    * keep_doing (bool) call again next frame
        :param order: (float) when the function will be called, compared to the
                              other enqueued functions/tasks. The redraw task
                              has order 0.0
        :param name: (str) the task name
        """
        self.perform = function
        self._order = order
        self._next = None
        self._enqueue_time = None
        self._when_done = []
        if name is None:
            self.name = str(self.id[0])
            self.id[0] += 1
        else:
            self.name = name
        verb(self, f'New task with order {order}')
        
    def enqueue_when_done(self, other):
        self._when_done.append(other)
        
class Scheduler:
    """
    Schedule the execution of tasks (see the Task class) with
    temporal granularity of frame_rate (milliseconds).
    """
    _redraw_required = True
    object_type = 'Scheduler'
    
    def __init__(self, sdl, fps = 30):
        verb(self, f'Expected frame rate is {fps} fps')
        self._sdl = sdl
        self._frame_rate = 1000 // fps
        self._first_task = None
        self._redraw_nr = 0
        self._start_time = self._sdl.get_current_time()
        self._sdl.set_timer(self.tick, self._frame_rate)
        verb(self, f'Starting now, time is 0 ms')
        self.enqueue_now(Task(self.redraw, 0.0, 'redraw'))
        
    def tick(self, t0, t1):
        task = self._first_task
        prev_task = None
        while task is not None:
            require_redraw, keep_doing = task.perform(t0, t1, task._enqueue_time)
            self._redraw_required |= require_redraw
            if not keep_doing:
                while len(task._when_done) > 0:
                    self.enqueue_now(task._when_done.pop())
                if prev_task is None:
                    self._first_task = task._next
                else:
                    prev_task._next = task._next
            task = task._next
            
    def redraw(self, t0, t1, first_time):
        self._redraw_nr += 1
        if self._redraw_required:
            self._sdl.refresh()
            self._redraw_required = False
        return False, True
            
    def enqueue_now(self, task):
        assert self._start_time >= 0, "Scheduler must be started first !"
        task._enqueue_time = self.time_since_start()
        verb(self, f'Enqueue Task[{task.name}] at {task._enqueue_time} ms')
        if self._first_task is None:
            self._first_task = task
        elif task._order < self._first_task._order:
            task._next = self._first_task
            self._first_task = task
        else:
            next_task = self._first_task._next
            prev_task = self._first_task
            while next_task is not None and next_task._order < task._order:
                prev_task = next_task
                next_task = next_task._next
            task._next = next_task
            prev_task._next = task
            
    def stop(self):
        verb(self, f'Frame rate has been {1000.0 * self._redraw_nr / self.time_since_start():.2f} fps')
        
    def time_since_start(self):
        return self._sdl.get_current_time() - self._start_time
        
        
if __name__ == '__main__':
    vertex_shader_source = """
    #version 330 core

    #define INSTANCES_NR 30

    precision mediump float;

    in vec4 a_position;

    uniform vec3 u_center[INSTANCES_NR];
    uniform mat4 u_matrix3;

    out vec4 color;

    void main() {
       vec4 m = a_position;
       int i = gl_InstanceID;
       m.xy *= u_center[i].z;
       m.xy += u_center[i].xy;
       gl_Position = u_matrix3 *  m;
       color = vec4(vec3(1.0), 0.3);
    }
    """

    fragment_shader_source = """
    #version 330 core

    precision mediump float;

    in vec4 color;
    out vec4 fragment_color;

    void main() {
       fragment_color = color;
    }
    """
    
    def test():
        #from SDL_OpenGL import Sdl, Wgl, Program, VBO, Uniform, Scheduler, Task
        import random

        sdl = Sdl('OpenGL')
        wgl = Wgl(sdl, blending=True)

        
        program = Program(wgl, vs_source=vertex_shader_source, fs_source=fragment_shader_source)
        if not program.is_ready:
            sdl.quit()
        wgl.add_program(program)
        sdl.register_drawing(wgl)

        # Background color :
        wgl.set_clear_color(0, 0, 0)

        # Attributes :
        n = 40
        disk = [(0.0, 0.0)] + [(math.cos(k * 2.0 * math.pi / n), math.sin(k * 2.0 * math.pi / n)) for k in range(n+1)]

        disk_a = VBO('attributes')
        disk_a.specify_attributes({'name':'a_position', 'dtype':'float', 'dim':2})
        disk_a.load(disk)

        # Uniforms :
        xmax = 2.88
        ymax = 1.8
        ortho_u = Uniform.Ortho('u_matrix3', 0.0, xmax, 0.0, ymax, -1.0, 1.0)
        sdl.register_resize(ortho_u)

        instances_nr = 30
        vi = 0.001
        positions = []
        velocities = []
        r = 0.2
        for _ in range(instances_nr):
            r = random.uniform(0.05, 0.5)
            theta = random.uniform(0, math.pi * 2.0)
            positions += [random.uniform(r, xmax-r), random.uniform(r, ymax-r), r]
            velocities += [vi*math.cos(theta), vi*math.sin(theta)]
        center_u = Uniform('u_center', 'vec3', array_size=instances_nr)
        center_u.set_value(positions)

        # Plug the attributes + uniforms in the rendering program :
        program.plug('disk', 'triangle_fan', (disk_a,), (ortho_u, center_u), instances_nr = instances_nr)

        def move(t_now, t_last, t_start):
            for i in range(instances_nr):
                x, y, r = positions[3*i], positions[3*i+1], positions[3*i+2]
                vx, vy = velocities[2*i], velocities[2*i+1]
                dt = t_now - t_last
                vy -= 0.000001 * dt
                x += vx * dt
                y += vy * dt
                if y < r:
                    y = r + r - y
                    vy *= -1.0
                if x > xmax - r:
                    x = 2.0 * (xmax - r) - x
                    vx *= -1.0
                if x < r:
                    x = r + r - x
                    vx *= -1.0
                positions[3*i], positions[3*i+1], positions[3*i+2] = x, y, r
                velocities[2*i], velocities[2*i+1] = vx, vy
            center_u.set_value(positions)
            return True, True

        scheduler = Scheduler(sdl, fps=60)
        move_task = Task(move, -1.0, name='move_disks')

        scheduler.enqueue_now(move_task)
                
        sdl.loop()
        
        scheduler.stop()
        sdl.quit()



    test()
