利用 “diart“ 和 OpenAI 的 Whisper 简化实时转录

锦通  金牌会员 | 2024-8-11 08:34:23 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 554|帖子 554|积分 1662

利用 "diart" 和 OpenAI 的 Whisper 简化实时转录

工作原理

Diart 是一个基于人工智能的 Python 库,用于实时记载说话者语言(即 "谁在什么时候说话"),它创建在 pyannote.audio 模型之上,专为实时音频流(如麦克风)而设计。
只需几行代码,diart 就能让您得到类似这样的实时发言者标签:

与此同时,Whisper 是 OpenAI 最新推出的一种为自动语音识别(ASR)而练习的模型,它对嘈杂情况的适应本领特别强,非常适合现实生活中的利用案例。
准备工作


  • 按照此处的阐明安装 diart
  • 利用 pip install git+https://github.com/linto-ai/whisper-timestamped 安装 whisper-timestamped
在这篇文章的其余部分,我将利用 RxPY(Python 的反应式编程扩展)来处理流媒体部分。如果你对它不熟悉,我发起你看看这个文档页面,了解一下基本知识。
简而言之,反应式编程就是对来自给定源(在我们的例子中是麦克风)的发射项(在我们的例子中是音频块)进行操作。
结合听和写

让我们先概述一下源代码,然后将其分解成若干块,以便更好地理解它。
  1. import logging
  2. import traceback
  3. import diart.operators as dops
  4. import rich
  5. import rx.operators as ops
  6. from diart import OnlineSpeakerDiarization, PipelineConfig
  7. from diart.sources import MicrophoneAudioSource
  8. # Suppress whisper-timestamped warnings for a clean output
  9. logging.getLogger("whisper_timestamped").setLevel(logging.ERROR)
  10. config = PipelineConfig(
  11.     duration=5,
  12.     step=0.5,
  13.     latency="min",
  14.     tau_active=0.5,
  15.     rho_update=0.1,
  16.     delta_new=0.57
  17. )
  18. dia = OnlineSpeakerDiarization(config)
  19. source = MicrophoneAudioSource(config.sample_rate)
  20. asr = WhisperTranscriber(model="small")
  21. transcription_duration = 2
  22. batch_size = int(transcription_duration // config.step)
  23. source.stream.pipe(
  24.     dops.rearrange_audio_stream(
  25.         config.duration, config.step, config.sample_rate
  26.     ),
  27.     ops.buffer_with_count(count=batch_size),
  28.     ops.map(dia),
  29.     ops.map(concat),
  30.     ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
  31.     ops.starmap(asr),
  32.     ops.map(colorize_transcription),
  33. ).subscribe(on_next=rich.print, on_error=lambda _: traceback.print_exc())
  34. print("Listening...")
  35. source.read()
复制代码
创建发言者记载模块

起首,我们创建了流媒体(又称 "在线")扬声器日记系统以及与本地麦克风相连的音频源。
我们将系统设置为利用 5 秒的滑动窗口,步长为 500 毫秒(默认值),并将延迟设置为最小值(500 毫秒),以进步相应速度。
  1. # If you have a GPU, you can also set device=torch.device("cuda")
  2. config = PipelineConfig(
  3.     duration=5,
  4.     step=0.5,
  5.     latency="min",
  6.     tau_active=0.5,
  7.     rho_update=0.1,
  8.     delta_new=0.57
  9. )
  10. dia = OnlineSpeakerDiarization(config)
  11. source = MicrophoneAudioSource(config.sample_rate)
复制代码
设置中的三个附加参数可调节扬声器识别的灵敏度:


  • tau_active=0.5: 只识别发言概率高于 50% 的发言者。
  • rho_update=0.1: Diart 会自动网络发言者的信息以自我改进(别担心,这是在本地完成的,不会与任何人共享)。在这里,我们只利用每位发言者 100ms 以上的语音进行自我改进。
  • delta_new=0.57:这是一个介于 0 和 2 之间的内部阈值,用于调节新发言人的检测。该值越小,系统对语音差异越敏感。
创建 ASR 模块

接下来,我们利用我为这篇文章创建的 WhisperTranscriber 类加载语音识别模型。
  1. # If you have a GPU, you can also set device="cuda"
  2. asr = WhisperTranscriber(model="small")
复制代码
该类的界说如下:
  1. import os
  2. import sys
  3. import numpy as np
  4. import whisper_timestamped as whisper
  5. from pyannote.core import Segment
  6. from contextlib import contextmanager
  7. @contextmanager
  8. def suppress_stdout():
  9.     # Auxiliary function to suppress Whisper logs (it is quite verbose)
  10.     # All credit goes to: https://thesmithfam.org/blog/2012/10/25/temporarily-suppress-console-output-in-python/
  11.     with open(os.devnull, "w") as devnull:
  12.         old_stdout = sys.stdout
  13.         sys.stdout = devnull
  14.         try:
  15.             yield
  16.         finally:
  17.             sys.stdout = old_stdout
  18. class WhisperTranscriber:
  19.     def __init__(self, model="small", device=None):
  20.         self.model = whisper.load_model(model, device=device)
  21.         self._buffer = ""
  22.     def transcribe(self, waveform):
  23.         """Transcribe audio using Whisper"""
  24.         # Pad/trim audio to fit 30 seconds as required by Whisper
  25.         audio = waveform.data.astype("float32").reshape(-1)
  26.         audio = whisper.pad_or_trim(audio)
  27.         # Transcribe the given audio while suppressing logs
  28.         with suppress_stdout():
  29.             transcription = whisper.transcribe(
  30.                 self.model,
  31.                 audio,
  32.                 # We use past transcriptions to condition the model
  33.                 initial_prompt=self._buffer,
  34.                 verbose=True  # to avoid progress bar
  35.             )
  36.         return transcription
  37.     def identify_speakers(self, transcription, diarization, time_shift):
  38.         """Iterate over transcription segments to assign speakers"""
  39.         speaker_captions = []
  40.         for segment in transcription["segments"]:
  41.             # Crop diarization to the segment timestamps
  42.             start = time_shift + segment["words"][0]["start"]
  43.             end = time_shift + segment["words"][-1]["end"]
  44.             dia = diarization.crop(Segment(start, end))
  45.             # Assign a speaker to the segment based on diarization
  46.             speakers = dia.labels()
  47.             num_speakers = len(speakers)
  48.             if num_speakers == 0:
  49.                 # No speakers were detected
  50.                 caption = (-1, segment["text"])
  51.             elif num_speakers == 1:
  52.                 # Only one speaker is active in this segment
  53.                 spk_id = int(speakers[0].split("speaker")[1])
  54.                 caption = (spk_id, segment["text"])
  55.             else:
  56.                 # Multiple speakers, select the one that speaks the most
  57.                 max_speaker = int(np.argmax([
  58.                     dia.label_duration(spk) for spk in speakers
  59.                 ]))
  60.                 caption = (max_speaker, segment["text"])
  61.             speaker_captions.append(caption)
  62.         return speaker_captions
  63.     def __call__(self, diarization, waveform):
  64.         # Step 1: Transcribe
  65.         transcription = self.transcribe(waveform)
  66.         # Update transcription buffer
  67.         self._buffer += transcription["text"]
  68.         # The audio may not be the beginning of the conversation
  69.         time_shift = waveform.sliding_window.start
  70.         # Step 2: Assign speakers
  71.         speaker_transcriptions = self.identify_speakers(transcription, diarization, time_shift)
  72.         return speaker_transcriptions
复制代码
转录器执行一个简单的操作,接收音频块及其日记,并按照以下步骤操作:

  • 用 Whisper 转录音频片段(带单词时间戳)
  • 通过调整单词和说话人之间的时间戳,为转录的每个片段指定说话人
将两个模块放在一起

既然我们已经创建了日记化和转录模块,那么我们就可以界说对每个音频块应用的操作链:
  1. import traceback
  2. import rich
  3. import rx.operators as ops
  4. import diart.operators as dops
  5. # Split the stream into 2s chunks for transcription
  6. transcription_duration = 2
  7. # Apply models in batches for better efficiency
  8. batch_size = int(transcription_duration // config.step)
  9. # Chain of operations to apply on the stream of microphone audio
  10. source.stream.pipe(
  11.     # Format audio stream to sliding windows of 5s with a step of 500ms
  12.     dops.rearrange_audio_stream(
  13.         config.duration, config.step, config.sample_rate
  14.     ),
  15.     # Wait until a batch is full
  16.     # The output is a list of audio chunks
  17.     ops.buffer_with_count(count=batch_size),
  18.     # Obtain diarization prediction
  19.     # The output is a list of pairs `(diarization, audio chunk)`
  20.     ops.map(dia),
  21.     # Concatenate 500ms predictions/chunks to form a single 2s chunk
  22.     ops.map(concat),
  23.     # Ignore this chunk if it does not contain speech
  24.     ops.filter(lambda ann_wav: ann_wav[0].get_timeline().duration() > 0),
  25.     # Obtain speaker-aware transcriptions
  26.     # The output is a list of pairs `(speaker: int, caption: str)`
  27.     ops.starmap(asr),
  28.     # Color transcriptions according to the speaker
  29.     # The output is plain text with color references for rich
  30.     ops.map(colorize_transcription),
  31. ).subscribe(
  32.     on_next=rich.print,  # print colored text
  33.     on_error=lambda _: traceback.print_exc()  # print stacktrace if error
  34. )
复制代码
在上述代码中,来自麦克风的所有音频块都将通过我们界说的操作链推送。
在这一系列操作中,我们起首利用 rearrange_audio_stream 将音频格式化为 5 秒钟的小块,小块之间的间隔为 500 毫秒。然后,我们利用 buffer_with_count 填充下一个批次,并应用日记化。请注意,批量大小的界说与转录窗口的大小相匹配。
接下来,我们将批次中不重叠的 500ms 日记化预测毗连起来,并应用我们的 WhisperTranscriber,只有在音频包含语音的情况下才能得到说话者感知转录。如果没有检测到语音,我们就跳过这一大块,等候下一块。
末了,我们将利用 rich 库为文本着色并打印到尺度输出中。
由于整个操作链可能有点晦涩难明,我还准备了一个操作示意图,盼望能让各人对算法有一个清晰的认识:

你可能已经注意到,我还没有界说 concat 和 colorize_transcriptions,但它们是非常简单的实用函数:
  1. import numpy as np
  2. from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow
  3. def concat(chunks, collar=0.05):
  4.     """
  5.     Concatenate predictions and audio
  6.     given a list of `(diarization, waveform)` pairs
  7.     and merge contiguous single-speaker regions
  8.     with pauses shorter than `collar` seconds.
  9.     """
  10.     first_annotation = chunks[0][0]
  11.     first_waveform = chunks[0][1]
  12.     annotation = Annotation(uri=first_annotation.uri)
  13.     data = []
  14.     for ann, wav in chunks:
  15.         annotation.update(ann)
  16.         data.append(wav.data)
  17.     annotation = annotation.support(collar)
  18.     window = SlidingWindow(
  19.         first_waveform.sliding_window.duration,
  20.         first_waveform.sliding_window.step,
  21.         first_waveform.sliding_window.start,
  22.     )
  23.     data = np.concatenate(data, axis=0)
  24.     return annotation, SlidingWindowFeature(data, window)
  25. def colorize_transcription(transcription):
  26.     """
  27.     Unify a speaker-aware transcription represented as
  28.     a list of `(speaker: int, text: str)` pairs
  29.     into a single text colored by speakers.
  30.     """
  31.     colors = 2 * [
  32.         "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
  33.         "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
  34.     ]
  35.     result = []
  36.     for speaker, text in transcription:
  37.         if speaker == -1:
  38.             # No speakerfound for this text, use default terminal color
  39.             result.append(text)
  40.         else:
  41.             result.append(f"[{colors[speaker]}]{text}")
  42.     return "\n".join(result)
复制代码
如果您对 pyannote.audio 中利用的 Annotation 和 SlidingWindowFeature 类不熟悉,我发起您检察一下它们的官方文档页面。
在这里,我们利用 SlidingWindowFeature 作为音频块的 numpy 数组封装器,这些音频块还带有 SlidingWindow 实例提供的时间戳。
我们还利用 Annotation 作为首选数据布局来表示日记化预测。它们可被视为包含说话者 ID 以及开始和竣事时间戳的片段有序列表。
结论

在这篇文章中,我们将 diart 流媒体扬声器日记库与 OpenAI 的 Whisper 结合起来,以得到实时的扬声器彩色转录。
为了方便起见,作者在 GitHub gist 中提供了完备的脚本。


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

锦通

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表