0%

Tensorflow源码解析——算子注册

什么是op

op和kernel是TF框架中最重要的两个概念,如果一定要做一个类比的话,可以认为op相当于函数声明,kernel相当于函数实现。举个例子,对于矩阵相乘,可以声明一个op叫做MatMul,指明它的名称,输入,输出,参数,以及对参数的限制等。op只是告诉我们,这个操作的目的是什么,op内部有哪些可定制的东西,但不会提供具体实现。op在某种设备上的具体实现方法,是由kernel决定的。TF的计算图由节点构成,而每个节点对应了一个op,在构建计算图时,我们只知道不同节点对应的操作是什么,而不知道运行时这个操作是怎样实现的。也就是说,op是编译期概念,而kernel是运行期概念。

那为什么要把操作和它的实现分离呢?是为了实现TF代码的可移植性。我们可以把TF构建的计算图想象为Java的字节码,而计算图在执行的时候,需要考虑可用的设备资源,相当于我们在运行Java字节码的时候,需要考虑当前所在的操作系统,选择合适的字节码实现。因为TF的目标是在多设备上运行,但我们在编码的时候,是无法预先知道某一个操作具体是在哪种设备上运行的,因此,将op和它的实现分离,可以让我们在设计计算图的时候,更专注于它的结构,而不是具体实现。当我们构建完成一个计算图之后,在一个包含GPU的设备上,它可以利用对应操作在GPU上的kernel,充分利用GPU的高计算性能,在一个仅包含CPU的设备上,它也可以利用对应操作在CPU上的kenrel,完成计算功能。这就提高了TF代码在不同设备之间的可移植性。

注册方式

下面是tensorflow代码中注册Argmax算子的代码:

1
2
3
4
5
6
7
8
REGISTER_OP("ArgMax")
.Input("input: T")
.Input("dimension: Tidx")
.Output("output: output_type")
.Attr("T: numbertype")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("output_type: {int32, int64} = DT_INT64")
.SetShapeFn(ArgOpShape);

通过REGISTER_OP宏进行算子注册,注册的内容有:

  • Input:算子的输入
  • Output:算子的输出
  • Attr:算子的属性,比如Argmax算子,有个属性是axis,在哪根轴上求最大值的下标
  • ShapeFn:用于shape推断

下面分析这个算子是如何被注册进去的。

OpDef

OpDef的定义在tensorflow\core\framework\op_def.proto

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

bool is_ref = 16;
};

OpDef中最核心的数据成员是操作名称、输入、输出、参数。

对于其中的几个难理解的点,作出说明:

  • ArgDef中的3-6四个字段,是为了描述·输入或输出的类型。当输入或输出是一个张量时,type或type_attr被设置为这个张量的数据类型,当输入或输出是一个由相同数据类型的张量构成的序列时,number_attr被设置为int对应的标识,当输入或输出是一个由张量构成的列表时,type_list_attr被设置为list(type)对应的标识;
  • AttrDef中的has_minimum字段,表明这个属性是否有最小值,如果数据类型是int,那么minimum就是允许的最小值,如果数据类型是列表,那么minimum就是列表的最短长度,is_aggregate这个字段,表明当前的操作是否是可聚集的。一个可聚集的操作是,能接受任意数量相同类型和形状的输入,并且保持输出与每个输入的类型和形状相同,这个字段对于操作的优化非常重要,如果一个操作是可聚集的,并且其输入来自多个不同的设备,那么我们就可以把聚集优化成一个树形的操作,先在设备内部对输入做聚集,最后在操作所在的设备集中,这样可以提高效率。这种优化对于分布式的机器学习模型训练非常有帮助,Spark ML中的TreeAggregate就实现了这样的优化。
  • is_stateful这个字段,表明当前的op是否带有状态的,什么操op会带有状态呢?比如Variable;

通过protoc工具用proto文件生成.h文件。命令为:

