diff --git a/game/game.py b/game/game.py index d319230..6e1efd5 100644 --- a/game/game.py +++ b/game/game.py @@ -186,23 +186,28 @@ def create_scene(keyboard, mouse): sea_detail_texture = sea.load_detail_texture('data/sea_bump.png') sea_triangles = sea.sea_triangles(64, proj_far_z - 0.1, proj_ratio) + assert tiles_shader.u_height_sampler == tests_shader.u_height_sampler + assert tiles_shader.u_normal_sampler == tests_shader.u_normal_sampler + return SceneNode( PerfNode('frame', - TextureNode({1: heightmap, 2: normalmap}, + TextureNode({tiles_shader.u_height_sampler: heightmap, tiles_shader.u_normal_sampler: normalmap}, FuncNode(update_camera, (mouse, camera, environment)), - TextureNode({0: tiles_texture}, + TextureNode({tiles_shader.u_texture_sampler: tiles_texture}, ShaderNode(tiles_shader, InputNode(tiles_shader, camera, environment), PerfNode('tiles_batch', DrawNode(tiles_batch)))), FuncNode(update_tests, (blob, cube, clouds)), - TextureNode({0: tests_texture}, + TextureNode({tests_shader.u_texture_sampler: tests_texture}, ShaderNode(tests_shader, InputNode(tests_shader, camera, environment), PerfNode('tests_batch', DrawNode(tests_batch))))), FuncNode(update_sea, (camera, sea_phase)), - TextureNode({0: sea_polar_textures, 1: sea_detail_texture}, + TextureNode({ + sea_shader.u_sea_polar_sampler: sea_polar_textures, + sea_shader.u_sea_detail_sampler: sea_detail_texture}, ShaderNode(sea_shader, InputNode(sea_shader, camera, environment, sea_phase), PerfNode('sea_triangles', diff --git a/game/shader.py b/game/shader.py index 9ab5c40..c80095b 100644 --- a/game/shader.py +++ b/game/shader.py @@ -24,7 +24,7 @@ def _filter(line): return line def _subst(line): - if line.startswith('#include '): + if line.startswith('#include'): path = Path('.') / 'game' / 'shaders' lines = [] for name in line.split()[1:]: @@ -45,16 +45,27 @@ def _convert(line): def _parse(shader, vert_lines, frag_lines): uniforms = [] - for line in vert_lines: - if line.startswith('uniform '): - uniforms.append(line.split()[-1].strip(';')) - for line in frag_lines: - if line.startswith('uniform '): + bindings = {} + def collect(line): + if line.startswith('uniform'): name = line.split()[-1].strip(';') if name not in uniforms: uniforms.append(name) + elif line.startswith('layout(binding='): + name = line.split()[-1].strip(';') + value = int(line[line.index('=') + 1 : line.index(')')].strip()) + if name in bindings: + assert value == bindings[name] + else: + bindings[name] = value + for line in vert_lines: + collect(line) + for line in frag_lines: + collect(line) for name in uniforms: setattr(shader, name, resolve_input(shader._shader, bytes(name, 'utf-8'))) + for name, value in bindings.items(): + setattr(shader, name, value) class Shader: __slots__ = '_shader', '__dict__'