Skip to content

Commit 8e4d2d3

Browse files
authored
Merge pull request #14 from second-state/concurrent_tts
(v2) Concurrent tts
2 parents 66e00c4 + e627a85 commit 8e4d2d3

File tree

13 files changed

+1135
-38
lines changed

13 files changed

+1135
-38
lines changed

src/ai/bailian/cosyvoice.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ impl Default for CosyVoiceVersion {
5959
}
6060

6161
impl CosyVoiceTTS {
62+
const WEBSOCKET_URL: &str = "wss://dashscope.aliyuncs.com/api-ws/v1/inference";
63+
6264
pub async fn connect(token: String) -> anyhow::Result<Self> {
63-
let url = format!("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
65+
let url = Self::WEBSOCKET_URL;
6466

6567
let client = reqwest::Client::new();
6668
let response = client
@@ -79,6 +81,22 @@ impl CosyVoiceTTS {
7981
})
8082
}
8183

84+
pub async fn reconnect(&mut self) -> anyhow::Result<()> {
85+
let url = Self::WEBSOCKET_URL;
86+
87+
let client = reqwest::Client::new();
88+
let response = client
89+
.get(url)
90+
.bearer_auth(&self.token)
91+
.header("X-DashScope-DataInspection", "enable")
92+
.upgrade()
93+
.send()
94+
.await?;
95+
self.websocket = response.into_websocket().await?;
96+
self.synthesis_started = false;
97+
Ok(())
98+
}
99+
82100
pub async fn start_synthesis(
83101
&mut self,
84102
model: CosyVoiceVersion,
@@ -222,4 +240,23 @@ async fn test_cosyvoice_tts() {
222240
};
223241
let wav = crate::util::pcm_to_wav(&audio_data, config);
224242
std::fs::write("./resources/test/cosyvoice_out.wav", wav).unwrap();
243+
244+
let text = "Hello, this is CosyVoice V2";
245+
tts.start_synthesis(CosyVoiceVersion::V2, None, Some(24000), text)
246+
.await
247+
.unwrap();
248+
249+
let mut audio_data = bytes::BytesMut::new();
250+
while let Ok(Some(chunk)) = tts.next_audio_chunk().await {
251+
audio_data.extend_from_slice(&chunk);
252+
}
253+
254+
println!("Audio data size: {} bytes", audio_data.len());
255+
let config = crate::util::WavConfig {
256+
channels: 1,
257+
sample_rate: 24000,
258+
bits_per_sample: 16,
259+
};
260+
let wav = crate::util::pcm_to_wav(&audio_data, config);
261+
std::fs::write("./resources/test/cosyvoice_out2.wav", wav).unwrap();
225262
}

src/ai/bailian/realtime_asr.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ impl ParaformerRealtimeV2Asr {
7575
})
7676
}
7777

