2020#define COMPRESS_STATS 0
2121
2222#include < stddef.h>
23+ #include < stdint.h>
2324#include < stdio.h>
2425
25- #include < array>
26- #include < cstdio>
2726#include < cstring>
2827#include < string>
2928#include < unordered_map>
3534#include " compression/io.h"
3635#include " compression/shared.h"
3736// IWYU pragma: end_exports
38- #include " compression/distortion.h"
39- #include " hwy/aligned_allocator.h"
40- #include " hwy/base.h" // BF16
41- #include " hwy/contrib/thread_pool/thread_pool.h"
37+ #include " util/allocator.h"
4238#if COMPRESS_STATS
39+ #include " compression/distortion.h"
4340#include " hwy/stats.h"
4441#endif
4542
4643namespace gcpp {
4744
48- // Compressed representation of floating-point elements. The array length may
49- // differ from the number of elements. Associated operations such as Dot are
50- // implemented in SIMD code and are thus non-member functions.
51- template <typename Packed, size_t kCapacity >
52- class CompressedArray {
53- public:
54- using value_type = Packed;
55-
56- // Note that whenever you access data(), you have to consider a scale() that
57- // may be different from 1.0f.
58- Packed* data () { return data_.data (); }
59- const Packed* data () const { return data_.data (); }
60- // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
61- // calling it means "I am sure the scale is 1 and therefore ignore the scale".
62- // A scale of 0 indicates that the scale has likely never been set, so is
63- // "implicitly 1".
64- const Packed* data_scale1 () const {
65- HWY_ASSERT (scale () == 1 .f || scale () == 0 .f );
66- return data_.data ();
67- }
68-
69- // Decoded elements should be multiplied by this to restore their original
70- // range. This is required because SfpStream can only encode a limited range
71- // of magnitudes.
72- float scale () const { return scale_[0 ]; }
73- void set_scale (float scale) { scale_[0 ] = scale; }
74-
75- constexpr size_t NumElements () const { return kCapacity ; }
76-
77- // Returns total number of packed elements for `BlobReader::Enqueue` and
78- // `Compress`. This differs from `NumElements` for `Packed=NuqStream`.
79- PackedSpan<Packed> GetSpan () { return MakeSpan (data (), data_.size ()); }
80- PackedSpan<const Packed> GetSpan () const {
81- return MakeSpan (data (), data_.size ());
82- }
83-
84- private:
85- std::array<Packed, CompressedArrayElements<Packed>(kCapacity )> data_;
86- // Blobs are at least kBlobAlign bytes anyway.
87- float scale_[kBlobAlign / sizeof (float )];
88- };
89-
90- // Yet another array class. This one is intended to be compatible with
91- // CompressedArray, but have both run-time sizing and compile-time constant
92- // size.
93- // It also provides easy conversion from/to a table of contents for a BlobStore
94- // file, and a templated (compile-time) accessor for a 2-d array of fixed inner
95- // dimension and type.
96- // The base class is intended for accessing the metadata, without needing to
97- // know any of the template arguments.
98- // It holds only a borrowed pointer to the data, but all metadata.
45+ // Base class for rank-1 or 2 tensors (vector or matrix).
46+ // Supports both dynamic and compile-time sizing.
47+ // Holds metadata and a non-owning pointer to the data, owned by the derived
48+ // MatStorageT class.
49+ // This class also provides easy conversion from/to a table of contents for a
50+ // BlobStore file, and a templated (compile-time) accessor for a 2-d array of
51+ // fixed inner dimension and type.
9952// It is designed to be put in a vector, and has default copy and operator=, so
10053// it is easy to read/write a blob_store file.
101- // The derived class or an external class owns the data.
10254class MatPtr {
10355 public:
10456 // Full constructor for dynamic sizing.
@@ -111,12 +63,12 @@ class MatPtr {
11163 rows_(rows),
11264 cols_(cols),
11365 ptr_(nullptr ) {}
114- // Default constructor doesn't set anything .
66+ // Default is to leave all fields default-initialized .
11567 MatPtr () = default;
11668 virtual ~MatPtr ();
11769
11870 // Number of hwy::uint128_t in a TOC entry.
119- // Note that the old-style BlobStore files Only have a list of keys and size.
71+ // Note that the old-style BlobStore files only have a list of keys and size.
12072 // The new-style BlobStore files have an entry called "toc" that contains a
12173 // vector of 4-tuples of
12274 // (name, type, (num_elements, element_size), (rows, cols)).
@@ -144,6 +96,7 @@ class MatPtr {
14496 }
14597
14698 // Compatibility interface for CompressedArray.
99+ // TODO: remove.
147100 template <typename T>
148101 T* data () {
149102 return HWY_RCAST_ALIGNED (T*, ptr_);
@@ -177,7 +130,6 @@ class MatPtr {
177130
178131 // Returns the number of bytes in the array.
179132 size_t SizeBytes () const { return num_elements_ * element_size_; }
180- size_t CompressedSize () const { return SizeBytes (); }
181133
182134 // Returns the number of rows in the 2-d array (outer dimension).
183135 size_t Rows () const { return rows_; }
@@ -211,8 +163,8 @@ class MatPtr {
211163 }
212164
213165 // Calls func on the upcasted type. Since MatPtr by design is not templated,
214- // here we provide a way to get to the derived type, provided that the type
215- // matches one of a known short-list .
166+ // here we provide a way to get to the derived type, provided that `Type()`
167+ // is one of the strings returned by `TypeName()` .
216168 template <class FuncT , typename ... TArgs>
217169 decltype (auto ) CallUpcasted(FuncT& func, TArgs&&... args);
218170
@@ -243,8 +195,6 @@ class MatPtr {
243195template <typename MatT>
244196class MatPtrT : public MatPtr {
245197 public:
246- using value_type = MatT;
247-
248198 // Full constructor for dynamic sizing.
249199 MatPtrT (const std::string& name, size_t rows, size_t cols)
250200 : MatPtr(name, TypeEnum<MatT>(), sizeof (MatT), rows, cols) {}
@@ -276,20 +226,13 @@ class MatPtrT : public MatPtr {
276226 }
277227 return name;
278228 }
229+
279230 // Sets the number of elements in the array. For use when the number of
280231 // elements is != rows * cols ONLY.
281232 void SetNumElements (size_t num_elements) {
282233 num_elements_ = CompressedArrayElements<MatT>(num_elements);
283234 }
284235
285- // Fast 2-d accessor for a 2-d array of fixed inner dimension and type.
286- template <typename T = MatT, size_t kInner >
287- const T& AtT (size_t row, size_t col) const {
288- size_t index = row * kInner + col;
289- HWY_DASSERT (index < num_elements_);
290- return HWY_RCAST_ALIGNED (const T*, ptr_)[index];
291- }
292-
293236 // 2-d Accessor for a specific type but with a dynamic inner dimension.
294237 template <typename T = MatT>
295238 const T& At (size_t row, size_t col) const {
@@ -299,17 +242,15 @@ class MatPtrT : public MatPtr {
299242 }
300243
301244 // 1-d Accessor for a specific type.
302- template < typename T = MatT>
303- const T & At (size_t index) const {
245+ // TODO: replace this with a Foreach(), or at least a ForEachRow().
246+ const MatT & At (size_t index) const {
304247 HWY_DASSERT (index < num_elements_);
305- return HWY_RCAST_ALIGNED (const T*, ptr_)[index];
306- }
307- template <typename T = MatT>
308- T& At (size_t index) {
309- return HWY_RCAST_ALIGNED (T*, ptr_)[index];
248+ return HWY_RCAST_ALIGNED (const MatT*, ptr_)[index];
310249 }
250+ MatT& At (size_t index) { return HWY_RCAST_ALIGNED (MatT*, ptr_)[index]; }
311251
312252 // Compatibility interface for CompressedArray.
253+ // TODO: remove
313254 template <typename T = MatT>
314255 T* data () {
315256 return HWY_RCAST_ALIGNED (T*, ptr_);
@@ -353,15 +294,14 @@ class MatStorageT : public MatPtrT<MatT> {
353294 public:
354295 // Full constructor for dynamic sizing.
355296 MatStorageT (const std::string& name, size_t rows, size_t cols)
356- : MatPtrT<MatT>(name, rows, cols),
357- data_ (hwy::AllocateAligned<MatT>(
358- hwy::DivCeil (this ->SizeBytes (), sizeof(MatT)))) {
359- this ->ptr_ = data_.get ();
297+ : MatPtrT<MatT>(name, rows, cols) {
298+ Allocate ();
360299 }
361300 // Can copy the metadata, from a MatPtr, and allocate later.
362301 MatStorageT (const MatPtr& other) : MatPtrT<MatT>(other) {}
302+ ~MatStorageT () = default ;
363303
364- // No copying of MatStorageT as it contains big data .
304+ // Move-only because this contains a unique_ptr .
365305 MatStorageT (const MatStorageT& other) = delete ;
366306 MatStorageT& operator =(const MatStorageT& other) = delete ;
367307 MatStorageT (MatStorageT&& other) = default ;
@@ -377,7 +317,7 @@ class MatStorageT : public MatPtrT<MatT> {
377317 } else {
378318 this ->num_elements_ = num_elements;
379319 }
380- data_ = hwy::AllocateAligned <MatT>(num_elements);
320+ data_ = Allocator::Alloc <MatT>(num_elements);
381321 this ->ptr_ = data_.get ();
382322 }
383323
@@ -388,8 +328,6 @@ class MatStorageT : public MatPtrT<MatT> {
388328 }
389329
390330 private:
391- // Aligned data array.
392- // std::unique_ptr<MatT[]> data_;
393331 hwy::AlignedFreeUniquePtr<MatT[]> data_;
394332};
395333
@@ -507,7 +445,7 @@ class CompressStats {
507445};
508446#else
509447struct CompressStats {
510- void Notify (const DistortionStats& ) {}
448+ void Notify (... ) {}
511449 void NotifyIn (int ) {}
512450 void Assimilate (const CompressStats&) {}
513451 void PrintAll () {}
@@ -526,18 +464,17 @@ struct CompressWorkingSet {
526464
527465// Functor called for each tensor, which loads them and their scaling factors
528466// from BlobStore.
529- class CacheLoader {
467+ class ReadFromBlobStore {
530468 public:
531- explicit CacheLoader (const Path& blob_filename) {
469+ explicit ReadFromBlobStore (const Path& blob_filename) {
532470 err_ = reader_.Open (blob_filename);
533- if (err_ != 0 ) {
534- fprintf (stderr,
535- " Cached compressed weights does not exist yet (code %d), "
536- " loading from file: %s.\n " ,
537- err_, blob_filename.path .c_str ());
471+ if (HWY_UNLIKELY (err_ != 0 )) {
472+ fprintf (stderr, " Error %d opening BlobStore %s.\n " , err_,
473+ blob_filename.path .c_str ());
474+ return ; // avoid overwriting err_ to ensure ReadAll will fail.
538475 }
539476 err_ = file_toc_.LoadToc (reader_);
540- if (err_ != 0 ) {
477+ if (HWY_UNLIKELY ( err_ != 0 ) ) {
541478 fprintf (stderr, " Found a TOC, but failed to load it (code %d)\n " , err_);
542479 }
543480 }
0 commit comments