Skip to content

Commit c206b33

Browse files
committed
Add metadata extraction
1 parent 7690219 commit c206b33

File tree

3 files changed

+91
-36
lines changed

3 files changed

+91
-36
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ tempfile = "3"
1515

1616
rayon = { version = "1.11.0" }
1717
anyhow = "1.0.100"
18+
csv = "1.4.0"
1819

1920
[profile.release]
2021
opt-level = 3

src/main.rs

Lines changed: 89 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use std::fs::{File, create_dir_all, read_dir};
22
use std::io::Write;
33
use std::path::{Path, PathBuf};
4-
use std::process;
4+
use std::process::{self};
5+
use std::sync::Mutex;
56
use std::sync::atomic::{AtomicUsize, Ordering};
67

78
use anyhow::{Context, Result};
89
use arrow::ipc::reader::StreamReader;
910
use arrow::record_batch::RecordBatch;
10-
use clap::{Parser, ValueEnum};
11+
use clap::{ArgAction, Parser, ValueEnum};
1112
use parquet::arrow::ArrowWriter;
1213
use parquet::file::properties::WriterProperties;
1314
use polars::prelude::*;
@@ -42,6 +43,10 @@ struct Args {
4243
/// Number of threads to use for processing
4344
#[arg(long, default_value_t = 3)]
4445
threads: usize,
46+
47+
/// CSV file where transcriptions should be written
48+
#[arg(long, action = ArgAction::Set)]
49+
metadata_file: Option<PathBuf>,
4550
}
4651

4752
fn arrow_to_parquet(filename: &Path) -> Result<DataFrame> {
@@ -73,10 +78,11 @@ fn batches_to_parquet(batches: &[RecordBatch]) -> Result<DataFrame> {
7378

7479
let tmp_file = writer.into_inner()?;
7580

76-
// Read in parquet file
81+
// Read in parquet file and unnest the audio column
7782
let df = ParquetReader::new(tmp_file)
78-
.with_columns(Some(vec!["audio".to_string()]))
79-
.finish()?;
83+
.with_columns(Some(vec!["audio".to_string(), "transcription".to_string()]))
84+
.finish()?
85+
.unnest(["audio"])?;
8086

8187
Ok(df)
8288
}
@@ -86,9 +92,10 @@ fn read_parquet(filename: &Path) -> Result<DataFrame> {
8692
.with_context(|| format!("Failed to open parquet file: {}", filename.display()))?;
8793

8894
let df = ParquetReader::new(file)
89-
.with_columns(Some(vec!["audio".to_string()]))
95+
.with_columns(Some(vec!["audio".to_string(), "transcription".to_string()]))
9096
.finish()
91-
.context("Failed to read parquet file into DataFrame")?;
97+
.context("Failed to read parquet file into DataFrame")?
98+
.unnest(["audio"])?;
9299

93100
Ok(df)
94101
}
@@ -106,7 +113,12 @@ fn write_file(filename: &Path, data: &[u8]) -> Result<()> {
106113
Ok(())
107114
}
108115

109-
fn process_file(filename: &Path, format: Format, output_dir: &Path) -> Result<usize> {
116+
fn process_file(
117+
filename: &Path,
118+
format: Format,
119+
output_dir: &Path,
120+
metadata_records: &Mutex<Vec<(String, String)>>,
121+
) -> Result<usize> {
110122
// Convert the file to a DataFrame
111123
let df = match format {
112124
Format::Arrow => arrow_to_parquet(filename)
@@ -115,36 +127,50 @@ fn process_file(filename: &Path, format: Format, output_dir: &Path) -> Result<us
115127
.with_context(|| format!("Error processing parquet file {}", filename.display()))?,
116128
};
117129

118-
let num_rows = df.height();
119-
120-
for row in df.iter() {
121-
let struct_series = row.struct_()?;
130+
// Extract the series from the DataFrame
131+
let path_series = df.column("path")?.str()?;
132+
let array_series = df.column("bytes")?.binary()?;
133+
let transcription_series = df.column("transcription")?.str()?;
122134

123-
let all_bytes = struct_series.field_by_name("bytes")?;
124-
let all_paths = struct_series.field_by_name("path")?;
125-
126-
// Extract files in parallel
127-
(0..all_paths.len()).into_par_iter().try_for_each(|idx| {
128-
let path = all_paths.get(idx)?;
129-
let bytes = all_bytes.get(idx)?;
130-
131-
let filename = match path {
132-
AnyValue::String(b) => b.to_string(),
133-
_ => return Ok::<(), PolarsError>(()),
134-
};
135+
let num_rows = df.height();
135136

136-
let bytes = match bytes {
137-
AnyValue::Binary(b) => b,
138-
_ => return Ok(()),
139-
};
137+
let records: Vec<_> = (0..num_rows)
138+
.into_par_iter()
139+
.filter_map(|i| {
140+
if let (Some(path_val), Some(transcription), Some(array_series_inner)) = (
141+
path_series.get(i),
142+
transcription_series.get(i),
143+
array_series.get(i),
144+
) {
145+
Some((path_val, transcription, array_series_inner))
146+
} else {
147+
None
148+
}
149+
})
150+
.collect();
151+
152+
let local_metadata: Vec<(String, String)> = records
153+
.par_iter()
154+
.map(|(path_val, transcription, array_series_inner)| {
155+
let original_path = Path::new(path_val);
156+
let file_stem = original_path.file_stem().unwrap_or_default();
157+
let extension = original_path.extension().unwrap_or_default();
158+
159+
let audio_filename_str = format!(
160+
"{}.{}",
161+
file_stem.to_string_lossy(),
162+
extension.to_string_lossy()
163+
);
164+
let audio_filename = output_dir.join(&audio_filename_str);
165+
let audio_data: &[u8] = array_series_inner;
166+
write_file(&audio_filename, audio_data).expect("Failed to write audio file");
140167

141-
let path = output_dir.join(filename.clone());
168+
(audio_filename_str, transcription.to_string())
169+
})
170+
.collect();
142171

143-
let _ = write_file(&path, bytes);
172+
metadata_records.lock().unwrap().extend(local_metadata);
144173

145-
Ok(())
146-
})?;
147-
}
148174
Ok(num_rows)
149175
}
150176

@@ -169,13 +195,15 @@ fn main() -> Result<()> {
169195
)
170196
})?;
171197

