Skip to content

packages/engine/scram/src/zbdd.h

Zero-Suppressed Binary Decision Diagram facilities.

Namespaces

Name
scram
scram::core
scram::core::zbdd

Classes

Name
classscram::core::SetNode <br>Representation of non-terminal nodes in ZBDD.
structscram::core::PairHash <br>Function for hashing a pair of ordered numbers.
structscram::core::TripletHash <br>Functor for hashing triplets of ordered numbers.
classscram::core::Zbdd <br>Zero-Suppressed Binary Decision Diagrams for set manipulations.
classscram::core::Zbdd::const_iterator <br>Iterator over products in a ZBDD container.
classscram::core::zbdd::CutSetContainer <br>Storage for generated cut sets in MOCUS.

Source code

cpp
/*
 * Copyright (C) 2014-2018 Olzhas Rakhimov
 * Copyright (C) 2023 OpenPRA ORG Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


#pragma once

#include <cstdint>

#include <array>
#include <map>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

#include <boost/functional/hash.hpp>
#include <boost/iterator/iterator_facade.hpp>
#include <boost/noncopyable.hpp>

#include "bdd.h"
#include "pdag.h"

namespace scram::core {

class SetNode : public NonTerminal<SetNode> {
 public:
  using NonTerminal::NonTerminal;

  bool minimal() const { return minimal_; }

  void minimal(bool flag) { minimal_ = flag; }

  int max_set_order() const { return max_set_order_; }

  void max_set_order(int order) { max_set_order_ = order; }

  std::int64_t count() const { return count_; }

  void count(std::int64_t number) { count_ = number; }

 private:
  bool minimal_ = false;  
  int max_set_order_ = 0;  
  std::int64_t count_ = 0;  
};

using SetNodePtr = IntrusivePtr<SetNode>;  

struct PairHash {
  std::size_t operator()(const std::pair<int, int>& p) const  {
    return boost::hash_value(p);
  }
};

template <typename Value>
using PairTable = std::unordered_map<std::pair<int, int>, Value, PairHash>;

using Triplet = std::array<int, 3>;  

struct TripletHash {
  std::size_t operator()(const Triplet& triplet) const  {
    return boost::hash_range(triplet.begin(), triplet.end());
  }
};

template <typename Value>
using TripletTable = std::unordered_map<Triplet, Value, TripletHash>;

class Zbdd : private boost::noncopyable {
 public:
  friend class const_iterator;
  using VertexPtr = IntrusivePtr<Vertex<SetNode>>;  
  using TerminalPtr = IntrusivePtr<Terminal<SetNode>>;  

  class const_iterator
      : public boost::iterator_facade<const_iterator, const std::vector<int>,
                                      boost::forward_traversal_tag> {
    friend class boost::iterator_core_access;

    class module_iterator {
     public:
      module_iterator(const SetNode* node, const Zbdd& zbdd, const_iterator* it,
                      bool sentinel = false)
          : sentinel_(sentinel),
            start_pos_(it->product_.size()),
            end_pos_(start_pos_),
            it_(*it),
            node_(node),
            zbdd_(zbdd) {
        if (!sentinel_) {
          sentinel_ = !GenerateProduct(zbdd_.root());
          end_pos_ = it_.product_.size();
        }
      }

      module_iterator(module_iterator&&)  = default;

      explicit operator bool() const { return !sentinel_; }

      void operator++() {
        if (sentinel_)
          return;
        assert(end_pos_ >= start_pos_ && "Corrupted sentinel.");
        while (start_pos_ != static_cast<int>(it_.product_.size())) {
          if (!module_stack_.empty() &&
              static_cast<int>(it_.product_.size()) == module_stack_.back().end_pos_) {
            const SetNode* node = module_stack_.back().node_;
            for (++module_stack_.back(); module_stack_.back();
                 ++module_stack_.back()) {
              if (GenerateProduct(node->high()))
                goto outer_break;
            }
            module_stack_.pop_back();
            if (GenerateProduct(node->low()))
              break;

          } else if (GenerateProduct(Pop()->low())) {
            break;
          }
        }
      outer_break:
        end_pos_ = it_.product_.size();
        sentinel_ = start_pos_ == end_pos_;
      }

     private:
      bool GenerateProduct(const VertexPtr& vertex)  {
        if (vertex->terminal()) {
          if (!Terminal<SetNode>::Ref(vertex).value())
            return false;
          if (it_.probability_enabled_ && it_.current_probability_ * it_.zbdd_.pdag_->initiating_event_frequency() < it_.cut_off_)
            return false;
          return true;
        }
        if (static_cast<int>(it_.product_.size()) >= it_.zbdd_.settings().limit_order())
          return false;
        const SetNode& node = SetNode::Ref(vertex);
        if (node.module()) {
          module_stack_.emplace_back(
              &node, *zbdd_.modules_.find(node.index())->second, &it_);
          for (; module_stack_.back(); ++module_stack_.back()) {
            if (GenerateProduct(node.high()))
              return true;
          }
          assert(static_cast<int>(it_.product_.size()) == module_stack_.back().start_pos_);
          module_stack_.pop_back();
          return GenerateProduct(node.low());

        } else {
          it_.PushLiteral(&node);
          if (!it_.ShouldPruneHighBranch() && GenerateProduct(node.high()))
            return true;
          it_.PopLiteral();
          return GenerateProduct(node.low());
        }
      }

      const SetNode* Pop()  {
        assert(start_pos_ < static_cast<int>(it_.product_.size()) && "Access beyond the range!");
        return it_.PopLiteral();
      }

      void Push(const SetNode* set_node)  {
        it_.PushLiteral(set_node);
      }

      bool sentinel_;  
      const int start_pos_;  
      int end_pos_;  
      const_iterator& it_;  
      const SetNode* node_;  
      const Zbdd& zbdd_;  
      std::vector<module_iterator> module_stack_;
    };

    void PushLiteral(const SetNode* set_node)  {
      node_stack_.push_back(set_node);
      product_.push_back(set_node->index());
      if (!probability_enabled_)
        return;
      probability_stack_.push_back(current_probability_);
      current_probability_ *= zbdd_.LiteralProbability(set_node->index());
    }

    const SetNode* PopLiteral()  {
      assert(!node_stack_.empty() && "PopLiteral on empty stack.");
      const SetNode* leaf = node_stack_.back();
      node_stack_.pop_back();
      product_.pop_back();
      if (probability_enabled_) {
        assert(!probability_stack_.empty() && "Probability stack mismatch.");
        current_probability_ = probability_stack_.back();
        probability_stack_.pop_back();
      }
      return leaf;
    }

    bool ShouldPruneHighBranch() const {
      if (!probability_enabled_) return false;
      double freq = zbdd_.pdag_->initiating_event_frequency();
      return current_probability_ * freq < cut_off_;
    }

   public:
    explicit const_iterator(const Zbdd& zbdd, bool sentinel = false)
        : sentinel_(sentinel),
          zbdd_(zbdd),
          cut_off_(zbdd.settings().cut_off()),
          probability_enabled_(zbdd.HasProbabilityContext()),
          it_(nullptr, zbdd, this, sentinel) {
      sentinel_ = !it_;
      current_probability_ = 1.0;
      if (!probability_enabled_)
        cut_off_ = 0.0;
    }

    const_iterator(const const_iterator& other)
        : sentinel_(other.sentinel_),
          zbdd_(other.zbdd_),
          cut_off_(other.cut_off_),
          probability_enabled_(other.probability_enabled_),
          it_(nullptr, zbdd_, this, sentinel_) {
      current_probability_ = 1.0;
      if (!probability_enabled_)
        cut_off_ = 0.0;
      assert(*this == other && "Copy ctor is only for begin/end iterators.");
    }

   private:
    void increment() {
      assert(!sentinel_ && "Incrementing an end iterator.");
      ++it_;
      sentinel_ = !it_;
    }
    bool equal(const const_iterator& other) const {
      assert(!(sentinel_ && !product_.empty()) && "Uncleared products.");
      return sentinel_ == other.sentinel_ && &zbdd_ == &other.zbdd_ &&
             product_ == other.product_;
    }
    const std::vector<int>& dereference() const {
      assert(!sentinel_ && "Dereferencing end iterator.");
      return product_;
    }

    bool sentinel_;  
    const Zbdd& zbdd_;  
    std::vector<int> product_;  
    std::vector<const SetNode*> node_stack_;  
    std::vector<double> probability_stack_;  
    double current_probability_ = 1.0;  
    double cut_off_ = 0.0;  
    bool probability_enabled_ = false;  
    module_iterator it_;  
  };

  Zbdd(Bdd* bdd, const Settings& settings) ;

  Zbdd(const Pdag* graph, const Settings& settings) ;

  virtual ~Zbdd()  = default;

  void Analyze(const Pdag* graph = nullptr) ;

  void SetProbabilityContext(const Pdag* pdag);

  const Zbdd& products() const { return *this; }

  auto begin() const { return const_iterator(*this); }
  auto end() const { return const_iterator(*this, /*sentinel=*/true); }

  std::size_t size() const { return std::distance(begin(), end()); }

  bool empty() const { return begin() == end(); }

  bool base() const { return root_ == kBase_; }

 protected:
  explicit Zbdd(const Settings& settings, bool coherent = false,
                int module_index = 0, const Pdag* pdag = nullptr) ;

  const VertexPtr& root() const { return root_; }

  void root(const VertexPtr& vertex) { root_ = vertex; }

  const Settings& settings() const { return kSettings_; }

  bool HasProbabilityContext() const { return pdag_ && kSettings_.cut_off() > 0.0; }

  double LiteralProbability(int literal) const;

  const std::map<int, std::unique_ptr<Zbdd>>& modules() const {
    return modules_;
  }

  void Log() ;

  SetNodePtr FindOrAddVertex(int index, const VertexPtr& high,
                             const VertexPtr& low, int order,
                             bool module = false,
                             bool coherent = false) ;

  SetNodePtr FindOrAddVertex(const Gate& gate, const VertexPtr& high,
                             const VertexPtr& low) ;

  template <Connective Type>
  VertexPtr Apply(const VertexPtr& arg_one, const VertexPtr& arg_two,
                  int limit_order) ;

  VertexPtr Apply(Connective type, const VertexPtr& arg_one,
                  const VertexPtr& arg_two, int limit_order) ;

  template <Connective Type>
  VertexPtr Apply(const SetNodePtr& arg_one, const SetNodePtr& arg_two,
                  int limit_order) ;

  VertexPtr EliminateComplements(
      const VertexPtr& vertex,
      std::unordered_map<int, VertexPtr>* wide_results) ;

  void EliminateConstantModules() ;

  VertexPtr Minimize(const VertexPtr& vertex) ;

  int GatherModules(const VertexPtr& vertex, int current_order,
                    std::map<int, std::pair<bool, int>>* modules) ;

  void ApplySubstitutions(
      const std::vector<Pdag::Substitution>& substitutions) ;

  void ClearTables()  {
    and_table_.clear();
    or_table_.clear();
    minimal_results_.clear();
    subsume_table_.clear();
    prune_results_.clear();
  }

  void Freeze()  {
    unique_table_.Release();
    Zbdd::ClearTables();
    and_table_.reserve(0);
    or_table_.reserve(0);
    minimal_results_.reserve(0);
    subsume_table_.reserve(0);
  }

  void JoinModule(int index, std::unique_ptr<Zbdd> container)  {
    assert(!modules_.count(index));
    assert(container->root()->terminal() ||
           SetNode::Ref(container->root()).minimal());
    modules_.emplace(index, std::move(container));
  }

  const TerminalPtr kBase_;  
  const TerminalPtr kEmpty_;  

 private:
  using SetNodeWeakPtr = WeakIntrusivePtr<SetNode>;  
  using ComputeTable = TripletTable<VertexPtr>;  
  using ModuleEntry = std::pair<const int, std::unique_ptr<Zbdd>>;

  Zbdd(const Bdd::Function& module, bool coherent, Bdd* bdd,
       const Settings& settings, int module_index = 0) ;

  Zbdd(const Gate& gate, const Settings& settings) ;

  SetNodePtr FindOrAddVertex(const SetNodePtr& node, const VertexPtr& high,
                             const VertexPtr& low) ;

  VertexPtr GetReducedVertex(const ItePtr& ite, bool complement,
                             const VertexPtr& high,
                             const VertexPtr& low) ;

  VertexPtr GetReducedVertex(const SetNodePtr& node, const VertexPtr& high,
                             const VertexPtr& low) ;

  Triplet GetResultKey(const VertexPtr& arg_one, const VertexPtr& arg_two,
                       int limit_order) ;

  VertexPtr ConvertBdd(const Bdd::VertexPtr& vertex, bool complement,
                       Bdd* bdd_graph, int limit_order,
                       PairTable<VertexPtr>* ites) ;

  VertexPtr ConvertBdd(const ItePtr& ite, bool complement, Bdd* bdd_graph,
                       int limit_order, PairTable<VertexPtr>* ites) ;

  VertexPtr ConvertBddPrimeImplicants(const ItePtr& ite, bool complement,
                                      Bdd* bdd_graph, int limit_order,
                                      PairTable<VertexPtr>* ites) ;

  VertexPtr
  ConvertGraph(const Gate& gate,
               std::unordered_map<int, std::pair<VertexPtr, int>>* gates,
               std::unordered_map<int, const Gate*>* module_gates) ;

  VertexPtr EliminateComplement(const SetNodePtr& node, const VertexPtr& high,
                                const VertexPtr& low) ;

  VertexPtr EliminateConstantModules(
      const VertexPtr& vertex,
      std::unordered_map<int, VertexPtr>* results) ;

  VertexPtr EliminateConstantModule(const SetNodePtr& node,
                                    const VertexPtr& high,
                                    const VertexPtr& low) ;

  VertexPtr Subsume(const VertexPtr& high, const VertexPtr& low) ;

  VertexPtr Prune(const VertexPtr& vertex, int limit_order) ;

  virtual bool IsGate(const SetNode& node)  { return node.module(); }

  bool MayBeUnity(const SetNode& node) ;

  int CountSetNodes(const VertexPtr& vertex) ;

  std::int64_t CountProducts(const VertexPtr& vertex, bool modules) ;

  void ClearMarks(const VertexPtr& vertex, bool modules) ;

  void ClearCounts(const VertexPtr& vertex, bool modules) ;

  void TestStructure(const VertexPtr& vertex, bool modules) ;

  const Settings kSettings_;  
  const Pdag* pdag_;  
  VertexPtr root_;  
  bool coherent_;  
  int module_index_;  

  UniqueTable<SetNode> unique_table_;

  ComputeTable and_table_;
  ComputeTable or_table_;

  std::unordered_map<int, VertexPtr> minimal_results_;
  PairTable<VertexPtr> subsume_table_;
  PairTable<VertexPtr> prune_results_;

  std::map<int, std::unique_ptr<Zbdd>> modules_;  
  int set_id_;  
};