1
2
3
4
5
./protoc \ 
-I=/home/anan/tensorflow1.12/tensorflow-1.12.0/ \
--cpp_out=/home/anan/tensorflow1.12/tensorflow-1.12.0/tensorflow/core/framework/
/home/z00354782/tensorflow_1.12/tensorflow-
1.12.0/tensorflow/core/framework/op_def.proto

从中找到OpDef的定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
class OpDef : public::google::protobuf::Message {
private:
::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> input_arg_;
::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> output_arg_;
::google::protobuf::RepeatedPtrField<::tensorflow::OpDef_ArgDef> attr_;
::google::protobuf::internal::ArenaStringPtr name_;
::google::protobuf::internal::ArenaStringPtr summary_;
::google::protobuf::internal::ArenaStringPtr description_;
bool is_commutative_;
bool is_aggregate_;
bool is_stateful_;
bool allows_uninitialized_input_;
}

为了方便进行OpDef的构建,TF还设计了OpDefBuilder类,它的私有数据成员如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// Builder class passed to the REGISTER_OP() macro.
class OpDefBuilder {
public:
// ...

private:
OpRegistrationData op_reg_data_;
std::vector<string> attrs_;
std::vector<string> inputs_;
std::vector<string> outputs_;
std::vector<string> control_outputs_;
string doc_;
std::vector<string> errors_;
};

可以看到,除了errors_字段外,其他内容几乎就是把OpDef的结构原封不动的搬了过来。

op_def_builder.h中还有一个新的结构,OpRegistrationData,他的结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
bool is_function = false)
: op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}

OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
bool is_function_op = false;
};

在这个结构中,除了屋面熟知的OpDef之外,还包含一个OpShapeInferenceFn结构,他的定义如下:

1
2
typedef std::function<Status(shape_inference::InferenceContext* c)>
OpShapeInferenceFn;

这个结构的定义中,涉及到了我们后面要讲到的形状推断的内容,这里我们只需要知道,OpShapeInferenceFn是一个帮助操作根据输入形状对输出形状进行推断的函数即可。

Op注册

上面的例子中使用REGISTER_OP宏进行Op注册,看一下这个宏的定义:

1
2
3
4
5
6
7
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \
TF_ATTRIBUTE_UNUSED = \
::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \
name)>(name)

注:__COUNTER__宏表示自动计数,最终的定义是register_op0register_op1register_op2依次往后排。

1
2
3
4
5
6
7
8
9
static ::tensorflow::register_op::OpDefBuilderReceiver register_op0   = \
::tensorflow::register_op::OpDefBuilderWrapper<true>("Argmax") \
.Input("input: T")
.Input("dimension: Tidx")
.Output("output: output_type")
.Attr("T: numbertype")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("output_type: {int32, int64} = DT_INT64")
.SetShapeFn(ArgOpShape);

也就是说,生成一个OpDefBuilderWrapper对象,并链式调用它的InputOutputAttr等方法。

OpDefBuilderWrapper的定义为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

