1+ using System ;
2+ using System . Diagnostics ;
3+ using System . Runtime . InteropServices ;
4+ using System . Text ;
5+ namespace GemmaCpp
6+ {
7+ public class GemmaException : Exception
8+ {
9+ public GemmaException ( string message ) : base ( message ) { }
10+ }
11+
12+ public class Gemma : IDisposable
13+ {
14+ private IntPtr _context ;
15+ private bool _disposed ;
16+
17+ // Optional: Allow setting DLL path
18+ public static string DllPath { get ; set ; } = "gemma.dll" ;
19+
20+ [ DllImport ( "kernel32.dll" , CharSet = CharSet . Unicode , SetLastError = true ) ]
21+ private static extern IntPtr LoadLibrary ( string lpFileName ) ;
22+
23+ static Gemma ( )
24+ {
25+ // Load DLL from specified path
26+ if ( LoadLibrary ( DllPath ) == IntPtr . Zero )
27+ {
28+ throw new DllNotFoundException ( $ "Failed to load { DllPath } . Error: { Marshal . GetLastWin32Error ( ) } ") ;
29+ }
30+ }
31+
32+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
33+ private static extern IntPtr GemmaCreate (
34+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string tokenizerPath ,
35+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string modelType ,
36+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string weightsPath ,
37+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string weightType ,
38+ int maxLength ) ;
39+
40+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
41+ private static extern void GemmaDestroy ( IntPtr context ) ;
42+
43+ // Delegate type for token callbacks
44+ public delegate bool TokenCallback ( string token ) ;
45+
46+ // Keep delegate alive for duration of calls
47+ private GCHandle _callbackHandle ;
48+
49+ [ UnmanagedFunctionPointer ( CallingConvention . Cdecl ) ]
50+ private delegate bool GemmaTokenCallback (
51+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string text ,
52+ IntPtr userData ) ;
53+
54+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
55+ private static extern int GemmaGenerate (
56+ IntPtr context ,
57+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string prompt ,
58+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] StringBuilder output ,
59+ int maxLength ,
60+ GemmaTokenCallback callback ,
61+ IntPtr userData ) ;
62+
63+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
64+ private static extern int GemmaCountTokens (
65+ IntPtr context ,
66+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string text ) ;
67+
68+ // Native callback delegate type
69+ [ UnmanagedFunctionPointer ( CallingConvention . Cdecl ) ]
70+ private delegate void GemmaLogCallback (
71+ [ MarshalAs ( UnmanagedType . LPUTF8Str ) ] string message ,
72+ IntPtr userData ) ;
73+
74+ [ DllImport ( "gemma" , CallingConvention = CallingConvention . Cdecl ) ]
75+ private static extern void GemmaSetLogCallback (
76+ IntPtr context ,
77+ GemmaLogCallback callback ,
78+ IntPtr userData ) ;
79+
80+ private GCHandle _logCallbackHandle ;
81+
82+ public Gemma ( string tokenizerPath , string modelType , string weightsPath , string weightType , int maxLength = 8192 )
83+ {
84+ _context = GemmaCreate ( tokenizerPath , modelType , weightsPath , weightType , maxLength ) ;
85+ if ( _context == IntPtr . Zero )
86+ {
87+ throw new GemmaException ( "Failed to create Gemma context" ) ;
88+ }
89+
90+ // optionally: set up logging
91+ /*
92+ GemmaLogCallback logCallback = (message, _) =>
93+ {
94+ #if UNITY_ENGINE
95+ Debug.Log($"Gemma: {message}");
96+ #else
97+ Debug.WriteLine($"Gemma: {message}");
98+ #endif
99+ };
100+ _logCallbackHandle = GCHandle.Alloc(logCallback);
101+ GemmaSetLogCallback(_context, logCallback, IntPtr.Zero);
102+ */
103+ }
104+
105+ public int CountTokens ( string prompt )
106+ {
107+ if ( _disposed )
108+ throw new ObjectDisposedException ( nameof ( Gemma ) ) ;
109+
110+ if ( _context == IntPtr . Zero )
111+ throw new GemmaException ( "Gemma context is invalid" ) ;
112+ int count = GemmaCountTokens ( _context , prompt ) ;
113+ return count ;
114+ }
115+
116+ public string Generate ( string prompt , int maxLength = 4096 )
117+ {
118+ return Generate ( prompt , null , maxLength ) ;
119+ }
120+
121+ public string Generate ( string prompt , TokenCallback callback , int maxLength = 4096 )
122+ {
123+ if ( _disposed )
124+ throw new ObjectDisposedException ( nameof ( Gemma ) ) ;
125+
126+ if ( _context == IntPtr . Zero )
127+ throw new GemmaException ( "Gemma context is invalid" ) ;
128+
129+ var output = new StringBuilder ( maxLength ) ;
130+ GemmaTokenCallback nativeCallback = null ;
131+
132+ if ( callback != null )
133+ {
134+ nativeCallback = ( text , _ ) => callback ( text ) ;
135+ _callbackHandle = GCHandle . Alloc ( nativeCallback ) ;
136+ }
137+
138+ try
139+ {
140+ int length = GemmaGenerate ( _context , prompt , output , maxLength ,
141+ nativeCallback , IntPtr . Zero ) ;
142+
143+ if ( length < 0 )
144+ throw new GemmaException ( "Generation failed" ) ;
145+
146+ return output . ToString ( ) ;
147+ }
148+ finally
149+ {
150+ if ( _callbackHandle . IsAllocated )
151+ _callbackHandle . Free ( ) ;
152+ }
153+ }
154+
155+ public void Dispose ( )
156+ {
157+ if ( ! _disposed )
158+ {
159+ if ( _context != IntPtr . Zero )
160+ {
161+ GemmaDestroy ( _context ) ;
162+ _context = IntPtr . Zero ;
163+ }
164+ if ( _logCallbackHandle . IsAllocated )
165+ _logCallbackHandle . Free ( ) ;
166+ _disposed = true ;
167+ }
168+ }
169+
170+ ~ Gemma ( )
171+ {
172+ Dispose ( ) ;
173+ }
174+ }
175+ }
0 commit comments