//===- TensorSpec.cpp - tensor type abstraction ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implementation file for the abstraction of a tensor type, and JSON loading // utils. // //===----------------------------------------------------------------------===// #include "llvm/Config/config.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/TensorSpec.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/JSON.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/raw_ostream.h" #include #include #include using namespace llvm; namespace llvm { #define TFUTILS_GETDATATYPE_IMPL(T, E) \ template <> TensorType TensorSpec::getDataType() { return TensorType::E; } SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) #undef TFUTILS_GETDATATYPE_IMPL static std::array(TensorType::Total)> TensorTypeNames{"INVALID", #define TFUTILS_GETNAME_IMPL(T, _) #T, SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL) #undef TFUTILS_GETNAME_IMPL }; StringRef toString(TensorType TT) { return TensorTypeNames[static_cast(TT)]; } void TensorSpec::toJSON(json::OStream &OS) const { OS.object([&]() { OS.attribute("name", name()); OS.attribute("type", toString(type())); OS.attribute("port", port()); OS.attributeArray("shape", [&]() { for (size_t D : shape()) OS.value(static_cast(D)); }); }); } TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, size_t ElementSize, const std::vector &Shape) : Name(Name), Port(Port), Type(Type), Shape(Shape), ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, std::multiplies())), ElementSize(ElementSize) {} std::optional getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value) { auto EmitError = [&](const llvm::Twine &Message) -> std::optional { std::string S; llvm::raw_string_ostream OS(S); OS << Value; Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); return std::nullopt; }; // FIXME: accept a Path as a parameter, and use it for error reporting. json::Path::Root Root("tensor_spec"); json::ObjectMapper Mapper(Value, Root); if (!Mapper) return EmitError("Value is not a dict"); std::string TensorName; int TensorPort = -1; std::string TensorType; std::vector TensorShape; if (!Mapper.map("name", TensorName)) return EmitError("'name' property not present or not a string"); if (!Mapper.map("type", TensorType)) return EmitError("'type' property not present or not a string"); if (!Mapper.map("port", TensorPort)) return EmitError("'port' property not present or not an int"); if (!Mapper.map>("shape", TensorShape)) return EmitError("'shape' property not present or not an int array"); #define PARSE_TYPE(T, E) \ if (TensorType == #T) \ return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); SUPPORTED_TENSOR_TYPES(PARSE_TYPE) #undef PARSE_TYPE return std::nullopt; } } // namespace llvm