// Template specialization that forwards all calls to the contained builder.
template <>
class OpDefBuilderWrapper<true> {
public:
explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
OpDefBuilderWrapper<true>& Attr(string spec) {
builder_.Attr(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& Input(string spec) {
builder_.Input(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& Output(string spec) {
builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
builder_.SetIsCommutative();
return *this;
}
OpDefBuilderWrapper<true>& SetIsAggregate() {
builder_.SetIsAggregate();
return *this;
}
OpDefBuilderWrapper<true>& SetIsStateful() {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
}
OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
builder_.Deprecated(version, std::move(explanation));
return *this;
}
OpDefBuilderWrapper<true>& Doc(string text) {
builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
Status (*fn)(shape_inference::InferenceContext*)) {
builder_.SetShapeFn(fn);
return *this;
}
const ::tensorflow::OpDefBuilder& builder() const { return builder_; }

private:
mutable ::tensorflow::OpDefBuilder builder_;
};

通过链式调用,把Input、Output、Attr等描述保存到OpDefBuiIder的attrs_、inputs_、outputs_属性中。例如,Input的处理为:

1
2
3
4
OpDefBuilder& OpDefBuilder::Input(string spec) {
inputs_.push_back(std::move(spec));
return *this;
}

OpDefBuilderWrapperOpDefBuilder的包装器,其成员包含一个OpDefBuilder的对象,它的API都是设置型的,且都返回对象本身,提供 链式的方式进行属性设置。值得注意的是,这个类名后面跟着一个true,它的含义等会再看。

最终把OpDefBuilderWrapper类型的对象用于构造OpDefBuilderReceiver

OpDefBuilderReceiver定义为:

1
2
3
4
5
6
7
8
9
10
struct OpDefBuilderReceiver {
// To call OpRegistry::Global()->Register(...), used by the
// REGISTER_OP macro below.
// Note: These are implicitly converting constructors.
OpDefBuilderReceiver(
const OpDefBuilderWrapper<true>& wrapper); // NOLINT(runtime/explicit)
constexpr OpDefBuilderReceiver(const OpDefBuilderWrapper<false>&) {
} // NOLINT(runtime/explicit)
};
} // namespace register_op

OpDefBuilderReceiver的构造函数的实现为:

1
2
3
4
5
6
7
OpDefBuilderReceiver::OpDefBuilderReceiver(
const OpDefBuilderWrapper<true>& wrapper) {
OpRegistry::Global()->Register(
[wrapper](OpRegistrationData* op_reg_data) -> Status {
return wrapper.builder().Finalize(op_reg_data);
});
}

相当于是OpDefBuilderWrapper构造时,以OpDefBuilderWrapper为参数,在构造函数中调用OpRegistry::Global()->Register(...)

也就是说,REGISTER_OP绕了一圈,先用OpDefBuilderWrapper对操作进行封装,然后把它作为参数传递给OpDefBuilderReceiver的构造函数,而在这个构造函数中,完成了对算子的注册。

真正的注册过程就是OpRegistryRegister方法中完成的,下面具体看一下注册类的实现。

注册类

为了方便对操作进行统一管理,TF提出了OP注册器的概念。这个OP注册器的作用,是为各种OP提供一个统一的管理接囗。

操作注册类的继承结构如下:

image
image

其中,OpRegistryInterface是一个接口类,它提供了注册类最基础的查找功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Users that want to look up an OpDef by type name should take an
// OpRegistryInterface. Functions accepting a
// (const) OpRegistryInterface* may call LookUp() from multiple threads.
class OpRegistryInterface {
public:
virtual ~OpRegistryInterface();

// Returns an error status and sets *op_reg_data to nullptr if no OpDef is
// registered under that name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
virtual Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const = 0;

// Shorthand for calling LookUp to get the OpDef.
Status LookUpOpDef(const string& op_type_name, const OpDef** op_def) const;
};

OpRegistry类继承了OpRegistryInterface类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// The standard implementation of OpRegistryInterface, along with a
// global singleton used for registering ops via the REGISTER
// macros below. Thread-safe.
//
// Example registration:
// OpRegistry::Global()->Register(
// [](OpRegistrationData* op_reg_data)->Status {
// // Populate *op_reg_data here.
// return Status::OK();
// });
class OpRegistry : public OpRegistryInterface {
public:
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

OpRegistry();
~OpRegistry() override;

void Register(const OpRegistrationDataFactory& op_data_factory);

Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override;

// Fills *ops with all registered OpDefs (except those with names
// starting with '_' if include_internal == false) sorted in
// ascending alphabetical order.
void Export(bool include_internal, OpList* ops) const;

// Returns ASCII-format OpList for all registered OpDefs (except
// those with names starting with '_' if include_internal == false).
string DebugString(bool include_internal) const;

// A singleton available at startup.
static OpRegistry* Global();

// Get all registered ops.
void GetRegisteredOps(std::vector<OpDef>* op_defs);

// Get all `OpRegistrationData`s.
void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);

// Watcher, a function object.
// The watcher, if set by SetWatcher(), is called every time an op is
// registered via the Register function. The watcher is passed the Status
// obtained from building and adding the OpDef to the registry, and the OpDef
// itself if it was successfully built. A watcher returns a Status which is in
// turn returned as the final registration status.
typedef std::function<Status(const Status&, const OpDef&)> Watcher;

// An OpRegistry object has only one watcher. This interface is not thread
// safe, as different clients are free to set the watcher any time.
// Clients are expected to atomically perform the following sequence of
// operations :
// SetWatcher(a_watcher);
// Register some ops;
// op_registry->ProcessRegistrations();
// SetWatcher(nullptr);
// Returns a non-OK status if a non-null watcher is over-written by another
// non-null watcher.
Status SetWatcher(const Watcher& watcher);

// Process the current list of deferred registrations. Note that calls to
// Export, LookUp and DebugString would also implicitly process the deferred
// registrations. Returns the status of the first failed op registration or
// Status::OK() otherwise.
Status ProcessRegistrations() const;

// Defer the registrations until a later call to a function that processes
// deferred registrations are made. Normally, registrations that happen after
// calls to Export, LookUp, ProcessRegistrations and DebugString are processed
// immediately. Call this to defer future registrations.
void DeferRegistrations();

// Clear the registrations that have been deferred.
void ClearDeferredRegistrations();

private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
// time it is called. Prints a fatal log if any op registration fails.
bool MustCallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Calls the functions in deferred_ and registers their OpDef's
// It returns the Status of the first failed op registration or Status::OK()
// otherwise.
Status CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
const EXCLUSIVE_LOCKS_REQUIRED(mu_);

Status LookUpSlow(const string& op_type_name,
const OpRegistrationData** op_reg_data) const;

mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
mutable std::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_);
// Values are owned.
mutable std::unordered_map<string, const OpRegistrationData*> registry_
GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);

