
import time, wgpu, ctypes
from typing import Callable
from rendercanvas.auto import RenderCanvas, loop

class Timer:

    def __init__(self):
        self._starting_time = time.time()

    def get_current(self):
        return time.time() - self._starting_time


vertex_shader_source = """

struct VertexInput {
    @builtin(vertex_index) vertex_index : u32,
};

struct Uniform_data {
    time: f32,
    scale: vec2<f32>,
};

@group(0) @binding(0) var<uniform> u_data: Uniform_data;

struct VertexOutput {
    @location(0) position: vec2f,
    @builtin(position) pos: vec4f,
};

@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
    var positions = array<vec2f, 4>(
        vec2f(-1.0, -1.0),
        vec2f(-1.0, 1.0),
        vec2f(1.0, -1.0),
        vec2f(1.0, 1.0),
    );

    let index = i32(in.vertex_index);
    var out: VertexOutput;
    out.pos = vec4f(positions[index]*u_data.scale, 0.0, 1.0);
    out.position = (positions[index] + vec2f(1.0, 1.0)) / 2.0;
    return out;
}
"""


# %% Functions to create wgpu objects

def get_render_pipeline_kwargs(canvas, device, pipeline_layout, render_texture_format):
    context = canvas.get_wgpu_context()
    if render_texture_format is None:
        render_texture_format = context.get_preferred_format(device.adapter)
    context.configure(device=device, format=render_texture_format)

    vertex_shader = device.create_shader_module(code=vertex_shader_source)
    fragment_shader = device.create_shader_module(code=fragment_shader_source)

    return dict(
        layout=pipeline_layout,
        vertex={
            "module": vertex_shader,
            "entry_point": "vs_main",
        },
        primitive={
            # Available topologies : line_list, line_strip, point_list, triangle_list, triangle_strip
            "topology": wgpu.PrimitiveTopology.triangle_strip,
            #"front_face": wgpu.FrontFace.ccw,
            #"cull_mode": wgpu.CullMode.back,
        },
        depth_stencil=None,
        multisample=None,
        fragment={
            "module": fragment_shader,
            "entry_point": "main",
            "targets": [
                {
                    "format": render_texture_format,
                    "blend": {
                        "alpha": {},
                        "color": {},
                    },
                }
            ],
        },
    )


def create_pipeline_layout(device):
    # Create uniform buffer - data is uploaded each frame
    uniform_buffer = device.create_buffer(
        size = ctypes.sizeof(uniform_data),
        usage = wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST,
    )

    # Create another buffer to copy data to it (by mapping it and then copying the data)
    uniform_buffer.copy_buffer = device.create_buffer(
        size = ctypes.sizeof(uniform_data),
        usage = wgpu.BufferUsage.MAP_WRITE | wgpu.BufferUsage.COPY_SRC,
    )


    # We always have two bind groups, so we can play distributing our
    # resources over these two groups in different configurations.
    bind_groups_entries = [[]]
    bind_groups_layout_entries = [[]]

    bind_groups_entries[0].append(
        {
            "binding": 0,
            "resource": {
                "buffer": uniform_buffer,
                "offset": 0,
                "size": uniform_buffer.size,
            },
        }
    )
    bind_groups_layout_entries[0].append(
        {
            "binding": 0,
            "visibility": wgpu.ShaderStage.VERTEX | wgpu.ShaderStage.FRAGMENT,
            "buffer": {},
        }
    )


    # Create the wgpu binding objects
    bind_group_layouts = []
    bind_groups = []

    for entries, layout_entries in zip(
        bind_groups_entries, bind_groups_layout_entries, strict=False
    ):
        bind_group_layout = device.create_bind_group_layout(entries=layout_entries)
        bind_group_layouts.append(bind_group_layout)
        bind_groups.append(
            device.create_bind_group(layout=bind_group_layout, entries=entries)
        )

    pipeline_layout = device.create_pipeline_layout(
        bind_group_layouts=bind_group_layouts
    )

    return pipeline_layout, uniform_buffer, bind_groups



