From 8be6e1c5221066a49b6ad27efbd20a999a7c16b3 Mon Sep 17 00:00:00 2001 From: Fernando Sahmkow Date: Fri, 28 Jun 2019 20:54:21 -0400 Subject: shader_ir: Corrections to outward movements and misc stuffs --- src/common/CMakeLists.txt | 4 + src/video_core/CMakeLists.txt | 1 + src/video_core/shader/ast.cpp | 185 ++++++++++++++++++++++++--------- src/video_core/shader/ast.h | 53 ++++++++-- src/video_core/shader/control_flow.cpp | 14 ++- src/video_core/shader/expr.cpp | 75 +++++++++++++ src/video_core/shader/expr.h | 36 ++++++- 7 files changed, 310 insertions(+), 58 deletions(-) create mode 100644 src/video_core/shader/expr.cpp (limited to 'src') diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index dfed8b51d..afc5ff736 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -60,9 +60,13 @@ add_custom_command(OUTPUT scm_rev.cpp "${VIDEO_CORE}/shader/decode/video.cpp" "${VIDEO_CORE}/shader/decode/warp.cpp" "${VIDEO_CORE}/shader/decode/xmad.cpp" + "${VIDEO_CORE}/shader/ast.cpp" + "${VIDEO_CORE}/shader/ast.h" "${VIDEO_CORE}/shader/control_flow.cpp" "${VIDEO_CORE}/shader/control_flow.h" "${VIDEO_CORE}/shader/decode.cpp" + "${VIDEO_CORE}/shader/expr.cpp" + "${VIDEO_CORE}/shader/expr.h" "${VIDEO_CORE}/shader/node.h" "${VIDEO_CORE}/shader/node_helper.cpp" "${VIDEO_CORE}/shader/node_helper.h" diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt index 32049a2e7..33fa88762 100644 --- a/src/video_core/CMakeLists.txt +++ b/src/video_core/CMakeLists.txt @@ -110,6 +110,7 @@ add_library(video_core STATIC shader/control_flow.cpp shader/control_flow.h shader/decode.cpp + shader/expr.cpp shader/expr.h shader/node_helper.cpp shader/node_helper.h diff --git a/src/video_core/shader/ast.cpp b/src/video_core/shader/ast.cpp index 56a1b29f3..d521a7b52 100644 --- a/src/video_core/shader/ast.cpp +++ b/src/video_core/shader/ast.cpp @@ -12,18 +12,22 @@ namespace VideoCommon::Shader { ASTZipper::ASTZipper() = default; -ASTZipper::ASTZipper(ASTNode new_first) : first{}, last{} { + +void ASTZipper::Init(ASTNode new_first, ASTNode parent) { + ASSERT(new_first->manager == nullptr); first = new_first; last = new_first; ASTNode current = first; while (current) { current->manager = this; + current->parent = parent; last = current; current = current->next; } } void ASTZipper::PushBack(ASTNode new_node) { + ASSERT(new_node->manager == nullptr); new_node->previous = last; if (last) { last->next = new_node; @@ -37,38 +41,55 @@ void ASTZipper::PushBack(ASTNode new_node) { } void ASTZipper::PushFront(ASTNode new_node) { + ASSERT(new_node->manager == nullptr); new_node->previous.reset(); new_node->next = first; if (first) { - first->previous = first; + first->previous = new_node; } - first = new_node; - if (!last) { + if (last == first) { last = new_node; } + first = new_node; new_node->manager = this; } void ASTZipper::InsertAfter(ASTNode new_node, ASTNode at_node) { + ASSERT(new_node->manager == nullptr); if (!at_node) { PushFront(new_node); return; } + ASTNode next = at_node->next; + if (next) { + next->previous = new_node; + } new_node->previous = at_node; if (at_node == last) { last = new_node; } - new_node->next = at_node->next; + new_node->next = next; at_node->next = new_node; new_node->manager = this; } -void ASTZipper::SetParent(ASTNode new_parent) { - ASTNode current = first; - while (current) { - current->parent = new_parent; - current = current->next; +void ASTZipper::InsertBefore(ASTNode new_node, ASTNode at_node) { + ASSERT(new_node->manager == nullptr); + if (!at_node) { + PushBack(new_node); + return; } + ASTNode previous = at_node->previous; + if (previous) { + previous->next = new_node; + } + new_node->next = at_node; + if (at_node == first) { + first = new_node; + } + new_node->previous = previous; + at_node->previous = new_node; + new_node->manager = this; } void ASTZipper::DetachTail(ASTNode node) { @@ -80,11 +101,22 @@ void ASTZipper::DetachTail(ASTNode node) { } last = node->previous; + last->next.reset(); node->previous.reset(); + ASTNode current = node; + while (current) { + current->manager = nullptr; + current->parent.reset(); + current = current->next; + } } void ASTZipper::DetachSegment(ASTNode start, ASTNode end) { ASSERT(start->manager == this && end->manager == this); + if (start == end) { + DetachSingle(start); + return; + } ASTNode prev = start->previous; ASTNode post = end->next; if (!prev) { @@ -131,7 +163,6 @@ void ASTZipper::DetachSingle(ASTNode node) { node->parent.reset(); } - void ASTZipper::Remove(ASTNode node) { ASSERT(node->manager == this); ASTNode next = node->next; @@ -178,12 +209,7 @@ public: } void operator()(ExprPredicate const& expr) { - u32 pred = static_cast(expr.predicate); - if (pred > 7) { - inner += "!"; - pred -= 8; - } - inner += "P" + std::to_string(pred); + inner += "P" + std::to_string(expr.predicate); } void operator()(ExprCondCode const& expr) { @@ -253,6 +279,10 @@ public: ");\n"; } + void operator()(ASTBlockDecoded& ast) { + inner += Ident() + "Block;\n"; + } + void operator()(ASTVarSet& ast) { ExprPrinter expr_parser{}; std::visit(expr_parser, *ast.condition); @@ -282,7 +312,7 @@ public: current = current->GetNext(); } scope--; - inner += Ident() + "} while (" + expr_parser.GetResult() + ")\n"; + inner += Ident() + "} while (" + expr_parser.GetResult() + ");\n"; } void operator()(ASTReturn& ast) { @@ -333,8 +363,6 @@ std::string ASTManager::Print() { return printer.GetResult(); } -#pragma optimize("", off) - void ASTManager::Decompile() { auto it = gotos.begin(); while (it != gotos.end()) { @@ -348,11 +376,12 @@ void ASTManager::Decompile() { } if (DirectlyRelated(goto_node, label)) { u32 goto_level = goto_node->GetLevel(); - u32 label_level = goto_node->GetLevel(); - while (label_level > goto_level) { + u32 label_level = label->GetLevel(); + while (label_level < goto_level) { MoveOutward(goto_node); - goto_level++; + goto_level--; } + // TODO(Blinkhawk): Implement Lifting and Inward Movements } if (label->GetParent() == goto_node->GetParent()) { bool is_loop = false; @@ -375,13 +404,11 @@ void ASTManager::Decompile() { } it++; } - /* for (ASTNode label : labels) { auto& manager = label->GetManager(); manager.Remove(label); } labels.clear(); - */ } bool ASTManager::IndirectlyRelated(ASTNode first, ASTNode second) { @@ -410,87 +437,149 @@ bool ASTManager::DirectlyRelated(ASTNode first, ASTNode second) { max = second; } - while (min_level < max_level) { - min_level++; - min = min->GetParent(); + while (max_level > min_level) { + max_level--; + max = max->GetParent(); } return (min->GetParent() == max->GetParent()); } +void ASTManager::ShowCurrentState(std::string state) { + LOG_CRITICAL(HW_GPU, "\nState {}:\n\n{}\n", state, Print()); + SanityCheck(); +} + +void ASTManager::SanityCheck() { + for (auto label : labels) { + if (!label->GetParent()) { + LOG_CRITICAL(HW_GPU, "Sanity Check Failed"); + } + } +} + void ASTManager::EncloseDoWhile(ASTNode goto_node, ASTNode label) { + // ShowCurrentState("Before DoWhile Enclose"); + enclose_count++; ASTZipper& zipper = goto_node->GetManager(); ASTNode loop_start = label->GetNext(); if (loop_start == goto_node) { zipper.Remove(goto_node); + // ShowCurrentState("Ignore DoWhile Enclose"); return; } ASTNode parent = label->GetParent(); Expr condition = goto_node->GetGotoCondition(); zipper.DetachSegment(loop_start, goto_node); - ASTNode do_while_node = ASTBase::Make(parent, condition, ASTZipper(loop_start)); - zipper.InsertAfter(do_while_node, label); + ASTNode do_while_node = ASTBase::Make(parent, condition); ASTZipper* sub_zipper = do_while_node->GetSubNodes(); - sub_zipper->SetParent(do_while_node); + sub_zipper->Init(loop_start, do_while_node); + zipper.InsertAfter(do_while_node, label); sub_zipper->Remove(goto_node); + // ShowCurrentState("After DoWhile Enclose"); } void ASTManager::EncloseIfThen(ASTNode goto_node, ASTNode label) { + // ShowCurrentState("Before IfThen Enclose"); + enclose_count++; ASTZipper& zipper = goto_node->GetManager(); ASTNode if_end = label->GetPrevious(); if (if_end == goto_node) { zipper.Remove(goto_node); + // ShowCurrentState("Ignore IfThen Enclose"); return; } ASTNode prev = goto_node->GetPrevious(); - ASTNode parent = label->GetParent(); Expr condition = goto_node->GetGotoCondition(); - Expr neg_condition = MakeExpr(condition); + bool do_else = false; + if (prev->IsIfThen()) { + Expr if_condition = prev->GetIfCondition(); + do_else = ExprAreEqual(if_condition, condition); + } + ASTNode parent = label->GetParent(); zipper.DetachSegment(goto_node, if_end); - ASTNode if_node = ASTBase::Make(parent, condition, ASTZipper(goto_node)); - zipper.InsertAfter(if_node, prev); + ASTNode if_node; + if (do_else) { + if_node = ASTBase::Make(parent); + } else { + Expr neg_condition = MakeExprNot(condition); + if_node = ASTBase::Make(parent, neg_condition); + } ASTZipper* sub_zipper = if_node->GetSubNodes(); - sub_zipper->SetParent(if_node); + sub_zipper->Init(goto_node, if_node); + zipper.InsertAfter(if_node, prev); sub_zipper->Remove(goto_node); + // ShowCurrentState("After IfThen Enclose"); } void ASTManager::MoveOutward(ASTNode goto_node) { + // ShowCurrentState("Before MoveOutward"); + outward_count++; ASTZipper& zipper = goto_node->GetManager(); ASTNode parent = goto_node->GetParent(); + ASTZipper& zipper2 = parent->GetManager(); + ASTNode grandpa = parent->GetParent(); bool is_loop = parent->IsLoop(); - bool is_if = parent->IsIfThen() || parent->IsIfElse(); + bool is_else = parent->IsIfElse(); + bool is_if = parent->IsIfThen(); ASTNode prev = goto_node->GetPrevious(); + ASTNode post = goto_node->GetNext(); Expr condition = goto_node->GetGotoCondition(); - u32 var_index = NewVariable(); - Expr var_condition = MakeExpr(var_index); - ASTNode var_node = ASTBase::Make(parent, var_index, condition); zipper.DetachSingle(goto_node); - zipper.InsertAfter(var_node, prev); - goto_node->SetGotoCondition(var_condition); if (is_loop) { + u32 var_index = NewVariable(); + Expr var_condition = MakeExpr(var_index); + ASTNode var_node = ASTBase::Make(parent, var_index, condition); + ASTNode var_node_init = ASTBase::Make(parent, var_index, true_condition); + zipper2.InsertBefore(var_node_init, parent); + zipper.InsertAfter(var_node, prev); + goto_node->SetGotoCondition(var_condition); ASTNode break_node = ASTBase::Make(parent, var_condition); zipper.InsertAfter(break_node, var_node); - } else if (is_if) { - ASTNode post = var_node->GetNext(); + } else if (is_if || is_else) { if (post) { + u32 var_index = NewVariable(); + Expr var_condition = MakeExpr(var_index); + ASTNode var_node = ASTBase::Make(parent, var_index, condition); + ASTNode var_node_init = ASTBase::Make(parent, var_index, true_condition); + if (is_if) { + zipper2.InsertBefore(var_node_init, parent); + } else { + zipper2.InsertBefore(var_node_init, parent->GetPrevious()); + } + zipper.InsertAfter(var_node, prev); + goto_node->SetGotoCondition(var_condition); zipper.DetachTail(post); - ASTNode if_node = ASTBase::Make(parent, var_condition, ASTZipper(post)); - zipper.InsertAfter(if_node, var_node); + ASTNode if_node = ASTBase::Make(parent, MakeExprNot(var_condition)); ASTZipper* sub_zipper = if_node->GetSubNodes(); - sub_zipper->SetParent(if_node); + sub_zipper->Init(post, if_node); + zipper.InsertAfter(if_node, var_node); + } else { + Expr if_condition; + if (is_if) { + if_condition = parent->GetIfCondition(); + } else { + ASTNode if_node = parent->GetPrevious(); + if_condition = MakeExprNot(if_node->GetIfCondition()); + } + Expr new_condition = MakeExprAnd(if_condition, condition); + goto_node->SetGotoCondition(new_condition); } } else { UNREACHABLE(); } - ASTZipper& zipper2 = parent->GetManager(); ASTNode next = parent->GetNext(); if (is_if && next && next->IsIfElse()) { zipper2.InsertAfter(goto_node, next); + goto_node->SetParent(grandpa); + // ShowCurrentState("After MoveOutward"); return; } zipper2.InsertAfter(goto_node, parent); + goto_node->SetParent(grandpa); + // ShowCurrentState("After MoveOutward"); } } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/ast.h b/src/video_core/shader/ast.h index 22ac8884c..4276f66a9 100644 --- a/src/video_core/shader/ast.h +++ b/src/video_core/shader/ast.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -21,6 +22,7 @@ class ASTProgram; class ASTIfThen; class ASTIfElse; class ASTBlockEncoded; +class ASTBlockDecoded; class ASTVarSet; class ASTGoto; class ASTLabel; @@ -28,7 +30,7 @@ class ASTDoWhile; class ASTReturn; class ASTBreak; -using ASTData = std::variant; using ASTNode = std::shared_ptr; @@ -43,7 +45,8 @@ enum class ASTZipperType : u32 { class ASTZipper final { public: ASTZipper(); - ASTZipper(ASTNode first); + + void Init(ASTNode first, ASTNode parent); ASTNode GetFirst() { return first; @@ -56,7 +59,7 @@ public: void PushBack(ASTNode new_node); void PushFront(ASTNode new_node); void InsertAfter(ASTNode new_node, ASTNode at_node); - void SetParent(ASTNode new_parent); + void InsertBefore(ASTNode new_node, ASTNode at_node); void DetachTail(ASTNode node); void DetachSingle(ASTNode node); void DetachSegment(ASTNode start, ASTNode end); @@ -74,14 +77,14 @@ public: class ASTIfThen { public: - ASTIfThen(Expr condition, ASTZipper nodes) : condition(condition), nodes{nodes} {} + ASTIfThen(Expr condition) : condition(condition), nodes{} {} Expr condition; ASTZipper nodes; }; class ASTIfElse { public: - ASTIfElse(ASTZipper nodes) : nodes{nodes} {} + ASTIfElse() : nodes{} {} ASTZipper nodes; }; @@ -92,6 +95,12 @@ public: u32 end; }; +class ASTBlockDecoded { +public: + ASTBlockDecoded(NodeBlock& new_nodes) : nodes(std::move(new_nodes)) {} + NodeBlock nodes; +}; + class ASTVarSet { public: ASTVarSet(u32 index, Expr condition) : index{index}, condition{condition} {} @@ -114,7 +123,7 @@ public: class ASTDoWhile { public: - ASTDoWhile(Expr condition, ASTZipper nodes) : condition(condition), nodes{nodes} {} + ASTDoWhile(Expr condition) : condition(condition), nodes{} {} Expr condition; ASTZipper nodes; }; @@ -132,6 +141,8 @@ public: Expr condition; }; +using TransformCallback = std::function; + class ASTBase { public: explicit ASTBase(ASTNode parent, ASTData data) : parent{parent}, data{data} {} @@ -195,6 +206,14 @@ public: return nullptr; } + Expr GetIfCondition() const { + auto inner = std::get_if(&data); + if (inner) { + return inner->condition; + } + return nullptr; + } + void SetGotoCondition(Expr new_condition) { auto inner = std::get_if(&data); if (inner) { @@ -210,6 +229,18 @@ public: return std::holds_alternative(data); } + bool IsBlockEncoded() const { + return std::holds_alternative(data); + } + + void TransformBlockEncoded(TransformCallback& callback) { + auto block = std::get_if(&data); + const u32 start = block->start; + const u32 end = block->end; + NodeBlock nodes = callback(start, end); + data = ASTBlockDecoded(nodes); + } + bool IsLoop() const { return std::holds_alternative(data); } @@ -245,6 +276,7 @@ public: explicit ASTManager() { main_node = ASTBase::Make(ASTNode{}); program = std::get_if(main_node->GetInnerData()); + true_condition = MakeExpr(true); } void DeclareLabel(u32 address) { @@ -283,7 +315,13 @@ public: void Decompile(); + void ShowCurrentState(std::string state); + void SanityCheck(); + + bool IsFullyDecompiled() { + return gotos.size() == 0; + } private: bool IndirectlyRelated(ASTNode first, ASTNode second); @@ -309,6 +347,9 @@ private: u32 variables{}; ASTProgram* program; ASTNode main_node; + Expr true_condition; + u32 outward_count{}; + u32 enclose_count{}; }; } // namespace VideoCommon::Shader diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp index bea7f767c..7a21d870f 100644 --- a/src/video_core/shader/control_flow.cpp +++ b/src/video_core/shader/control_flow.cpp @@ -423,7 +423,16 @@ void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) { result = MakeExpr(cond.cc); } if (cond.predicate != Pred::UnusedIndex) { - Expr extra = MakeExpr(cond.predicate); + u32 pred = static_cast(cond.predicate); + bool negate; + if (pred > 7) { + negate = true; + pred -= 8; + } + Expr extra = MakeExpr(pred); + if (negate) { + extra = MakeExpr(extra); + } if (result) { return MakeExpr(extra, result); } @@ -460,8 +469,9 @@ void DecompileShader(CFGRebuildState& state) { InsertBranch(manager, block.branch); } } + //manager.ShowCurrentState("Before Decompiling"); manager.Decompile(); - LOG_CRITICAL(HW_GPU, "Decompiled Shader:\n{} \n", manager.Print()); + //manager.ShowCurrentState("After Decompiling"); } std::optional ScanFlow(const ProgramCode& program_code, diff --git a/src/video_core/shader/expr.cpp b/src/video_core/shader/expr.cpp new file mode 100644 index 000000000..ebce6339b --- /dev/null +++ b/src/video_core/shader/expr.cpp @@ -0,0 +1,75 @@ +// Copyright 2019 yuzu Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include +#include + +#include "video_core/shader/expr.h" + +namespace VideoCommon::Shader { + +bool ExprAnd::operator==(const ExprAnd& b) const { + return (*operand1 == *b.operand1) && (*operand2 == *b.operand2); +} + +bool ExprOr::operator==(const ExprOr& b) const { + return (*operand1 == *b.operand1) && (*operand2 == *b.operand2); +} + +bool ExprNot::operator==(const ExprNot& b) const { + return (*operand1 == *b.operand1); +} + +bool ExprIsBoolean(Expr expr) { + return std::holds_alternative(*expr); +} + +bool ExprBooleanGet(Expr expr) { + return std::get_if(expr.get())->value; +} + +Expr MakeExprNot(Expr first) { + if (std::holds_alternative(*first)) { + return std::get_if(first.get())->operand1; + } + return MakeExpr(first); +} + +Expr MakeExprAnd(Expr first, Expr second) { + if (ExprIsBoolean(first)) { + return ExprBooleanGet(first) ? second : first; + } + if (ExprIsBoolean(second)) { + return ExprBooleanGet(second) ? first : second; + } + return MakeExpr(first, second); +} + +Expr MakeExprOr(Expr first, Expr second) { + if (ExprIsBoolean(first)) { + return ExprBooleanGet(first) ? first : second; + } + if (ExprIsBoolean(second)) { + return ExprBooleanGet(second) ? second : first; + } + return MakeExpr(first, second); +} + +bool ExprAreEqual(Expr first, Expr second) { + return (*first) == (*second); +} + +bool ExprAreOpposite(Expr first, Expr second) { + if (std::holds_alternative(*first)) { + return ExprAreEqual(std::get_if(first.get())->operand1, second); + } + if (std::holds_alternative(*second)) { + return ExprAreEqual(std::get_if(second.get())->operand1, first); + } + return false; +} + +} // namespace VideoCommon::Shader diff --git a/src/video_core/shader/expr.h b/src/video_core/shader/expr.h index 94678f09a..f012f6fcf 100644 --- a/src/video_core/shader/expr.h +++ b/src/video_core/shader/expr.h @@ -30,6 +30,8 @@ class ExprAnd final { public: ExprAnd(Expr a, Expr b) : operand1{a}, operand2{b} {} + bool operator==(const ExprAnd& b) const; + Expr operand1; Expr operand2; }; @@ -38,6 +40,8 @@ class ExprOr final { public: ExprOr(Expr a, Expr b) : operand1{a}, operand2{b} {} + bool operator==(const ExprOr& b) const; + Expr operand1; Expr operand2; }; @@ -46,6 +50,8 @@ class ExprNot final { public: ExprNot(Expr a) : operand1{a} {} + bool operator==(const ExprNot& b) const; + Expr operand1; }; @@ -53,20 +59,32 @@ class ExprVar final { public: ExprVar(u32 index) : var_index{index} {} + bool operator==(const ExprVar& b) const { + return var_index == b.var_index; + } + u32 var_index; }; class ExprPredicate final { public: - ExprPredicate(Pred predicate) : predicate{predicate} {} + ExprPredicate(u32 predicate) : predicate{predicate} {} + + bool operator==(const ExprPredicate& b) const { + return predicate == b.predicate; + } - Pred predicate; + u32 predicate; }; class ExprCondCode final { public: ExprCondCode(ConditionCode cc) : cc{cc} {} + bool operator==(const ExprCondCode& b) const { + return cc == b.cc; + } + ConditionCode cc; }; @@ -74,6 +92,10 @@ class ExprBoolean final { public: ExprBoolean(bool val) : value{val} {} + bool operator==(const ExprBoolean& b) const { + return value == b.value; + } + bool value; }; @@ -83,4 +105,14 @@ Expr MakeExpr(Args&&... args) { return std::make_shared(T(std::forward(args)...)); } +bool ExprAreEqual(Expr first, Expr second); + +bool ExprAreOpposite(Expr first, Expr second); + +Expr MakeExprNot(Expr first); + +Expr MakeExprAnd(Expr first, Expr second); + +Expr MakeExprOr(Expr first, Expr second); + } // namespace VideoCommon::Shader -- cgit v1.2.3