// Registry watcher.
mutable Watcher watcher_ GUARDED_BY(mu_);
};

OpRegistry类是单例模式,通过Global获取单例对象,并且是线程安全的。

注册函数Register的定义为:

1
2
3
4
5
6
7
8
void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
deferred_.push_back(op_data_factory);
}
}

其中,OpRegistrationDataFactory是一个function类型:

1
typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;

也就是说,Register注册时传入的是一个函数,最终在Register中完成对函数的调用。

从代码看,只有RegisterAlreadyLocked(op_data_factory)中可能产生对op_data_factory的调用,所以可以从这儿入手看注册过程。姑且不论initialized_字段的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// Add 'def' to the registry with additional data 'data'. On failure, or if
// there is already an OpDef with that name registered, returns a non-okay
// status.
Status OpRegistry::RegisterAlreadyLocked(
const OpRegistrationDataFactory& op_data_factory) const {
std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
Status s = op_data_factory(op_reg_data.get());
if (s.ok()) {
s = ValidateOpDef(op_reg_data->op_def);
if (s.ok() &&
!gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
op_reg_data.get())) {
s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
}
}
Status watcher_status = s;
if (watcher_) {
watcher_status = watcher_(s, op_reg_data->op_def);
}
if (s.ok()) {
op_reg_data.release();
} else {
op_reg_data.reset();
}
return watcher_status;
}

函数的注释写的很清楚了,新增一个def到register中。失败或者算子name已经被注册,返回非okey结果。

这个函数中构造了一个OpRegistrationData对象,并最终对op_data_factory进行了调用。

OpRegistrationData的定义如下,其中包含了一个OpDef的变量。

1
2
3
4
5
6
7
8
9
10
11
12
struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn,
bool is_function = false)
: op_def(def), shape_inference_fn(fn), is_function_op(is_function) {}

OpDef op_def;
OpShapeInferenceFn shape_inference_fn;
bool is_function_op = false;
};

op_data_factory的调用构造了一个OpRegistrationData空对象,最终进入wrapper.builder().Finalize(op_reg_data)中进行处理。

