summaryrefslogtreecommitdiffstats
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp43
1 files changed, 36 insertions, 7 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index bf2210899..01b77a7d1 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -140,7 +140,27 @@ Id DefineVariable(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin
return id;
}
+u32 NumVertices(InputTopology input_topology) {
+ switch (input_topology) {
+ case InputTopology::Points:
+ return 1;
+ case InputTopology::Lines:
+ return 2;
+ case InputTopology::LinesAdjacency:
+ return 4;
+ case InputTopology::Triangles:
+ return 3;
+ case InputTopology::TrianglesAdjacency:
+ return 6;
+ }
+ throw InvalidArgument("Invalid input topology {}", input_topology);
+}
+
Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
+ if (ctx.stage == Stage::Geometry) {
+ const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
+ type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+ }
return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
}
@@ -455,12 +475,16 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
void EmitContext::DefineAttributeMemAccess(const Info& info) {
const auto make_load{[&] {
+ const bool is_array{stage == Stage::Geometry};
const Id end_block{OpLabel()};
const Id default_label{OpLabel()};
- const Id func_type_load{TypeFunction(F32[1], U32[1])};
+ const Id func_type_load{is_array ? TypeFunction(F32[1], U32[1], U32[1])
+ : TypeFunction(F32[1], U32[1])};
const Id func{OpFunction(F32[1], spv::FunctionControlMask::MaskNone, func_type_load)};
const Id offset{OpFunctionParameter(U32[1])};
+ const Id vertex{is_array ? OpFunctionParameter(U32[1]) : Id{}};
+
AddLabel();
const Id base_index{OpShiftRightArithmetic(U32[1], offset, Constant(U32[1], 2U))};
const Id masked_index{OpBitwiseAnd(U32[1], base_index, Constant(U32[1], 3U))};
@@ -472,7 +496,7 @@ void EmitContext::DefineAttributeMemAccess(const Info& info) {
labels.push_back(OpLabel());
}
const u32 base_attribute_value = static_cast<u32>(IR::Attribute::Generic0X) >> 2;
- for (u32 i = 0; i < info.input_generics.size(); i++) {
+ for (u32 i = 0; i < info.input_generics.size(); ++i) {
if (!info.input_generics[i].used) {
continue;
}
@@ -486,7 +510,10 @@ void EmitContext::DefineAttributeMemAccess(const Info& info) {
size_t label_index{0};
if (info.loads_position) {
AddLabel(labels[label_index]);
- const Id result{OpLoad(F32[1], OpAccessChain(input_f32, input_position, masked_index))};
+ const Id pointer{is_array
+ ? OpAccessChain(input_f32, input_position, vertex, masked_index)
+ : OpAccessChain(input_f32, input_position, masked_index)};
+ const Id result{OpLoad(F32[1], pointer)};
OpReturnValue(result);
++label_index;
}
@@ -502,7 +529,9 @@ void EmitContext::DefineAttributeMemAccess(const Info& info) {
continue;
}
const Id generic_id{input_generics.at(i)};
- const Id pointer{OpAccessChain(type->pointer, generic_id, masked_index)};
+ const Id pointer{is_array
+ ? OpAccessChain(type->pointer, generic_id, vertex, masked_index)
+ : OpAccessChain(type->pointer, generic_id, masked_index)};
const Id value{OpLoad(type->id, pointer)};
const Id result{type->needs_cast ? OpBitcast(F32[1], value) : value};
OpReturnValue(result);
@@ -910,13 +939,13 @@ void EmitContext::DefineOutputs(const Info& info) {
}
if (info.stores_point_size || profile.fixed_state_point_size) {
if (stage == Stage::Fragment) {
- throw NotImplementedException("Storing PointSize in Fragment stage");
+ throw NotImplementedException("Storing PointSize in fragment stage");
}
output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize);
}
if (info.stores_clip_distance) {
if (stage == Stage::Fragment) {
- throw NotImplementedException("Storing PointSize in Fragment stage");
+ throw NotImplementedException("Storing ClipDistance in fragment stage");
}
const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance);
@@ -924,7 +953,7 @@ void EmitContext::DefineOutputs(const Info& info) {
if (info.stores_viewport_index &&
(profile.support_viewport_index_layer_non_geometry || stage == Shader::Stage::Geometry)) {
if (stage == Stage::Fragment) {
- throw NotImplementedException("Storing ViewportIndex in Fragment stage");
+ throw NotImplementedException("Storing ViewportIndex in fragment stage");
}
viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex);
}