message OpDef { // Op names starting with an underscore are reserved for internal use. // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". string name = 1;
// For describing inputs and outputs. message ArgDef { // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". string name = 1;
// Human readable description. string description = 2;
DataType type = 3; string type_attr = 4; // if specified, attr must have type "type" string number_attr = 5; // if specified, attr must have type "int" // If specified, attr must have type "list(type)", and none of // type, type_attr, and number_attr may be specified. string type_list_attr = 6;
// Users that want to look up an OpDef by type name should take an // OpRegistryInterface. Functions accepting a // (const) OpRegistryInterface* may call LookUp() from multiple threads. classOpRegistryInterface { public: virtual ~OpRegistryInterface();
// Returns an error status and sets *op_reg_data to nullptr if no OpDef is // registered under that name, otherwise returns the registered OpDef. // Caller must not delete the returned pointer. virtual Status LookUp(conststring& op_type_name, const OpRegistrationData** op_reg_data)const= 0;
// Shorthand for calling LookUp to get the OpDef. Status LookUpOpDef(conststring& op_type_name, const OpDef** op_def)const; };
// The standard implementation of OpRegistryInterface, along with a // global singleton used for registering ops via the REGISTER // macros below. Thread-safe. // // Example registration: // OpRegistry::Global()->Register( // [](OpRegistrationData* op_reg_data)->Status { // // Populate *op_reg_data here. // return Status::OK(); // }); classOpRegistry :public OpRegistryInterface { public: typedefstd::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
Status LookUp(conststring& op_type_name, const OpRegistrationData** op_reg_data)constoverride;
// Fills *ops with all registered OpDefs (except those with names // starting with '_' if include_internal == false) sorted in // ascending alphabetical order. voidExport(bool include_internal, OpList* ops)const;
// Returns ASCII-format OpList for all registered OpDefs (except // those with names starting with '_' if include_internal == false). stringDebugString(bool include_internal)const;
// A singleton available at startup. static OpRegistry* Global();
// Get all registered ops. voidGetRegisteredOps(std::vector<OpDef>* op_defs);
// Get all `OpRegistrationData`s. voidGetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
// Watcher, a function object. // The watcher, if set by SetWatcher(), is called every time an op is // registered via the Register function. The watcher is passed the Status // obtained from building and adding the OpDef to the registry, and the OpDef // itself if it was successfully built. A watcher returns a Status which is in // turn returned as the final registration status. typedefstd::function<Status(const Status&, const OpDef&)> Watcher;
// An OpRegistry object has only one watcher. This interface is not thread // safe, as different clients are free to set the watcher any time. // Clients are expected to atomically perform the following sequence of // operations : // SetWatcher(a_watcher); // Register some ops; // op_registry->ProcessRegistrations(); // SetWatcher(nullptr); // Returns a non-OK status if a non-null watcher is over-written by another // non-null watcher. Status SetWatcher(const Watcher& watcher);
// Process the current list of deferred registrations. Note that calls to // Export, LookUp and DebugString would also implicitly process the deferred // registrations. Returns the status of the first failed op registration or // Status::OK() otherwise. Status ProcessRegistrations()const;
// Defer the registrations until a later call to a function that processes // deferred registrations are made. Normally, registrations that happen after // calls to Export, LookUp, ProcessRegistrations and DebugString are processed // immediately. Call this to defer future registrations. voidDeferRegistrations();
// Clear the registrations that have been deferred. voidClearDeferredRegistrations();
private: // Ensures that all the functions in deferred_ get called, their OpDef's // registered, and returns with deferred_ empty. Returns true the first // time it is called. Prints a fatal log if any op registration fails. boolMustCallDeferred()constEXCLUSIVE_LOCKS_REQUIRED(mu_);
// Calls the functions in deferred_ and registers their OpDef's // It returns the Status of the first failed op registration or Status::OK() // otherwise. Status CallDeferred()constEXCLUSIVE_LOCKS_REQUIRED(mu_);
// Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) constEXCLUSIVE_LOCKS_REQUIRED(mu_);
Status LookUpSlow(conststring& op_type_name, const OpRegistrationData** op_reg_data)const;
mutable mutex mu_; // Functions in deferred_ may only be called with mu_ held. mutablestd::vector<OpRegistrationDataFactory> deferred_ GUARDED_BY(mu_); // Values are owned. mutablestd::unordered_map<string, const OpRegistrationData*> registry_ GUARDED_BY(mu_); mutablebool initialized_ GUARDED_BY(mu_);
// Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. Status OpRegistry::RegisterAlreadyLocked( const OpRegistrationDataFactory& op_data_factory)const{ std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData); Status s = op_data_factory(op_reg_data.get()); if (s.ok()) { s = ValidateOpDef(op_reg_data->op_def); if (s.ok() && !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), op_reg_data.get())) { s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); } } Status watcher_status = s; if (watcher_) { watcher_status = watcher_(s, op_reg_data->op_def); } if (s.ok()) { op_reg_data.release(); } else { op_reg_data.reset(); } return watcher_status; }
// Ensures that all the functions in deferred_ get called, their OpDef's // registered, and returns with deferred_ empty. Returns true the first // time it is called. Prints a fatal log if any op registration fails. boolOpRegistry::MustCallDeferred()const{ if (initialized_) returnfalse; initialized_ = true; for (size_t i = 0; i < deferred_.size(); ++i) { TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); } deferred_.clear(); returntrue; }
// Calls the functions in deferred_ and registers their OpDef's // It returns the Status of the first failed op registration or Status::OK() // otherwise. Status OpRegistry::CallDeferred()const{ if (initialized_) return Status::OK(); initialized_ = true; for (size_t i = 0; i < deferred_.size(); ++i) { Status s = RegisterAlreadyLocked(deferred_[i]); if (!s.ok()) { return s; } } deferred_.clear(); return Status::OK(); }