wrapper.builder()返回的是OpDefBuilder对象。函数Finalize的实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const {
std::vector<string> errors = errors_;
*op_reg_data = op_reg_data_;

OpDef* op_def = &op_reg_data->op_def;
for (StringPiece attr : attrs_) {
FinalizeAttr(attr, op_def, &errors);
}
for (StringPiece input : inputs_) {
FinalizeInputOrOutput(input, false, op_def, &errors);
}
for (StringPiece output : outputs_) {
FinalizeInputOrOutput(output, true, op_def, &errors);
}
for (StringPiece control_output : control_outputs_) {
FinalizeControlOutput(control_output, op_def, &errors);
}
FinalizeDoc(doc_, op_def, &errors);

if (errors.empty()) return Status::OK();
return errors::InvalidArgument(str_util::Join(errors, "\n"));
}

这里把最开始wrapper中保存的inputs_outputs_attrs_等信息依次取出,用于构建OpDef对象。

得到的OpDef对象首先经过ValidateOpDef(op_reg_data->op_def);进行校验,然后插入到Registerregistry_中。

1
2
gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
op_reg_data.get()))

到这里就完成了一个算子的注册过程。

下面这个代码值得注意:

1
2
3
4
5
if (initialized_) {
TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
deferred_.push_back(op_data_factory);
}

只有在initialized_是true时,才进行注册,否则把op_data_factory放到deferred_这个vector中。

注意到Register类有如下两个方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
// time it is called. Prints a fatal log if any op registration fails.
bool OpRegistry::MustCallDeferred() const {
if (initialized_) return false;
initialized_ = true;
for (size_t i = 0; i < deferred_.size(); ++i) {
TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
}
deferred_.clear();
return true;
}

// Calls the functions in deferred_ and registers their OpDef's
// It returns the Status of the first failed op registration or Status::OK()
// otherwise.
Status OpRegistry::CallDeferred() const {
if (initialized_) return Status::OK();
initialized_ = true;
for (size_t i = 0; i < deferred_.size(); ++i) {
Status s = RegisterAlreadyLocked(deferred_[i]);
if (!s.ok()) {
return s;
}
}
deferred_.clear();
return Status::OK();
}

可以看出,在特定的调用中,把deferred_中保存的算子注册函数全部取出,执行RegisterAlreadyLocked真正的执行算子注册过程。

这里有几点值得关注:

  • 注册函数Register的输入是一个函数引用,这个函数接收一个OpRegistrationData指针作为输入;
  • Watcher是一个监视器,当每次注册一个算子的时候,在注册步骤的最后都要调用一下这个监视器,它可方便对注册的操作进行监控,所有的算子注册动作都逃不过它的眼,可以根据需求定制特殊Watcher;
  • registry_`是已注册的算子真正存放的位置,它的结构很简单,是一个key为算子名、value为算子数据的map;
  • initialized_deferred_是与注册模式相关的两个数据成员,注册器在注册操作时,分为两种模式:
    • 即时注册模式懒惰注册模式
    • 注册模式通过initialized_字段区分,true为即时注册模式,false为懒惰注册模式;
    • 在懒惰注册模式中,被注册的算子先 被保存在deferred_向量中,在特定的函数调用时再将deferred_中的算子注册到registryy_,而即时注册模式下,待注册的算子不用经过deferred_,直接注册到registry_。 -懒惰注册模式的使用场景是,部分算子组合的注册是原子的,即要么全部注册,要么全部不注册,因为这些算子之间可能会有相互依赖关系。
    • 构造函数将initialized_设置为false,进入懒惰注册模式,随后一旦调用了MustCallDeferred或者CallDeferred中的任意一个,都会将initialized_设置为true,进入即时注册模式。想要重新返回懒惰注册模式也很简单,只需要调用DeferRegistrations即可。

参考

https://www.cnblogs.com/jicanghai/p/9539513.html

注:文中代码基于tensorflow1.12.0版本。