198+
let metadata_records = Mutex::new(Vec::new());
199+
172200
if let Some(input_file) = args.input {
173201
if !input_file.is_file() {
174202
eprintln!("Input is not a file: {}", input_file.display());
175203
process::exit(1);
176204
}
177205
println!("Processing file: {}...", input_file.display());
178-
let rows = process_file(&input_file, args.format, &args.output)?;
206+
let rows = process_file(&input_file, args.format, &args.output, &metadata_records)?;
179207
println!("Total number of rows processed: {}", rows);
180208
}
181209

@@ -192,7 +220,7 @@ fn main() -> Result<()> {
192220
.filter_map(Result::ok)
193221
.filter(|entry| {
194222
entry.path().is_file()
195-
&& entry
223+
&& entry // TODO: this is not correct, should be based on format
196224
.path()
197225
.extension()
198226
.is_some_and(|ext| ext == "parquet" || ext == "arrow")
@@ -204,7 +232,7 @@ fn main() -> Result<()> {
204232
files_to_process.into_iter().for_each(|entry| {
205233
let path = entry.path();
206234
println!("Processing file: {}...", path.display());
207-
match process_file(&path, args.format, &args.output) {
235+
match process_file(&path, args.format, &args.output, &metadata_records) {
208236
Ok(rows) => {
209237
total_rows.fetch_add(rows, Ordering::SeqCst);
210238
}
@@ -218,6 +246,31 @@ fn main() -> Result<()> {
218246
);
219247
}
220248

249+
if let Some(metadata_file_path) = args.metadata_file {
250+
println!("Writing metadata to {}...", metadata_file_path.display());
251+
let records = metadata_records.into_inner().unwrap();
252+
if !records.is_empty() {
253+
let mut df = DataFrame::new(vec![
254+
Column::new(
255+
"file_name".into(),
256+
records.iter().map(|(f, _)| f.as_str()).collect::<Vec<_>>(),
257+
),
258+
Column::new(
259+
"transcription".into(),
260+
records.iter().map(|(_, t)| t.as_str()).collect::<Vec<_>>(),
261+
),
262+
])?;
263+
264+
let mut file = File::create(&metadata_file_path).with_context(|| {
265+
format!(
266+
"Failed to create metadata file: {}",
267+
metadata_file_path.display()
268+
)
269+
})?;
270+
CsvWriter::new(&mut file).finish(&mut df)?;
271+
}
272+
}
273+
221274
println!("Done!");
222275

223276
Ok(())

0 commit comments

Comments
 (0)