Skip to content

Commit c077d83

Browse files
authored
Merge pull request #5 from sandialabs/convert_tensor_index_base
Allow setting index-base for input/output tensor in convert_tensor
2 parents 3c451ad + c290e2a commit c077d83

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

src/Genten_TensorIO.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,9 @@ queryFile()
10881088
template <typename ExecSpace>
10891089
TensorWriter<ExecSpace>::
10901090
TensorWriter(const std::string& fname,
1091-
const bool comp) : filename(fname), compressed(comp) {}
1091+
const ttb_indx ib,
1092+
const bool comp) :
1093+
filename(fname), index_base(ib), compressed(comp) {}
10921094

10931095
template <typename ExecSpace>
10941096
void
@@ -1119,7 +1121,9 @@ writeText(const SptensorT<ExecSpace>& X) const
11191121
{
11201122
Sptensor X_host = create_mirror_view(X);
11211123
deep_copy(X_host, X);
1122-
export_sptensor(filename, X_host, true, 15, true, compressed);
1124+
if (index_base != 0 && index_base != 1)
1125+
Genten::error("Writing a sparse tensor requires index base of 0 or 1");
1126+
export_sptensor(filename, X_host, true, 15, index_base==0, compressed);
11231127
}
11241128

11251129
template <typename ExecSpace>

src/Genten_TensorIO.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ template <typename ExecSpace>
188188
class TensorWriter {
189189
public:
190190
TensorWriter(const std::string& filename,
191+
const ttb_indx index_base = 0,
191192
const bool compressed = false);
192193

193194
void writeBinary(const SptensorT<ExecSpace>& X,
@@ -199,6 +200,7 @@ class TensorWriter {
199200
void writeText(const TensorT<ExecSpace>& X) const;
200201
private:
201202
std::string filename;
203+
ttb_indx index_base;
202204
bool compressed;
203205
};
204206

tools/convert_tensor.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
template <typename TensorType>
4444
void print_tensor_stats(const TensorType& x)
4545
{
46-
std::cout << " Stats: ";
46+
std::cout << " Stats: ";
4747
const ttb_indx nd = x.ndims();
4848
for (ttb_indx i=0; i<nd; ++i) {
4949
std::cout << x.size(i);
@@ -65,19 +65,22 @@ void print_tensor_stats(const TensorType& x)
6565

6666
template <typename TensorType>
6767
void save_tensor(const TensorType& x_in, const std::string& filename,
68-
const std::string format, const std::string type, bool gz,
68+
const std::string format, const std::string type,
69+
const ttb_indx index_base, bool gz,
6970
bool header)
7071
{
7172
std::cout << "\nOutput:\n"
72-
<< " File: " << filename << std::endl
73-
<< " Format: " << format << std::endl
74-
<< " Type: " << type;
73+
<< " File: " << filename << std::endl
74+
<< " Index Base: " << index_base << std::endl
75+
<< " Format: " << format << std::endl
76+
<< " Type: " << type;
7577
if (type == "text" && gz)
7678
std::cout << " (compressed)";
7779
if (type == "binary" && !header)
7880
std::cout << " (no header)";
7981
std::cout << std::endl;
80-
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(filename,gz);
82+
Genten::TensorWriter<Genten::DefaultHostExecutionSpace> writer(
83+
filename,index_base,gz);
8184
if (format == "sparse") {
8285
Genten::Sptensor x_out(x_in);
8386
print_tensor_stats(x_out);
@@ -97,10 +100,12 @@ void save_tensor(const TensorType& x_in, const std::string& filename,
97100
}
98101

99102
void read_tensor_file(const std::string& filename,
103+
const ttb_indx index_base,
100104
std::string& format, std::string& type, bool gz,
101105
Genten::Sptensor& x_sparse, Genten::Tensor& x_dense)
102106
{
103-
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(filename,0,gz);
107+
Genten::TensorReader<Genten::DefaultHostExecutionSpace> reader(
108+
filename,index_base,gz);
104109
reader.read();
105110

106111
if (reader.isSparse()) {
@@ -125,14 +130,16 @@ int main(int argc, char* argv[])
125130
auto args = Genten::build_arg_list(argc,argv);
126131
const bool help =
127132
Genten::parse_ttb_bool(args, "--help", "--no-help", false);
128-
if (argc < 9 || argc > 11 || help) {
133+
if (argc < 9 || argc > 16 || help) {
129134
std::cout << "\nconvert-tensor: a helper utility for converting tensor data between\n"
130135
<< "tensor formats (sparse or dense), and file types (text or binary).\n\n"
131136
<< "Usage: " << argv[0] << " --input-file <string> --output-file <string> --output-format <sparse|dense> --output-type <text|binary> [options] \n"
132137
<< "Options:\n"
133-
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
134-
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
135-
<< " --output-header Write header to output file (binary-only, default: on)\n";
138+
<< " --input-gz Input tensor is Gzip compressed (text-only, default: off)\n"
139+
<< " --output-gz Output tensor is Gzip compressed (text-only, default: off)\n"
140+
<< " --output-header Write header to output file (binary-only, default: on)\n"
141+
<< " --input-index-base Starting index for input tensor (sparse-only, default: 0)\n"
142+
<< " --output-index-base Starting index for output tensor (sparse-only, default: 0)\n";
136143
return 0;
137144
}
138145

@@ -153,6 +160,10 @@ int main(int argc, char* argv[])
153160
Genten::parse_ttb_bool(args, "--output-gz", "--no-output-gz", false);
154161
const bool output_header =
155162
Genten::parse_ttb_bool(args, "--output-header", "--no-output-header", true);
163+
const ttb_indx input_index_base =
164+
Genten::parse_ttb_indx(args, "--input-index-base", 0, 0, INT_MAX);
165+
const ttb_indx output_index_base =
166+
Genten::parse_ttb_indx(args, "--output-index-base", 0, 0, INT_MAX);
156167

157168
if (input_filename == "")
158169
Genten::error("input filename must be specified");
@@ -168,29 +179,30 @@ int main(int argc, char* argv[])
168179
Genten::error("No header option only supported for binary output files");
169180

170181
std::cout << "\nInput:\n"
171-
<< " File: " << input_filename << std::endl;
182+
<< " File: " << input_filename << std::endl
183+
<< " Index base: " << input_index_base << std::endl;
172184

173185
std::string input_format = "unknown";
174186
std::string input_type = "unknown";
175187
Genten::Sptensor x_sparse;
176188
Genten::Tensor x_dense;
177-
read_tensor_file(input_filename, input_format, input_type, input_gz,
178-
x_sparse, x_dense);
189+
read_tensor_file(input_filename, input_index_base, input_format, input_type,
190+
input_gz, x_sparse, x_dense);
179191

180-
std::cout << " Format: " << input_format << std::endl
181-
<< " Type: " << input_type;
192+
std::cout << " Format: " << input_format << std::endl
193+
<< " Type: " << input_type;
182194
if (input_type == "text" && input_gz)
183195
std::cout << " (compressed)";
184196
std::cout << std::endl;
185197
if (input_format == "sparse") {
186198
print_tensor_stats(x_sparse);
187199
save_tensor(x_sparse, output_filename, output_format, output_type,
188-
output_gz, output_header);
200+
output_index_base, output_gz, output_header);
189201
}
190202
else if (input_format == "dense") {
191203
print_tensor_stats(x_dense);
192204
save_tensor(x_dense, output_filename, output_format, output_type,
193-
output_gz, output_header);
205+
output_index_base, output_gz, output_header);
194206
}
195207
else
196208
Genten::error("Invalid input tensor format!");

0 commit comments

Comments
 (0)