def get_draw_function(
    canvas, device, render_pipeline, uniform_buffer, bind_groups
) -> Callable[[], None]:
    

    def update_transform():

        global uniform_data

        uniform_data[0] = timer.get_current()

        
    def upload_uniform_buffer_sync():
        device.queue.write_buffer(uniform_buffer, 0, uniform_data)

        
    def upload_uniform_buffer_sync2(): # inutilement compliqué
        tmp_buffer = uniform_buffer.copy_buffer
        tmp_buffer.map_sync(wgpu.MapMode.WRITE)
        tmp_buffer.write_mapped(uniform_data)
        tmp_buffer.unmap()
        command_encoder = device.create_command_encoder()
        command_encoder.copy_buffer_to_buffer(
            tmp_buffer, 0, uniform_buffer, 0, ctypes.sizeof(uniform_data)
        )
        device.queue.submit([command_encoder.finish()])


    def draw_frame():
        current_texture_view = (
            canvas.get_context("wgpu").get_current_texture().create_view()
        )
        command_encoder = device.create_command_encoder()
        render_pass = command_encoder.begin_render_pass(
            color_attachments=[
                {
                    "view": current_texture_view,
                    "resolve_target": None,
                    "clear_value": (0, 0, 0, 1),
                    "load_op": wgpu.LoadOp.clear,
                    "store_op": wgpu.StoreOp.store,
                }
            ],
        )

        render_pass.set_pipeline(render_pipeline)
        for bind_group_id, bind_group in enumerate(bind_groups):
            render_pass.set_bind_group(bind_group_id, bind_group)
        render_pass.draw(4, 1, 0, 0)
        render_pass.end()

        device.queue.submit([command_encoder.finish()])

    def draw_frame_sync():
        update_transform()
        upload_uniform_buffer_sync()
        draw_frame()

    return draw_frame_sync


def setup_drawing_sync(
    canvas, power_preference="high-performance", limits=None, format=None
) -> Callable[[], None]:
    """
    Setup to draw on the given canvas.
    The given canvas must implement WgpuCanvasInterface, but nothing more.
    Returns the draw function.
    """

    adapter = wgpu.gpu.request_adapter_sync(power_preference=power_preference)
    device = adapter.request_device_sync(required_limits=limits)

    pipeline_layout, uniform_buffer, bind_groups = create_pipeline_layout(device)
    pipeline_kwargs = get_render_pipeline_kwargs(
        canvas, device, pipeline_layout, format
    )

    render_pipeline = device.create_render_pipeline(**pipeline_kwargs)

    return get_draw_function(
        canvas, device, render_pipeline, uniform_buffer, bind_groups
    )


def key_down_handler(event):
    if event['key'] == 'q':
        loop.stop()

def resize_handler(event):
    global width, height
    new_width, new_height = round(event['width']), round(event['height'])
    print(f'Resize from {width}x{height} to {new_width}x{new_height}')
    width, height = new_width, new_height
    if width / height > ratio:
        uniform_data[2] = height / width * ratio
        uniform_data[3] = 1.0        
    elif width / height < ratio:
        uniform_data[2] = 1.0
        uniform_data[3] = width / height / ratio


width = 500
height = 500
ratio = width / height

FRAGMENT_SHADER_FILENAME='shader.rs'
with open(FRAGMENT_SHADER_FILENAME) as f:
    fragment_shader_source = f.read()

timer = Timer()
uniform_data = (ctypes.c_float*4)(timer.get_current(), 0.0, 1.0, 1.0)

canvas = RenderCanvas(
    title = "Shader Art WebGPU", update_mode = "continuous"
)

canvas.add_event_handler(key_down_handler, "key_down")
canvas.add_event_handler(resize_handler, "resize")

draw_frame = setup_drawing_sync(canvas)
canvas.request_draw(draw_frame)

loop.run()
