11use std:: fs:: { File , create_dir_all, read_dir} ;
22use std:: io:: Write ;
33use std:: path:: { Path , PathBuf } ;
4- use std:: process;
4+ use std:: process:: { self } ;
5+ use std:: sync:: Mutex ;
56use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
67
78use anyhow:: { Context , Result } ;
89use arrow:: ipc:: reader:: StreamReader ;
910use arrow:: record_batch:: RecordBatch ;
10- use clap:: { Parser , ValueEnum } ;
11+ use clap:: { ArgAction , Parser , ValueEnum } ;
1112use parquet:: arrow:: ArrowWriter ;
1213use parquet:: file:: properties:: WriterProperties ;
1314use polars:: prelude:: * ;
@@ -42,6 +43,10 @@ struct Args {
4243 /// Number of threads to use for processing
4344 #[ arg( long, default_value_t = 3 ) ]
4445 threads : usize ,
46+
47+ /// CSV file where transcriptions should be written
48+ #[ arg( long, action = ArgAction :: Set ) ]
49+ metadata_file : Option < PathBuf > ,
4550}
4651
4752fn arrow_to_parquet ( filename : & Path ) -> Result < DataFrame > {
@@ -73,10 +78,11 @@ fn batches_to_parquet(batches: &[RecordBatch]) -> Result<DataFrame> {
7378
7479 let tmp_file = writer. into_inner ( ) ?;
7580
76- // Read in parquet file
81+ // Read in parquet file and unnest the audio column
7782 let df = ParquetReader :: new ( tmp_file)
78- . with_columns ( Some ( vec ! [ "audio" . to_string( ) ] ) )
79- . finish ( ) ?;
83+ . with_columns ( Some ( vec ! [ "audio" . to_string( ) , "transcription" . to_string( ) ] ) )
84+ . finish ( ) ?
85+ . unnest ( [ "audio" ] ) ?;
8086
8187 Ok ( df)
8288}
@@ -86,9 +92,10 @@ fn read_parquet(filename: &Path) -> Result<DataFrame> {
8692 . with_context ( || format ! ( "Failed to open parquet file: {}" , filename. display( ) ) ) ?;
8793
8894 let df = ParquetReader :: new ( file)
89- . with_columns ( Some ( vec ! [ "audio" . to_string( ) ] ) )
95+ . with_columns ( Some ( vec ! [ "audio" . to_string( ) , "transcription" . to_string ( ) ] ) )
9096 . finish ( )
91- . context ( "Failed to read parquet file into DataFrame" ) ?;
97+ . context ( "Failed to read parquet file into DataFrame" ) ?
98+ . unnest ( [ "audio" ] ) ?;
9299
93100 Ok ( df)
94101}
@@ -106,7 +113,12 @@ fn write_file(filename: &Path, data: &[u8]) -> Result<()> {
106113 Ok ( ( ) )
107114}
108115
109- fn process_file ( filename : & Path , format : Format , output_dir : & Path ) -> Result < usize > {
116+ fn process_file (
117+ filename : & Path ,
118+ format : Format ,
119+ output_dir : & Path ,
120+ metadata_records : & Mutex < Vec < ( String , String ) > > ,
121+ ) -> Result < usize > {
110122 // Convert the file to a DataFrame
111123 let df = match format {
112124 Format :: Arrow => arrow_to_parquet ( filename)
@@ -115,36 +127,50 @@ fn process_file(filename: &Path, format: Format, output_dir: &Path) -> Result<us
115127 . with_context ( || format ! ( "Error processing parquet file {}" , filename. display( ) ) ) ?,
116128 } ;
117129
118- let num_rows = df . height ( ) ;
119-
120- for row in df. iter ( ) {
121- let struct_series = row . struct_ ( ) ?;
130+ // Extract the series from the DataFrame
131+ let path_series = df . column ( "path" ) ? . str ( ) ? ;
132+ let array_series = df. column ( "bytes" ) ? . binary ( ) ? ;
133+ let transcription_series = df . column ( "transcription" ) ? . str ( ) ?;
122134
123- let all_bytes = struct_series. field_by_name ( "bytes" ) ?;
124- let all_paths = struct_series. field_by_name ( "path" ) ?;
125-
126- // Extract files in parallel
127- ( 0 ..all_paths. len ( ) ) . into_par_iter ( ) . try_for_each ( |idx| {
128- let path = all_paths. get ( idx) ?;
129- let bytes = all_bytes. get ( idx) ?;
130-
131- let filename = match path {
132- AnyValue :: String ( b) => b. to_string ( ) ,
133- _ => return Ok :: < ( ) , PolarsError > ( ( ) ) ,
134- } ;
135+ let num_rows = df. height ( ) ;
135136
136- let bytes = match bytes {
137- AnyValue :: Binary ( b) => b,
138- _ => return Ok ( ( ) ) ,
139- } ;
137+ let records: Vec < _ > = ( 0 ..num_rows)
138+ . into_par_iter ( )
139+ . filter_map ( |i| {
140+ if let ( Some ( path_val) , Some ( transcription) , Some ( array_series_inner) ) = (
141+ path_series. get ( i) ,
142+ transcription_series. get ( i) ,
143+ array_series. get ( i) ,
144+ ) {
145+ Some ( ( path_val, transcription, array_series_inner) )
146+ } else {
147+ None
148+ }
149+ } )
150+ . collect ( ) ;
151+
152+ let local_metadata: Vec < ( String , String ) > = records
153+ . par_iter ( )
154+ . map ( |( path_val, transcription, array_series_inner) | {
155+ let original_path = Path :: new ( path_val) ;
156+ let file_stem = original_path. file_stem ( ) . unwrap_or_default ( ) ;
157+ let extension = original_path. extension ( ) . unwrap_or_default ( ) ;
158+
159+ let audio_filename_str = format ! (
160+ "{}.{}" ,
161+ file_stem. to_string_lossy( ) ,
162+ extension. to_string_lossy( )
163+ ) ;
164+ let audio_filename = output_dir. join ( & audio_filename_str) ;
165+ let audio_data: & [ u8 ] = array_series_inner;
166+ write_file ( & audio_filename, audio_data) . expect ( "Failed to write audio file" ) ;
140167
141- let path = output_dir. join ( filename. clone ( ) ) ;
168+ ( audio_filename_str, transcription. to_string ( ) )
169+ } )
170+ . collect ( ) ;
142171
143- let _ = write_file ( & path , bytes ) ;
172+ metadata_records . lock ( ) . unwrap ( ) . extend ( local_metadata ) ;
144173
145- Ok ( ( ) )
146- } ) ?;
147- }
148174 Ok ( num_rows)
149175}
150176
@@ -169,13 +195,15 @@ fn main() -> Result<()> {
169195 )
170196 } ) ?;
171197
198+ let metadata_records = Mutex :: new ( Vec :: new ( ) ) ;
199+
172200 if let Some ( input_file) = args. input {
173201 if !input_file. is_file ( ) {
174202 eprintln ! ( "Input is not a file: {}" , input_file. display( ) ) ;
175203 process:: exit ( 1 ) ;
176204 }
177205 println ! ( "Processing file: {}..." , input_file. display( ) ) ;
178- let rows = process_file ( & input_file, args. format , & args. output ) ?;
206+ let rows = process_file ( & input_file, args. format , & args. output , & metadata_records ) ?;
179207 println ! ( "Total number of rows processed: {}" , rows) ;
180208 }
181209
@@ -192,7 +220,7 @@ fn main() -> Result<()> {
192220 . filter_map ( Result :: ok)
193221 . filter ( |entry| {
194222 entry. path ( ) . is_file ( )
195- && entry
223+ && entry // TODO: this is not correct, should be based on format
196224 . path ( )
197225 . extension ( )
198226 . is_some_and ( |ext| ext == "parquet" || ext == "arrow" )
@@ -204,7 +232,7 @@ fn main() -> Result<()> {
204232 files_to_process. into_iter ( ) . for_each ( |entry| {
205233 let path = entry. path ( ) ;
206234 println ! ( "Processing file: {}..." , path. display( ) ) ;
207- match process_file ( & path, args. format , & args. output ) {
235+ match process_file ( & path, args. format , & args. output , & metadata_records ) {
208236 Ok ( rows) => {
209237 total_rows. fetch_add ( rows, Ordering :: SeqCst ) ;
210238 }
@@ -218,6 +246,31 @@ fn main() -> Result<()> {
218246 ) ;
219247 }
220248
249+ if let Some ( metadata_file_path) = args. metadata_file {
250+ println ! ( "Writing metadata to {}..." , metadata_file_path. display( ) ) ;
251+ let records = metadata_records. into_inner ( ) . unwrap ( ) ;
252+ if !records. is_empty ( ) {
253+ let mut df = DataFrame :: new ( vec ! [
254+ Column :: new(
255+ "file_name" . into( ) ,
256+ records. iter( ) . map( |( f, _) | f. as_str( ) ) . collect:: <Vec <_>>( ) ,
257+ ) ,
258+ Column :: new(
259+ "transcription" . into( ) ,
260+ records. iter( ) . map( |( _, t) | t. as_str( ) ) . collect:: <Vec <_>>( ) ,
261+ ) ,
262+ ] ) ?;
263+
264+ let mut file = File :: create ( & metadata_file_path) . with_context ( || {
265+ format ! (
266+ "Failed to create metadata file: {}" ,
267+ metadata_file_path. display( )
268+ )
269+ } ) ?;
270+ CsvWriter :: new ( & mut file) . finish ( & mut df) ?;
271+ }
272+ }
273+
221274 println ! ( "Done!" ) ;
222275
223276 Ok ( ( ) )
0 commit comments