78+
pub async fn reconnect(&mut self) -> anyhow::Result<()> {
79+
let url = format!("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
80+
81+
let client = reqwest::Client::new();
82+
let response = client
83+
.get(url)
84+
.bearer_auth(&self.token)
85+
.header("X-DashScope-DataInspection", "enable")
86+
.upgrade()
87+
.send()
88+
.await?;
89+
self.websocket = response.into_websocket().await?;
90+
Ok(())
91+
}
92+
7893
pub async fn start_pcm_recognition(&mut self) -> anyhow::Result<()> {
7994
let task_id = Uuid::new_v4().to_string();
8095
log::info!("Starting asr task with ID: {}", task_id);
@@ -110,6 +125,13 @@ impl ParaformerRealtimeV2Asr {
110125
log::debug!("Received message: {:?}", text);
111126

112127
let response: ResponseMessage = serde_json::from_str(&text)?;
128+
if response.header.task_id != self.task_id {
129+
log::warn!(
130+
"Received message for different task_id: {}",
131+
response.header.task_id
132+
);
133+
continue;
134+
}
113135

114136
if response.is_task_started() {
115137
log::info!("Recognition task started");
@@ -187,6 +209,7 @@ impl ParaformerRealtimeV2Asr {
187209
}
188210
}
189211

212+
// cargo test --package echokit_server --bin echokit_server -- ai::bailian::realtime_asr::test_paraformer_asr --exact --show-output
190213
#[tokio::test]
191214
async fn test_paraformer_asr() {
192215
env_logger::init();
@@ -202,6 +225,21 @@ async fn test_paraformer_asr() {
202225
.unwrap();
203226
asr.start_pcm_recognition().await.unwrap();
204227

228+
asr.send_audio(audio_data.clone()).await.unwrap();
229+
asr.finish_task().await.unwrap();
230+
231+
loop {
232+
if let Ok(Some(sentence)) = asr.next_result().await {
233+
println!("{:?}", sentence);
234+
if sentence.sentence_end {
235+
println!();
236+
}
237+
} else {
238+
break;
239+
}
240+
}
241+
242+
asr.start_pcm_recognition().await.unwrap();
205243
asr.send_audio(audio_data).await.unwrap();
206244
asr.finish_task().await.unwrap();
207245

src/ai/elevenlabs/tts.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,21 @@ impl ElevenlabsTTS {
9898
token: String,
9999
voice: String,
100100
output_format: OutputFormat,
101+
) -> anyhow::Result<Self> {
102+
let client = reqwest::Client::new();
103+
Self::new_with_client(&client, token, voice, output_format).await
104+
}
105+
106+
pub async fn new_with_client(
107+
client: &reqwest::Client,
108+
token: String,
109+
voice: String,
110+
output_format: OutputFormat,
101111
) -> anyhow::Result<Self> {
102112
let url = format!(
103113
"wss://api.elevenlabs.io/v1/text-to-speech/{voice}/stream-input?model_id={MODEL_ID}&output_format={output_format}",
104114
);
105115

106-
let client = reqwest::Client::new();
107-
108116
let response = client
109117
.get(url)
110118
.header("xi-api-key", &token)
@@ -203,6 +211,7 @@ impl ElevenlabsTTS {
203211
}
204212
}
205213

214+
// cargo test --package echokit_server --bin echokit_server -- ai::elevenlabs::tts::test_elevenlabs_tts --exact --show-output
206215
#[tokio::test]
207216
async fn test_elevenlabs_tts() {
208217
env_logger::init();

src/ai/tts.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ use bytes::Bytes;
22

33
/// return: wav_audio: 16bit,32k,single-channel.
44
pub async fn gsv(
5+
client: &reqwest::Client,
56
tts_url: &str,
67
speaker: &str,
78
text: &str,
89
sample_rate: Option<usize>,
910
) -> anyhow::Result<Bytes> {
1011
log::debug!("speaker: {speaker}, text: {text}");
11-
let client = reqwest::Client::new();
1212
let res = client
1313
.post(tts_url)
1414
.json(&serde_json::json!({"speaker": speaker, "input": text, "sample_rate": sample_rate}))
@@ -34,7 +34,10 @@ async fn test_gsv() {
3434
let tts_url = "http://localhost:8000/v1/audio/speech";
3535
let speaker = "ad";
3636
let text = "你好,我是胡桃";
37-
let wav_audio = gsv(tts_url, speaker, text, Some(16000)).await.unwrap();
37+
let client = reqwest::Client::new();
38+
let wav_audio = gsv(&client, tts_url, speaker, text, Some(16000))
39+
.await
40+
.unwrap();
3841
let header = hound::WavReader::new(wav_audio.as_ref()).unwrap();
3942
let spec = header.spec();
4043
println!("wav header: {:?}", spec);
@@ -44,13 +47,13 @@ async fn test_gsv() {
4447

4548
/// return: pcm_chunk: 16bit,32k,single-channel.
4649
pub async fn stream_gsv(
50+
client: &reqwest::Client,
4751
tts_url: &str,
4852
speaker: &str,
4953
text: &str,
5054
sample_rate: Option<usize>,
5155
) -> anyhow::Result<reqwest::Response> {
5256
log::debug!("speaker: {speaker}, text: {text}");
53-
let client = reqwest::Client::new();
5457
let res = client
5558
.post(tts_url)
5659
.json(&serde_json::json!({"speaker": speaker, "input": text, "sample_rate": sample_rate}))
@@ -69,9 +72,14 @@ pub async fn stream_gsv(
6972
}
7073

7174
/// return: wav_audio: 16bit,48k,single-channel.
72-
pub async fn groq(model: &str, token: &str, voice: &str, text: &str) -> anyhow::Result<Bytes> {
75+
pub async fn groq(
76+
client: &reqwest::Client,
77+
model: &str,
78+
token: &str,
79+
voice: &str,
80+
text: &str,
81+
) -> anyhow::Result<Bytes> {
7382
log::debug!("groq tts. voice: {voice}, text: {text}");
74-
let client = reqwest::Client::new();
7583
let res = client
7684
.post("https://api.groq.com/openai/v1/audio/speech")
7785
.bearer_auth(token)
@@ -102,7 +110,10 @@ async fn test_groq() {
102110
let token = std::env::var("GROQ_API_KEY").unwrap();
103111
let speaker = "Aaliyah-PlayAI";
104112
let text = "你好,我是胡桃";
105-
let wav_audio = groq("playai-tts", &token, speaker, text).await.unwrap();
113+
let client = reqwest::Client::new();
114+
let wav_audio = groq(&client, "playai-tts", &token, speaker, text)
115+
.await
116+
.unwrap();
106117
let mut reader = wav_io::reader::Reader::from_vec(wav_audio.to_vec()).unwrap();
107118
let head = reader.read_header().unwrap();
108119
println!("wav header: {:?}", head);

src/main.rs

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,15 @@ async fn routes(
5454
let mut tool_set = ai::openai::tool::ToolSet::default();
5555
let mut real_config: Option<StableRealtimeConfig> = None;
5656
match &config.config {
57-
config::AIConfig::Stable {
58-
llm,
59-
tts,
60-
asr: ASRConfig::Whisper(asr),
61-
} => {
62-
real_config = Some(StableRealtimeConfig {
63-
llm: llm.clone(),
64-
tts: tts.clone(),
65-
asr: asr.clone(),
66-
});
57+
config::AIConfig::Stable { llm, tts, asr } => {
58+
if let ASRConfig::Whisper(asr) = asr {
59+
real_config = Some(StableRealtimeConfig {
60+
llm: llm.clone(),
61+
tts: tts.clone(),
62+
asr: asr.clone(),
63+
});
64+
}
65+
6766
for server in &llm.mcp_server {
6867
match server.type_ {
6968
config::MCPType::SSE => {
@@ -87,22 +86,53 @@ async fn routes(
8786
_ => {}
8887
}
8988

89+
let record_config = Arc::new(services::ws_record::WsRecordSetting {
90+
record_callback_url: config.record.callback_url,
91+
});
92+
93+
let ws_setting = Arc::new(services::ws::WsSetting::new(
94+
hello_wav.clone(),
95+
config.config.clone(),
96+
tool_set.clone(),
97+
));
98+
9099
let mut router = Router::new()
91100
// .route("/", get(handler))
92-
.route("/ws/{id}", any(services::mixed_handler))
101+
.route("/v1/ws/{id}", any(services::mixed_handler))
93102
.route("/v1/chat/{id}", any(services::ws::ws_handler))
94103
.route("/v1/record/{id}", any(services::ws_record::ws_handler))
95104
.nest("/downloads", services::file::new_file_service("./record"))
96-
.layer(axum::Extension(Arc::new(services::ws::WsSetting::new(
97-
hello_wav,
98-
config.config,
99-
tool_set,
100-
))))
101-
.layer(axum::Extension(Arc::new(
102-
services::ws_record::WsRecordSetting {
103-
record_callback_url: config.record.callback_url,
104-
},
105-
)));
105+
.layer(axum::Extension(ws_setting.clone()))
106+
.layer(axum::Extension(record_config.clone()));
107+
108+
if let config::AIConfig::Stable { llm, tts, asr } = config.config {
109+
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
110+
// let tool_set = tool_set;
111+
tokio::spawn(async move {
112+
if let Err(e) =
113+
crate::services::ws::stable::run_session_manager(&llm, &tts, &asr, &tool_set, rx)
114+
.await
115+
{
116+
log::error!("Stable session manager exited with error: {}", e);
117+
}
118+
});
119+
120+
router = router
121+
.route("/ws/{id}", any(services::v2_mixed_handler))
122+
.route("/v2/stable_ws/{id}", any(services::ws::stable::ws_handler))
123+
.layer(axum::Extension(Arc::new(
124+
services::ws::stable::StableWsSetting {
125+
sessions: tx,
126+
hello_wav,
127+
},
128+
)))
129+
.layer(axum::Extension(record_config.clone()));
130+
} else {
131+
router = router
132+
.route("/ws/{id}", any(services::mixed_handler))
133+
.layer(axum::Extension(ws_setting.clone()))
134+
.layer(axum::Extension(record_config.clone()));
135+
}
106136

107137
if let Some(real_config) = real_config {
108138
log::info!(

src/services/mod.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,28 @@ pub async fn mixed_handler(
4343
.into_response()
4444
}
4545
}
46+
47+
pub async fn v2_mixed_handler(
48+
Extension(record_setting): Extension<Arc<ws_record::WsRecordSetting>>,
49+
Extension(pool): Extension<Arc<ws::stable::StableWsSetting>>,
50+
ws: WebSocketUpgrade,
51+
Path(id): Path<String>,
52+
Query(params): Query<ConnectQueryParams>,
53+
) -> Response {
54+
if params.record {
55+
ws_record::ws_handler(Extension(record_setting), ws, Path(id))
56+
.await
57+
.into_response()
58+
} else {
59+
ws::stable::ws_handler(
60+
Extension(pool),
61+
ws,
62+
Path(id),
63+
Query(ws::ConnectQueryParams {
64+
reconnect: params.reconnect,
65+
}),
66+
)
67+
.await
68+
.into_response()
69+
}
70+
}

src/services/realtime_ws.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,9 +1169,11 @@ async fn tts_and_send(
11691169
item_id: Option<String>,
11701170
text: String,
11711171
) -> anyhow::Result<()> {
1172+
let client = reqwest::Client::new();
11721173
match tts_config {
11731174
crate::config::TTSConfig::Stable(tts) => {
1174-
let wav_data = crate::ai::tts::gsv(&tts.url, &tts.speaker, &text, Some(24000)).await?;
1175+
let wav_data =
1176+
crate::ai::tts::gsv(&client, &tts.url, &tts.speaker, &text, Some(24000)).await?;
11751177
let duration_sec = send_wav(tx, response_id, item_id, text, wav_data).await?;
11761178
log::info!("Stable TTS duration: {:?}", duration_sec);
11771179
Ok(())
@@ -1184,13 +1186,15 @@ async fn tts_and_send(
11841186
}
11851187
crate::config::TTSConfig::Groq(groq) => {
11861188
let wav_data =
1187-
crate::ai::tts::groq(&groq.model, &groq.api_key, &groq.voice, &text).await?;
1189+
crate::ai::tts::groq(&client, &groq.model, &groq.api_key, &groq.voice, &text)
1190+
.await?;
11881191
let duration_sec = send_wav(tx, response_id, item_id, text, wav_data).await?;
11891192
log::info!("Groq TTS duration: {:?}", duration_sec);
11901193
Ok(())
11911194
}
11921195
crate::config::TTSConfig::StreamGSV(stream_tts) => {
11931196
let resp = crate::ai::tts::stream_gsv(
1197+
&client,
11941198
&stream_tts.url,
11951199
&stream_tts.speaker,
11961200
&text,

0 commit comments

Comments
 (0)