@@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);
192192
193193// Returns the sub-config for the ViT model of the PaliGemma model.
194194ModelConfig VitConfig (const ModelConfig& config);
195+ template <class TConfig , typename = void >
196+ struct CacheLayerSize {
197+ constexpr size_t operator ()() const {
198+ return TConfig::kKVHeads * TConfig::kQKVDim * 2 ;
199+ }
200+ };
201+
202+ template <class TConfig , typename = void >
203+ struct CachePosSize {
204+ constexpr size_t operator ()() const {
205+ return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
206+ }
207+ };
208+
209+ struct ConfigNoSSM {
210+ static constexpr int kGriffinLayers = 0 ;
211+
212+ static constexpr int kConv1dWidth = 0 ;
213+ static constexpr bool kFFBiases = false ;
214+ static constexpr bool kSoftmaxAttnOutputBiases = false ;
215+ static constexpr bool kUseHalfRope = false ;
216+ static constexpr bool kUseLocalAttention = false ;
217+ static constexpr bool kInterleaveQKV = true ;
218+ static constexpr int kNumTensorScales = 0 ;
219+
220+ static constexpr PostQKType kPostQK = PostQKType::Rope;
221+ static constexpr ActivationType kActivation = ActivationType::Gelu;
222+ static constexpr ResidualType kResidual = ResidualType::Add;
223+
224+ // Self-extend parameters with defaul values
225+ static constexpr bool kSelfExtend = false ;
226+ static constexpr size_t kSelfExtendNgbSize = 0 ;
227+ static constexpr size_t kSelfExtendGrpSize = 1 ;
228+ };
229+
230+ struct ConfigBaseGemmaV1 : ConfigNoSSM {
231+ static constexpr float kAttCap = 0 .0f ;
232+ static constexpr float kFinalCap = 0 .0f ;
233+ static constexpr PostNormType kPostNorm = PostNormType::None;
234+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
235+ };
236+
237+ struct ConfigBaseGemmaV2 : ConfigNoSSM {
238+ static constexpr float kAttCap = 50 .0f ;
239+ static constexpr float kFinalCap = 30 .0f ;
240+ static constexpr PostNormType kPostNorm = PostNormType::Scale;
241+ };
242+
243+ template <typename TWeight>
244+ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
245+ using Weight = TWeight; // make accessible where we only have a TConfig
246+
247+ static constexpr int kSeqLen = 8192 ;
248+ static constexpr int kVocabSize = 256000 ;
249+ static constexpr std::array<LayerAttentionType, 46 > kLayerConfig =
250+ FixedLayerConfig<46 >(LayerAttentionType::kGemma );
251+ static constexpr std::array<size_t , 46 > kAttentionWindowSizes =
252+ RepeatedAttentionWindowSizes<46 , 2 >({4096 , kSeqLen });
253+ static constexpr int kLayers = kLayerConfig .size();
254+ static constexpr int kGemmaLayers = kLayers ;
255+ static constexpr int kModelDim = 4608 ;
256+ static constexpr int kFFHiddenDim = 16 * 4608 / 2 ; // = 36864
257+ static constexpr int kHeads = 32 ;
258+ static constexpr int kKVHeads = 16 ;
259+ static constexpr int kQKVDim = 128 ; // query size == key size == value size
260+ static constexpr int kTopK = gcpp::kTopK ;
261+ static constexpr bool kAbsolutePE = false ;
262+ static constexpr QueryScaleType kQueryScale =
263+ QueryScaleType::SqrtModelDimDivNumHeads;
264+ };
265+
266+ template <typename TWeight>
267+ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
268+ using Weight = TWeight; // make accessible where we only have a TConfig
269+
270+ static constexpr int kSeqLen = 8192 ;
271+ static constexpr int kVocabSize = 256000 ;
272+ static constexpr std::array<LayerAttentionType, 42 > kLayerConfig =
273+ FixedLayerConfig<42 >(LayerAttentionType::kGemma );
274+ static constexpr std::array<size_t , 42 > kAttentionWindowSizes =
275+ RepeatedAttentionWindowSizes<42 , 2 >({4096 , kSeqLen });
276+ static constexpr int kLayers = kLayerConfig .size();
277+ static constexpr int kGemmaLayers = kLayers ;
278+ static constexpr int kModelDim = 3584 ;
279+ static constexpr int kFFHiddenDim = 8 * 3584 / 2 ; // = 14336
280+ static constexpr int kHeads = 16 ;
281+ static constexpr int kKVHeads = 8 ;
282+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
283+ static constexpr int kTopK = gcpp::kTopK ;
284+ static constexpr bool kAbsolutePE = false ;
285+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
286+ };
287+
288+ template <typename TWeight>
289+ struct ConfigGemma7B : public ConfigBaseGemmaV1 {
290+ using Weight = TWeight; // make accessible where we only have a TConfig
291+
292+ static constexpr int kSeqLen = gcpp::kSeqLen ;
293+ static constexpr int kVocabSize = 256000 ;
294+ static constexpr std::array<LayerAttentionType, 28 > kLayerConfig =
295+ FixedLayerConfig<28 >(LayerAttentionType::kGemma );
296+ static constexpr std::array<size_t , 28 > kAttentionWindowSizes =
297+ FixedAttentionWindowSizes<28 >(kSeqLen );
298+ static constexpr int kLayers = kLayerConfig .size();
299+ static constexpr int kGemmaLayers = kLayers ;
300+ static constexpr int kModelDim = 3072 ;
301+ static constexpr int kFFHiddenDim = 16 * 3072 / 2 ; // = 24576
302+ static constexpr int kHeads = 16 ;
303+ static constexpr int kKVHeads = 16 ; // standard MHA
304+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
305+ static constexpr int kTopK = gcpp::kTopK ;
306+ static constexpr bool kAbsolutePE = false ;
307+ };
308+
309+ template <typename TWeight>
310+ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
311+ using Weight = TWeight; // make accessible where we only have a TConfig
312+
313+ static constexpr int kSeqLen = gcpp::kSeqLen ;
314+ static constexpr int kVocabSize = 256000 ;
315+ static constexpr std::array<LayerAttentionType, 18 > kLayerConfig =
316+ FixedLayerConfig<18 >(LayerAttentionType::kGemma );
317+ static constexpr std::array<size_t , 18 > kAttentionWindowSizes =
318+ FixedAttentionWindowSizes<18 >(kSeqLen );
319+ static constexpr int kLayers = kLayerConfig .size();
320+ static constexpr int kGemmaLayers = kLayers ;
321+ static constexpr int kModelDim = 2048 ;
322+ static constexpr int kFFHiddenDim = 16 * 2048 / 2 ; // = 16384
323+ static constexpr int kHeads = 8 ;
324+ static constexpr int kKVHeads = 1 ;
325+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
326+ static constexpr int kTopK = gcpp::kTopK ;
327+ static constexpr bool kAbsolutePE = false ;
328+ };
329+
330+ template <typename TWeight>
331+ struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
332+ using Weight = TWeight; // make accessible where we only have a TConfig
333+
334+ static constexpr int kSeqLen = 8192 ;
335+ static constexpr int kVocabSize = 256000 ;
336+ static constexpr std::array<LayerAttentionType, 26 > kLayerConfig =
337+ FixedLayerConfig<26 >(LayerAttentionType::kGemma );
338+ static constexpr std::array<size_t , 26 > kAttentionWindowSizes =
339+ RepeatedAttentionWindowSizes<26 , 2 >({4096 , kSeqLen });
340+ static constexpr int kLayers = kLayerConfig .size();
341+ static constexpr int kGemmaLayers = kLayers ;
342+ static constexpr int kModelDim = 2304 ;
343+ static constexpr int kFFHiddenDim = 8 * 2304 / 2 ; // = 9216
344+ static constexpr int kHeads = 8 ;
345+ static constexpr int kKVHeads = 4 ;
346+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
347+ static constexpr int kTopK = gcpp::kTopK ;
348+ static constexpr bool kAbsolutePE = false ;
349+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
350+ };
351+
352+ template <typename TWeight>
353+ struct ConfigGemmaTiny : public ConfigNoSSM {
354+ using Weight = TWeight; // make accessible where we only have a TConfig
355+
356+ static constexpr int kSeqLen = 32 ;
357+ static constexpr int kVocabSize = 64 ;
358+ static constexpr std::array<LayerAttentionType, 3 > kLayerConfig =
359+ FixedLayerConfig<3 >(LayerAttentionType::kGemma );
360+ static constexpr std::array<size_t , 3 > kAttentionWindowSizes =
361+ FixedAttentionWindowSizes<3 >(kSeqLen );
362+ static constexpr int kLayers = kLayerConfig .size();
363+ static constexpr int kGemmaLayers = kLayers ;
364+ static constexpr int kModelDim = 128 ;
365+ static constexpr int kFFHiddenDim = 256 ;
366+ static constexpr int kHeads = 4 ;
367+ static constexpr int kKVHeads = 1 ;
368+ static constexpr int kQKVDim = 16 ; // query size == key size == value size
369+ static constexpr int kTopK = gcpp::kTopK ;
370+ static constexpr bool kAbsolutePE = false ;
371+ static constexpr PostNormType kPostNorm = PostNormType::None;
372+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
373+
374+ static constexpr float kAttCap = 0 .0f ;
375+ // This is required for optimize_test to pass.
376+ static constexpr float kFinalCap = 30 .0f ;
377+ };
378+
379+ template <typename TWeight>
380+ struct ConfigGriffin2B {
381+ using Weight = TWeight; // make accessible where we only have a TConfig
382+
383+ // Griffin uses local attention, so kSeqLen is actually the local attention
384+ // window.
385+ static constexpr int kSeqLen = 2048 ;
386+ static constexpr int kVocabSize = 256000 ;
387+ static constexpr std::array<LayerAttentionType, 26 > kLayerConfig = {
388+ LayerAttentionType::kGriffinRecurrentBlock ,
389+ LayerAttentionType::kGriffinRecurrentBlock ,
390+ LayerAttentionType::kGemma ,
391+ LayerAttentionType::kGriffinRecurrentBlock ,
392+ LayerAttentionType::kGriffinRecurrentBlock ,
393+ LayerAttentionType::kGemma ,
394+ LayerAttentionType::kGriffinRecurrentBlock ,
395+ LayerAttentionType::kGriffinRecurrentBlock ,
396+ LayerAttentionType::kGemma ,
397+ LayerAttentionType::kGriffinRecurrentBlock ,
398+ LayerAttentionType::kGriffinRecurrentBlock ,
399+ LayerAttentionType::kGemma ,
400+ LayerAttentionType::kGriffinRecurrentBlock ,
401+ LayerAttentionType::kGriffinRecurrentBlock ,
402+ LayerAttentionType::kGemma ,
403+ LayerAttentionType::kGriffinRecurrentBlock ,
404+ LayerAttentionType::kGriffinRecurrentBlock ,
405+ LayerAttentionType::kGemma ,
406+ LayerAttentionType::kGriffinRecurrentBlock ,
407+ LayerAttentionType::kGriffinRecurrentBlock ,
408+ LayerAttentionType::kGemma ,
409+ LayerAttentionType::kGriffinRecurrentBlock ,
410+ LayerAttentionType::kGriffinRecurrentBlock ,
411+ LayerAttentionType::kGemma ,
412+ LayerAttentionType::kGriffinRecurrentBlock ,
413+ LayerAttentionType::kGriffinRecurrentBlock ,
414+ };
415+ static constexpr std::array<size_t , 26 > kAttentionWindowSizes =
416+ FixedAttentionWindowSizes<26 >(kSeqLen );
417+ static constexpr int kLayers = kLayerConfig .size();
418+ static constexpr int kGemmaLayers =
419+ NumLayersOfTypeBefore (kLayerConfig , LayerAttentionType::kGemma , kLayers );
420+ static constexpr int kGriffinLayers =
421+ NumLayersOfTypeBefore (kLayerConfig ,
422+ LayerAttentionType::kGriffinRecurrentBlock ,
423+ kLayers );
424+ static constexpr int kModelDim = 2560 ;
425+ static constexpr int kFFHiddenDim = 7680 ;
426+ static constexpr int kHeads = 10 ;
427+ static constexpr int kKVHeads = 1 ;
428+ static constexpr int kQKVDim = 256 ; // query size == key size == value size
429+ static constexpr int kTopK = gcpp::kTopK ;
430+ static constexpr bool kAbsolutePE = false ;
431+ static constexpr PostNormType kPostNorm = PostNormType::None;
432+
433+ // No SoftCap.
434+ static constexpr float kAttCap = 0 .0f ;
435+ static constexpr float kFinalCap = 0 .0f ;
436+
437+ // SSM config.
438+ static constexpr int kConv1dWidth = 4 ;
439+ static constexpr bool kFFBiases = true ;
440+ static constexpr bool kSoftmaxAttnOutputBiases = true ;
441+ static constexpr bool kUseHalfRope = true ;
442+ static constexpr bool kUseLocalAttention = true ;
443+ static constexpr bool kInterleaveQKV = false ;
444+ static constexpr int kNumTensorScales = 140 ;
445+ static constexpr PostQKType kPostQK = PostQKType::Rope;
446+ static constexpr ActivationType kActivation = ActivationType::Gelu;
447+ static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
448+ static constexpr ResidualType kResidual = ResidualType::Add;
449+
450+ // Self-extend parameters with defaul values
451+ static constexpr bool kSelfExtend = false ;
452+ static constexpr size_t kSelfExtendNgbSize = 0 ;
453+ static constexpr size_t kSelfExtendGrpSize = 1 ;
454+ };
195455
196456} // namespace gcpp
197457
0 commit comments