namespace zbdd {

class CutSetContainer : public Zbdd {
 public:
  CutSetContainer(const Settings& settings, int module_index,
                  int gate_index_bound, const Pdag* pdag) ;

  VertexPtr ConvertGate(const Gate& gate) ;

  int GetNextGate()  {
    if (Zbdd::root()->terminal())
      return 0;
    SetNode& node = SetNode::Ref(Zbdd::root());
    return CutSetContainer::IsGate(node) && !node.module() ? node.index() : 0;
  }

  VertexPtr ExtractIntermediateCutSets(int index) ;

  VertexPtr ExpandGate(const VertexPtr& gate_zbdd,
                       const VertexPtr& cut_sets) ;

  void Merge(const VertexPtr& vertex) ;

  void EliminateComplements()  {
    std::unordered_map<int, VertexPtr> wide_results;
    Zbdd::root(Zbdd::EliminateComplements(Zbdd::root(), &wide_results));
  }

  void EliminateConstantModules()  { Zbdd::EliminateConstantModules(); }

  void Minimize()  { Zbdd::root(Zbdd::Minimize(Zbdd::root())); }

  std::map<int, std::pair<bool, int>> GatherModules()  {
    assert(Zbdd::modules().empty() && "Unexpected call with defined modules?!");
    std::map<int, std::pair<bool, int>> modules;
    Zbdd::GatherModules(Zbdd::root(), 0, &modules);
    return modules;
  }

  using Zbdd::JoinModule;  
  using Zbdd::Log;  

 private:
  bool IsGate(const SetNode& node)  override {
    return node.index() > gate_index_bound_;
  }

  int gate_index_bound_;  
};

}  // namespace zbdd

}  // namespace scram::core

Updated on 2026-01-09 at 21:59:13 +0000