commit 84981115056086dd481af844a7531d645e22c4ba Author: yumoqing Date: Wed Jul 16 14:32:27 2025 +0800 first commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..ddb495d --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# rtcllm + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..a0c39d1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +aiortc +websockets +webrtcvad +faster-whisper +git+https://github.com/suno-ai/bark +nvidia-cublas-cu11 +nvidia-cudnn-cu11 diff --git a/rtcllm/a2a.py b/rtcllm/a2a.py new file mode 100644 index 0000000..0d2463a --- /dev/null +++ b/rtcllm/a2a.py @@ -0,0 +1,80 @@ +import os +from traceback import print_exc +import io +import asyncio +from av import AudioFrame + +from aiortc import VideoStreamTrack, AudioStreamTrack + +# 计算每个样本的字节数,对于's16'是2字节,'s32'是4字节等 +def bytes_to_audio_frame(bytes_data, format='s16', layout='mono', sample_rate=16000): + """ + 从字节数据构造av.AudioFrame对象。 + + 参数: + - bytes_data: 字节数据,代表音频内容。 + - format: 音频样本格式字符串,如's16'表示16位有符号整数。 + - layout: 通道布局,如'mono'、'stereo'等。 + - sample_rate: 采样率。 + """ + # 根据给定的参数创建AudioFrame + frame = AudioFrame(format=format, + layout=layout, + samples=len(bytes_data) // 2) + + # 将字节数据复制到AudioFrame中 + frame.planes[0].update(bytes_data) + return frame + +class LLMAudioStreamTrack(AudioStreamTrack): + def __init__(self, omni_infer): + super().__init__() + self.oi = omni_infer + self.audio_iters = [] + self.cur_iters = None + self.tmp_files = [] + + async def recv(self): + try: + b = self.get_audio_bytes() + if b is None: + return await super().recv() + frame = bytes_to_audio_frame(b) + print('LLMAudioStreamTrack return frame ...') + return frame + except Exception as e: + print_exc() + print(f'{self.__class__.__name__} recv() exception happened') + return None + + def set_cur_audio_iter(self): + if len(self.audio_iters) == 0: + return False + self.cur_iters = self.audio_iters[0] + self.audio_iters.remove(self.cur_iters) + return True + + def get_audio_bytes(self): + if self.cur_iters is None: + if not self.set_cur_audio_iter(): + return None + try: + b = next(self.cur_iters) + return b + except StopIteration: + self.cur_iters = None + if len(self.tmp_files) > 0: + tf = self.tmp_files[0] + self.tmp_files.remove(tf) + os.remove(tf) + return self.get_audio_bytes() + + def _feed(self, audio_file): + self.tmp_files.append(audio_file) + if audio_file is None: + print(f'*****{self.__class__.__name__}._feed(),{audio_file=}') + return + abiters = self.oi.run_AT_batch_stream(audio_file) + self.audio_iters.append(abiters) + return abiters + diff --git a/rtcllm/aav.py b/rtcllm/aav.py new file mode 100644 index 0000000..7470aff --- /dev/null +++ b/rtcllm/aav.py @@ -0,0 +1,154 @@ +import random +from traceback import print_exc +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay +from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack +from aiortc.mediastreams import MediaStreamError + +class MyMediaPlayer(MediaPlayer): + pass + +class MyTrackBase(MediaStreamTrack): + def __init__(self, source): + super().__init__() + self.dumb = None + self.source = source + self.set_source_track() + print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}') + + def set_source_track(self): + if self.kind == 'audio': + self.track = self.source.audio + else: + self.track = self.source.video + + def source_reload(self): + print(f'reload the source----{self.source._file_path}') + self.set_source(MyMediaPlayer(self.source._file_path)) + + def set_source(self, source): + self.source = source + self.set_source_track() + + async def recv(self): + print(f'============{self.__class__.__name__}, {self.source.duration=}, {self.source.time=}') + if self.source is None: + return await self.dumb.recv() + try: + f = await self.track.recv() + return f + except MediaStreamError: + return await self.dumb.recv() + except Exception as e: + print_exc() + print(f'{e}') + return await self.dumb.recv() + return f + +class MyAudioStreamTrack(MyTrackBase): + kind = 'audio' + def __init__(self, source): + super().__init__(source) + self.dumb = AudioStreamTrack() + +class MyVideoStreamTrack(MyTrackBase): + kind = 'video' + def __init__(self, source): + super().__init__(source) + self.dumb = AudioStreamTrack() + +""" +class MyAudioStreamTrack(AudioStreamTrack): + def __init__(self, source=None): + super().__init__() + self.source = source + self.set_source_track() + print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}') + + async def recv(self): + try: + return await self._recv() + except MediaStreamError: + print(f'{self.__class__.__name__} reach ended of the media ...') + self.source_reload() + return await self._recv() + except Exception as e: + print_exc() + raise e + + async def _recv(self): + if self.source is None: + print('self._recv() return None') + return None + f = await self.track.recv() + return f + + def set_source_track(self): + if self.kind == 'audio': + self.track = self.source.audio + else: + self.track = self.source.video + + def source_reload(self): + print(f'reload the source----{self.source._file_path}') + self.set_source(MyMediaPlayer(self.source._file_path)) + + def set_source(self, source): + self.source = source + self.set_source_track() + +class MyVideoStreamTrack(VideoStreamTrack): + def __init__(self, source=None): + super().__init__() + self.source = source + self.set_source_track() + print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}') + + async def recv(self): + try: + return await self._recv() + except MediaStreamError: + print(f'{self.__class__.__name__} reach ended of the media ...') + self.source_reload() + return self._recv() + except Exception as e: + print_exc() + raise e + + async def _recv(self): + if self.source is None: + return None + f = await self.track.recv() + return f + + def set_source_track(self): + if self.kind == 'audio': + self.track = self.source.audio + else: + self.track = self.source.video + + def source_reload(self): + print(f'reload the source----{self.source._file_path}') + self.set_source(MyMediaPlayer(self.source._file_path)) + + def set_source(self, source): + self.source = source + self.set_source_track() + + +class LoopingVideoTrack(VideoStreamTrack): + ''' + A video track that loops a video file. + ''' + def __init__(self, filename): + super().__init__() + self.player = MyMediaPlayer(filename) + print(dir(self.player)) + + async def recv(self): + frame = await self.player.video.recv() + if self.player.video.readyState != 'live': + self.player = MyMediaPlayer(self.player._file_path) + frame = await self.player.video.recv() + return frame +""" + diff --git a/rtcllm/audio_mix.py b/rtcllm/audio_mix.py new file mode 100644 index 0000000..e2360e2 --- /dev/null +++ b/rtcllm/audio_mix.py @@ -0,0 +1,75 @@ +import asyncio +from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack +from av import AudioFrame +from aiortc.contrib.media import MediaPlayer +import numpy as np + +class MixedAudioTrack(MediaStreamTrack): + kind = "audio" + + def __init__(self, tracks): + super().__init__() + self.tracks = tracks + + def add_track(self, track): + if track in self.tracks: + return + self.tracks.append(track) + + de del_track(self, track): + tracks = [ t for t in self.tracks if t != track ] + self.tracks = tracks + + async def recv(self): + # 获取所有音频轨道的数据 + audio_data = [] + for track in self.tracks: + frame = await track.recv() + if frame: + audio_data.append(frame.to_ndarray()) + + # 检查是否有有效的音频数据 + if not audio_data: + return None + + # 将音频数据转换为 numpy 数组 + audio_arrays = [np.frombuffer(data, dtype=np.int16) for data in audio_data] + + # 确保所有音频数组长度相同 + min_length = min(len(arr) for arr in audio_arrays) + audio_arrays = [arr[:min_length] for arr in audio_arrays] + + # 混合音频数据 + mixed_audio = np.sum(audio_arrays, axis=0, dtype=np.int16) + + # 创建新的音频帧 + new_frame = AudioFrame(format="s16", layout="stereo", samples=len(mixed_audio) // 2) + new_frame.planes[0].update(mixed_audio.tobytes()) + new_frame.pts = self._timestamp + self._timestamp += new_frame.samples + return new_frame + +if __name__ == '__main__': + # 示例:创建两个音频轨道 + track1 = MediaPlayer('audio1.wav').audio + track2 = MediaPlayer('audio2.wav').audio + + # 创建混合音频轨道 + mixed_track = MixedAudioTrack([track1, track2]) + + # 创建 RTCPeerConnection 并添加混合音频轨道 + pc = RTCPeerConnection() + pc.addTrack(mixed_track) + + # 以下部分是用于建立 WebRTC 连接的代码 + # 你可以根据需要进行修改和扩展 + async def create_offer(): + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + print("Local description set successfully") + # 通常在这里会发送 SDP 到对端 + + # 启动异步事件循环 + loop = asyncio.get_event_loop() + loop.run_until_complete(create_offer()) + loop.run_forever() diff --git a/rtcllm/examples.py b/rtcllm/examples.py new file mode 100644 index 0000000..282674b --- /dev/null +++ b/rtcllm/examples.py @@ -0,0 +1,214 @@ +import argparse +import asyncio +import json +import logging +import os +import ssl +import uuid + +import cv2 +from aiohttp import web +from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay +from av import VideoFrame + +ROOT = os.path.dirname(__file__) + +logger = logging.getLogger("pc") +pcs = set() +relay = MediaRelay() + + +class VideoTransformTrack(MediaStreamTrack): + """ + A video stream track that transforms frames from an another track. + """ + + kind = "video" + + def __init__(self, track, transform): + super().__init__() # don't forget this! + self.track = track + self.transform = transform + + async def recv(self): + frame = await self.track.recv() + + if self.transform == "cartoon": + img = frame.to_ndarray(format="bgr24") + + # prepare color + img_color = cv2.pyrDown(cv2.pyrDown(img)) + for _ in range(6): + img_color = cv2.bilateralFilter(img_color, 9, 9, 7) + img_color = cv2.pyrUp(cv2.pyrUp(img_color)) + + # prepare edges + img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + img_edges = cv2.adaptiveThreshold( + cv2.medianBlur(img_edges, 7), + 255, + cv2.ADAPTIVE_THRESH_MEAN_C, + cv2.THRESH_BINARY, + 9, + 2, + ) + img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) + + # combine color and edges + img = cv2.bitwise_and(img_color, img_edges) + + # rebuild a VideoFrame, preserving timing information + new_frame = VideoFrame.from_ndarray(img, format="bgr24") + new_frame.pts = frame.pts + new_frame.time_base = frame.time_base + return new_frame + elif self.transform == "edges": + # perform edge detection + img = frame.to_ndarray(format="bgr24") + img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) + + # rebuild a VideoFrame, preserving timing information + new_frame = VideoFrame.from_ndarray(img, format="bgr24") + new_frame.pts = frame.pts + new_frame.time_base = frame.time_base + return new_frame + elif self.transform == "rotate": + # rotate image + img = frame.to_ndarray(format="bgr24") + rows, cols, _ = img.shape + M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) + img = cv2.warpAffine(img, M, (cols, rows)) + + # rebuild a VideoFrame, preserving timing information + new_frame = VideoFrame.from_ndarray(img, format="bgr24") + new_frame.pts = frame.pts + new_frame.time_base = frame.time_base + return new_frame + else: + return frame + + +async def index(request): + content = open(os.path.join(ROOT, "index.html"), "r").read() + return web.Response(content_type="text/html", text=content) + + +async def javascript(request): + content = open(os.path.join(ROOT, "client.js"), "r").read() + return web.Response(content_type="application/javascript", text=content) + + +async def offer(request): + params = await request.json() + offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) + + pc = RTCPeerConnection() + pc_id = "PeerConnection(%s)" % uuid.uuid4() + pcs.add(pc) + + def log_info(msg, *args): + logger.info(pc_id + " " + msg, *args) + + log_info("Created for %s", request.remote) + + # prepare local media + player = MediaPlayer(os.path.join(ROOT, "demo-instruct.wav")) + if args.record_to: + recorder = MediaRecorder(args.record_to) + else: + recorder = MediaBlackhole() + + @pc.on("datachannel") + def on_datachannel(channel): + @channel.on("message") + def on_message(message): + if isinstance(message, str) and message.startswith("ping"): + channel.send("pong" + message[4:]) + + @pc.on("connectionstatechange") + async def on_connectionstatechange(): + log_info("Connection state is %s", pc.connectionState) + if pc.connectionState == "failed": + await pc.close() + pcs.discard(pc) + + @pc.on("track") + def on_track(track): + log_info("Track %s received", track.kind) + + if track.kind == "audio": + pc.addTrack(player.audio) + recorder.addTrack(track) + elif track.kind == "video": + pc.addTrack( + VideoTransformTrack( + relay.subscribe(track), transform=params["video_transform"] + ) + ) + if args.record_to: + recorder.addTrack(relay.subscribe(track)) + + @track.on("ended") + async def on_ended(): + log_info("Track %s ended", track.kind) + await recorder.stop() + + # handle offer + await pc.setRemoteDescription(offer) + await recorder.start() + + # send answer + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + return web.Response( + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), + ) + + +async def on_shutdown(app): + # close peer connections + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="WebRTC audio / video / data-channels demo" + ) + parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)") + parser.add_argument("--key-file", help="SSL key file (for HTTPS)") + parser.add_argument( + "--host", default="0.0.0.0", help="Host for HTTP server (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=8080, help="Port for HTTP server (default: 8080)" + ) + parser.add_argument("--record-to", help="Write received media to a file.") + parser.add_argument("--verbose", "-v", action="count") + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + if args.cert_file: + ssl_context = ssl.SSLContext() + ssl_context.load_cert_chain(args.cert_file, args.key_file) + else: + ssl_context = None + + app = web.Application() + app.on_shutdown.append(on_shutdown) + app.router.add_get("/", index) + app.router.add_get("/client.js", javascript) + app.router.add_post("/offer", offer) + web.run_app( + app, access_log=None, host=args.host, port=args.port, ssl_context=ssl_context + ) diff --git a/rtcllm/rec.py b/rtcllm/rec.py new file mode 100644 index 0000000..f7f6d52 --- /dev/null +++ b/rtcllm/rec.py @@ -0,0 +1,70 @@ +import pyaudio +import av +import numpy as np +from vad import frames_write_wave +import sys +import select +from vad import MyVad, bytes2frame, to16000_160_frames + +def check_input(): + if select.select([sys.stdin], [], [], 0.0)[0]: + input_str = sys.stdin.readline().strip() + return input_str + else: + return None + +def mic(seconds=None, frames_ms=20): + # PyAudio setup + # + p = pyaudio.PyAudio() + + # Audio format + FORMAT = pyaudio.paInt16 # 16-bit resolution + CHANNELS = 1 # 1 channel for mono + RATE = 16000 # 44.1kHz sampling rate + CHUNK = int(RATE / 100) * 2 # 1024 samples per frame + # PyAudio input stream + stream = p.open(format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=CHUNK) + + print("Recording and streaming audio...") + if seconds is not None: + cnt = int(seconds * 1000 / frames_ms) + else: + cnt = 0 + i = 0 + while True: + if cnt == 0: + if check_input() == 'q': + break + else: + if i>=cnt: + break + # Read audio frames from the microphone + data = stream.read(CHUNK) + # print(f'{data.__class__.__name__}, {len(data)=}, {CHUNK=}') + frames = to16000_160_frames(bytes2frame(data)) + for frame in frames: + yield frame + i += 1 + + # Close the audio stream and the container + stream.close() + p.terminate() + +if __name__ == '__main__': + def cb(f): + print(f'{f} voice wave file') + + i = 0 + vad = MyVad(callback=cb) + for f in mic(): + if i == 0: + print(f'{f.sample_rate=}, {f.samples=}, {f.layout=} ') + i += 1 + vad.vad_check(f) + + print(f'record save to {f}') diff --git a/rtcllm/rtc.old.py b/rtcllm/rtc.old.py new file mode 100644 index 0000000..b96ccd3 --- /dev/null +++ b/rtcllm/rtc.old.py @@ -0,0 +1,246 @@ +import os +import sys +sys.path.append('./mini_omni') + +import asyncio +import random +import json + +from functools import partial + +from appPublic.dictObject import DictObject +from appPublic.hf import hf_socks5proxy +from appPublic.worker import awaitify +from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate +from aiortc.sdp import candidate_from_sdp, candidate_to_sdp +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay + +# from websockets.asyncio.client import connect +# from websockets.asyncio.client import connect +from websockets.client import connect +import websockets +from stt import asr +from vad import AudioTrackVad +from a2a import LLMAudioStreamTrack +from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack + +from mini_omni.inference import OmniInference + +videos = ['./1.mp4', './2.mp4'] + +async def pc_get_local_candidates(pc, peer): + its = pc.__dict__.get('_RTCPeerConnection__iceTransports') + coros = map(lambda t: t.iceGatherer.gather(), its ) + await asyncio.gather(*coros) + ss = pc.__dict__.get('_RTCPeerConnection__iceTransport') + + if not peer.l_candidates: + peer.l_candidates = [] + peer.sdp_id = 0 + + for t in its: + for c in t._connection.local_candidates: + if c not in peer.l_candidates: + # print(f'{c=}, {dir(c)}') + c.sdpMid = str(peer.sdp_id) + peer.sdp_id += 1 + peer.l_candidates.append(c) + pc.emit('icecandidate', c) + +RTCPeerConnection.get_local_candidates = pc_get_local_candidates + +class RTCLLM: + def __init__(self, ws_url, iceServers): + self.ws_url = ws_url + self.omni_infer = OmniInference(ckpt_dir='/d/models/mini-omni') + self.iceServers = iceServers + self.peers = DictObject() + self.vid = 0 + self.info = { + 'id':'rtcllm_agent', + 'name':'rtcllm_agent' + } + self.dc = None + self.loop = asyncio.get_event_loop() + + async def ws_send(self, s): + await self.ws.send(s) + + def get_pc(self, data): + return self.peers[data['from'].id].pc + + async def run(self): + async with connect(self.ws_url) as self.ws: + await self.login() + while True: + msg = await self.ws.recv() + data = DictObject(**json.loads(msg)) + d = data.data + print(f'ws recv(): {d.type=}') + func = self.handlers.get(d.type) + if func: + f = partial(func, self) + await f(d) + self.ws.close() + + async def login(self): + await self.ws_send(json.dumps({ + 'type':'login', + 'info':self.info + })) + + async def on_icecandidate(self, pc, to, candidate): + if candidate: + candi = { + 'candidate':'candidate:' + candidate.to_sdp(), + 'sdpMid':candidate.sdpMid, + 'type': candidate.type + } + # print('***********on_icecandidate()', candi) + await self.ws_send(json.dumps({ + "type":"iceCandidate", + "to":to, + "candidate":candi + })) + + async def save_onlineList(self, data): + # print(f'{self}, {type(self)}') + self.onlineList = data.onlineList + + async def vad_voiceend(self, peer, audio): + if audio is not None: + feed = awaitify(peer.llmtrack._feed) + ret = await feed(audio) + print(f'self.feed("{audio}") return {ret}') + # os.remove(audio) + + async def auto_accept_call(self, data): + opts = DictObject(iceServers=self.iceServers) + pc = RTCPeerConnection(opts) + player = MyMediaPlayer('./1.mp4') + llmtrack = LLMAudioStreamTrack(self.omni_infer) + self.peers[data['from'].id] = DictObject(**{ + 'info':data['from'], + 'llmtrack':llmtrack, + 'player':player, + 'pc':pc + }) + pc.addTrack(llmtrack) + # pc.addTrack(MyAudioStreamTrack(player)) + pc.addTrack(MyVideoStreamTrack(player)) + # pc.addTrack(LoopingVideoTrack('./1.mp4')) + # pc.addTrack(player.video) + await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']})) + print('auto_accept_call() end') + + async def pc_track(self, peerid, track): + peer = self.peers[peerid] + pc = peer.pc + if track.kind == 'audio': + f = partial(self.vad_voiceend, peer) + vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f) + peer.vadtrack = vadtrack + vadtrack.start_vad() + + async def pc_connectionState_changed(self, peerid): + peer = self.peers[peerid] + pc = peer.pc + print(f'conn_state={pc.connectionState} ...........') + if pc.connectionState == 'connected': + peer.dc = pc.createDataChannel(peer.info.name) + return + if pc.connectionState == 'closed': + await pc.close() + if peer.dc: + await peer.dc.close() + await self.ws_send(json.dumps({ + 'type':'disconnect', + 'to': peer.info + })) + peers = { + k:v for k,v in self.peers.items() if k != peerid + } + self.peers = peers + + async def response_offer(self, data): + pc = self.get_pc(data) + peer = self.peers[data['from'].id] + if pc is None: + print(f'{self.peers=}, {data=}') + return + pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id)) + pc.on('track', partial(self.pc_track, data['from'].id)) + pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from'])) + + offer = RTCSessionDescription(** data.offer) + await pc.setRemoteDescription(offer) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + await self.ws_send(json.dumps({ + 'type':'answer', + 'answer':{'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp}, + 'to':data['from'] + })) + cands = await pc.get_local_candidates(peer) + + """ + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + await self.ws_send(json.dumps({ + 'type':'offer', + 'offer': {'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp}, + 'to':data['from'] + })) + """ + + async def accept_answer(self, data): + pc = self.get_pc(data) + answer = RTCSessionDescription(data.answer); + pc.setRemoteDescription(answer) + + async def accept_iceCandidate(self, data): + pc = self.get_pc(data) + candidate = data.candidate + # print('accepted candidate=', candidate) + """ + rtc_candidate = RTCIceCandidate( + ip=ip, + port=port, + protocol=protocol, + priority=priority, + foundation=foundation, + component=component, + type=type, + sdpMid=candidate['sdpMid'], + sdpMLineIndex=candidate['sdpMLineIndex'] + ) + """ + rtc_candidate = candidate_from_sdp(candidate['candidate'].split(":", 1)[1]) + rtc_candidate.sdpMid = candidate['sdpMid'] + rtc_candidate.sdpMLineIndex = candidate['sdpMLineIndex'] + await pc.addIceCandidate(rtc_candidate) + # print('addIceCandidate ok') + + handlers = { + 'onlineList':save_onlineList, + 'callRequest':auto_accept_call, + 'offer':response_offer, + 'answer':accept_answer, + 'iceCandidate':accept_iceCandidate, + } + +async def main(): + hf_socks5proxy() + agent = RTCLLM(ws_url='wss://sage.open-computing.cn/wss/ws/rtc_signaling.ws', + iceServers=[{ + 'urls':'stun:stun.open-computing.cn:13478'},{ + 'urls':'turn:stun.open-computing.cn:13479', + 'username':'turn', + 'credential':'server' + }]) + print('running ...') + await agent.run() + +if __name__ == '__main__': + asyncio.run(main()) + diff --git a/rtcllm/rtc.py b/rtcllm/rtc.py new file mode 100644 index 0000000..b4ac9f5 --- /dev/null +++ b/rtcllm/rtc.py @@ -0,0 +1,426 @@ +import os +import sys +sys.path.append('./mini_omni') + +import asyncio +import random +import json + +from functools import partial + +from appPublic.dictObject import DictObject +from appPublic.hf import hf_socks5proxy +from appPublic.worker import awaitify +from appPublic.log import debug, error, exception, critical +from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate +from aiortc.sdp import candidate_from_sdp, candidate_to_sdp +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay + +# from websockets.asyncio.client import connect +# from websockets.asyncio.client import connect +from websockets.client import connect +import websockets +from stt import asr +from vad import AudioTrackVad +from a2a import LLMAudioStreamTrack +from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack + +from mini_omni.inference import OmniInference + + + +class RTCSignaling: + /* + { + signaling_url: + info: + connect_opts: + onclose: + onlogin: + } + */ + def __init__(self, opts): + self.signaling_url = opts.signaling_url + self.info = opts.info + self.connect_opts = opts.connect_opts + self.peers = [] + self.sessions = DictObject() + self.handlers = DictObject() + self.sessionhandlers = DictObject() + self.init_websocket() + self.hb_task = None + self.socket = None + self.heartbeat_period = opts.heartbeat_period + self.onclose = opts.onclose + self.onlogin = opts.onlogin + self.onopen = opts.onopen + if not self.heartbeat_period: + self.heartbeat_period = 0 + + async def run(self): + async with connect(self.ws_url) as self.socket: + await self.login() + while True: + msg = await self.socket.recv() + data = DictObject(**json.loads(msg)) + d = data.data + debug(f'ws recv(): {d.type=}') + await self.signaling_recvdata(d) + await self.socket.close() + + def reconnect(self): + debug('eror happened') + if self.hb_task: + self.hb_task.cancel() + self.hb_task = None + + if self.onclose: + self.onclose() + + return + + def del_handler(self, sessionid): + handlers = { k:v \ + for k,v in self.handlers.items() if k!=sessionid} + self.handlers = handlers + + async def signaling_recvdata(self, data): + debug(f'ws recv data={data}') + if data.session: + sessionid = data.session.sessionid + sessiontype = data.session.sessiontype + handler = self.handlers[sessionid] + if not handler: + k = self.sessionhandlers[sessiontype] + if not k: + e = Exception('recvdata_handler() exception(' + sessiontype + ')') + raise e + + h = k(self, data.session, self.connect_opts) + handler = h.recvdata_handler.bind(h) + self.add_handler(sessionid, handler) + + await handler(data) + else: + await self.recvdata_handler(data) + + async def add_handler(self, key, handler): + self.handlers[key] = handler + + async def add_sessionhandler(self, sessiontype, handler): + self.sessionhandlers[sessiontype] = handler + + async def recvdata_handler(self, data): + debug(f'recv data={data}') + if data.type == 'online'): + for p in data.online: + d = self.peers[p.id] + if (!d) d = DictOject() + d.update(p) + self.peers[p.id] = d + + if self.onlogin: + self.onlogin(data.online) + + return + + debug(f'recv data= {data} NOT HANDLED') + + async def new_session(self, sessiontype, peer): + k = self.sessionhandlers[sessiontype] + if not k: + e = Exception('new_session() exception(' + sessiontype + ')') + raise e + sessionid = getID() + session = DictObject(**{ + sessiontype:sessiontype, + sessionid:sessionid + }) + d = DictObject(**{ + "type":'new_session', + "session": session + }) + await self.send_data(d) + opts = self.connect_options.clone() + opts.peer_info = peer + h = k(self, session, opts) + self.add_handler(sessionid, h.recvdata_handler) + return h + + async def login(self): + debug(f'login send {self.heartbeat_period=}') + d = DictObject( + type='login' + ) + self.send_data(d) + if self.heartbeat_period > 0): + debug(f'call login again in {self.heartbeat_period} seconds') + self.hb_task = schedule_once(self.login.bind(self), self.heartbeat_period) + + + async def logout(self): + d = DictObject( + type='logout' + ) + await self.send_data(d) + + await send_data(self, d): + d.msgfrom = self.info + s = json.dumps(d) + debug(f'send_data() {s=}') + await self.socket.send(s) + + async def socket_send(self, s): + await self.socket.send(s) + +class RTCP2PConnect: + /* + opts:{ + ice_servers: + peer_info: + auto_callaccept: true or false + media_options: { video:trur or false, audio:true or false } + data_connect: true or false + } + */ + + def __init__(self, signaling, session, opts): + self.id = bricks.uuid() + self.signaling = signaling + self.session = session + self.requester = false + self.opts = opts + self.peers = DictObject() + self.signal_handlers = DictObject() + self.local_stream = None + self.localVideo = None + self.add_handler('sessioncreated', self.h_sessioncreated) + self.add_handler('callrequest', self.h_callrequest) + self.add_handler('callaccepted', self.h_callaccepted) + self.add_handler('offer', self.h_offer) + self.add_handler('answer', self.h_answer) + self.add_handler('icecandidate', self.h_icecandidate) + self.add_handler('sessionquit', self.h_sessionquit) + + def add_handler(self, typ, f): + self.signal_handlers[typ] = f + + def get_handler(self, typ): + return self.signal_handlers[typ] + + + async def p2pconnect(self, peer): + await self.getLocalStream() + p = self.peers.get(peer.id) + if not p: + await self.createPeerConnection(peer) + else: + debug(f'{peer=}, connect exists {self=}') + + debug(f'p2pconnect() called {self=} {peer=}') + + async def h_sessioncreated(self, data): + await self.p2pconnect(self.opts.peer_info, 'requester') + if self.opts.peer_info: + d = DictObject(**{ + "type":'callrequest', + "msgto":self.opts.peer_info + }) + await self.signaling_send(d) + self.requester = true + + async def h_callrequest(self, data): + if self.opts.auto_callaccept or true): + await self.p2pconnect(data.msgfrom, 'responser') + d = DictObject(**{ + "type":'callaccepted', + "msgto":data.msgfrom + }) + await self.signaling_send(d) + return + + async def h_callaccepted(self, data): + self.createDataChannel(data.msgfrom) + await self.send_offer(data.msgfrom, true) + + async def h_offer(self, data): + debug(f'h_offer(), {self=}, peer={data.msgfrom}') + pc = self.peers[data.msgfrom.id].pc + offer = RTCSessionDescription(data.offer) + await pc.setRemoteDescription(offer) + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + d = DictObject(**{ + type:'answer', + answer:pc.localDescription, + msgto:data.msgfrom + }) + await self.signaling_send(d) + if self.peers[data.msgfrom.id].role != 'requester': + await self.send_offer(data.msgfrom) + + async def h_answer(self, data): + desc = RTCSessionDescription(data.answer) + pc = self.peers[data.msgfrom.id].pc + await pc.setRemoteDescription(desc) + + async def h_icecandidate(self, data): + candidate = RTCIceCandidate(data.candidate) + pc = self.peers[data.msgfrom.id].pc + await pc.addIceCandidate(candidate) + + async def h_sessionquit(self, data): + pc = self.peers[data.msgfrom.id].pc + self.peer_close(data.msgfrom) + + async send_offer(self, peer, initial){ + debug(f'send_offer(), peers={self.peers}, {peer=}') + pc = self.peers[peer.id].pc + if initial: + self.peers[peer.id].role = 'requester' + else: + self.peers[peer.id].role = 'responser' + + offer = await pc.createOffer() + await pc.setLocalDescription(offer) + d = DictObject(**{ + "type":'offer', + "offer":pc.localDescription, + "msgto":peer + }) + await self.signaling_send(d) + + async def send_candidate(self, peer, event): + debug(f'send_candidate called, {peer=},{event=}') + if (event.candidate) { + candidate = event.candidate + self.signaling_send({ + type: 'icecandidate', + msgto:peer, + candidate: candidate + }) + } + } + + signaling_send(d){ + d.session = self.session + self.signaling.send_data(d) + } + + async recvdata_handler(data){ + debug('recvdata=', data, self.signal_handlers) + f = self.get_handler(data.type) + if (f) { + await f(data) + return + } + debug('recvdata=', data, 'NOT HANDLED') + } + + async ice_statechange(peer, event){ + pc = self.peers[peer.id].pc + debug(`oniceconnectionstatechange, pc.iceConnectionState is ${pc.iceConnectionState}.`) + } + + async connection_statechange(peer, event){ + pc = self.peers[peer.id].pc + debug(`${peer.id} state changed. new state=${pc.connectionState}`) + debug('state=', pc.connectionState, typeof(pc.connectionState)) + if (pc.connectionState == 'disconnected'){ + self.peer_close(peer) + if (self.opts.on_pc_disconnected){ + self.opts.on_pc_disconnected(peer) + } + return + } + if (pc.connectionState == 'connected'){ + debug('state is connected, data_connect=', + self.opts.data_connect) + if(self.opts.on_pc_connected){ + self.opts.on_pc_connected(peer) + } + } + } + + async dc_accepted(peer, event){ + debug('accept datachannel ....') + self.peers[peer.id].dc = event.channel + await self.dc_created(peer, self.peers[peer.id].dc) + } + async dc_created(self, peer, dc): + debug('dc_created.....', dc) + dc.onmessage = self.datachannel_message.bind(peer) + dc.onopen = self.datachannel_open(peer) + dc.onclose = self.datachannel_close(peer) + + async datachannel_message(self, peer, event): + debug('datachannel_message():', self, arguments) + dc = self.peers[peer.id].dc + if self.opts.on_dc_messaage: + await self.opts.on_dc_message(dc, event.data) + + async def datachannel_open(self, peer): + debug('datachannel_open():', self, arguments) + dc = self.peers[peer.id].dc + if self.opts.on_dc_open: + await self.opts.on_dc_open(dc) + + async def datachannel_close(self, peer): + debug('datachannel_close():', self, arguments) + dc = self.peers[peer.id].dc + if self.opts.on_dc_close: + await self.opts.on_dc_close(dc) + + async def createDataChannel(self, peer): + pc = self.peers[peer.id].pc + self.peers[peer.id].dc = pc.createDataChannel('chat', {ordered:true}) + dc = self.peers[peer.id].dc + await self.dc_created(peer, self.peers[peer.id].dc) + debug('dc created', self.peers[peer.id].dc) + + async def createPeerConnection(self, peer): + configuration = DictObject(**{ + "iceServers":self.opts.ice_servers + }) + debug(f'RTCPC {configuration=}') + pc = RTCPeerConnection(configuration) + if self.local_stream: + for track in self.local_stream.getTracks(): + pc.addTrack(track, self.local_stream) + self.peers[peer.id] = peer + self.peers[peer.id].pc = pc + self.peers[peer.id].role = '' + pc.onicecandidate = partial(self.send_candidate, peer) + pc.oniceconnectionstatechange = \ + partial(self.ice_statechange, peer) + pc.onconnectionstatechange = \ + partial(self.connection_statechange, peer) + pc.ondatachanel = partial(self.dc_accepted, peer) + pc.ontrack = partial(self.onpctrack, peer) + + def peer_close(self, peer){ + pc = self.peers[peer.id].pc + video = self.peers[peer.id].video + for r in pc.getReceivers(): + r.track.stop() + + for s in pc.getSenders(): + s.track.stop() + + dc = self.peers[peer.id].dc + if (dc){ + dc.close() + } + pc.close() + self.peers = {k:v for k,v in self.peers.items() if k!=peer.id} + cnt = len(self.peers.keys()) + if (keys.length == 0){ + self.localVideo.get_stream() + .getTracks().forEach(track => track.stop()) + self.local_stream.getTracks().forEach(track => track.stop()) + self.local_screan.getTracks().forEach(track => track.stop()) + self.signaling.del_handler(self.session.sessionid) + } + } +} + diff --git a/rtcllm/stt.py b/rtcllm/stt.py new file mode 100644 index 0000000..291b62f --- /dev/null +++ b/rtcllm/stt.py @@ -0,0 +1,65 @@ +import aiohttp +from appPublic.dictObject import DictObject +from appPublic.oauth_client import OAuthClient + +desc = { + "path":"/api/generate.dspy", + "method":"POST", + "headers":[ + { + "name":"Content-Type", + "value":"application/json" + } + ], + "data":[{ + "name":"audio_file", + "value":"${audio_file}" + },{ + "name":"model", + "value":"whisper" + } + ], + "resp":[ + { + "name":"content", + "value":"content" + } + ] +} +opts = { + "data":{ + "oops":0 + }, + "asr":desc +} + +async def asr(a_file): + if not a_file: + return + """ + r = None + with open(a_file, 'rb') as f: + oc = OAuthClient(DictObject(**opts)) + r = await oc('https://sage.open-computing.cn/asr', 'asr', {'audio_file':f}) + print(f'{r=}') + return r + """ + with open(a_file, 'rb') as f: + async with aiohttp.ClientSession() as session: + async with session.post("https://sage.open-computing.cn/asr/api/generate.dspy", data={ + "model":"whisper", + "audio_file":f + }) as response: + r = await response.json() + print(f'{r=}') + return r + """ + segments, info = model.transcribe(a_file, beam_size=5) + txt = '' + for s in segments: + txt += s.text + return { + 'content': txt, + 'language': info.language + } + """ diff --git a/rtcllm/vad.py b/rtcllm/vad.py new file mode 100644 index 0000000..02f1f43 --- /dev/null +++ b/rtcllm/vad.py @@ -0,0 +1,181 @@ +import base64 +from inspect import isfunction, iscoroutinefunction +from traceback import print_exc +import asyncio +import collections +import contextlib +from appPublic.folderUtils import temp_file +from appPublic.worker import awaitify +from aiortc import MediaStreamTrack +from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay +import webrtcvad +import wave +from scipy.io.wavfile import write +import numpy as np +import av +from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat + +def frames_write_wave(frames): + path = temp_file(suffix='.wav') + output_container = av.open(path, 'w') + out_stream = output_container.add_stream('pcm_s16le') + for frame in frames: + for packet in out_stream.encode(frame): + output_container.mux(packet) + for packet in out_stream.encode(None): + output_container.mux(packet) + output_container.close() + return path + +def bytes2frame(byts, channels=1, sample_rate=16000): + audio_data = np.frombuffer(byts, np.int16) + audio_data = audio_data.reshape((channels, -1)) + layout = 'mono' + if channels == 2: + layout = 'stereo' + # Create an AV frame from the audio data + frame = av.AudioFrame.from_ndarray(audio_data, format='s16', layout='mono') + frame.sample_rate = sample_rate + return frame + +def frame2bytes(frame): + audio_array = frame.to_ndarray() + dtype = audio_array.dtype + audio_bytes = audio_array.tobytes() + return audio_bytes + +def resample(frame, sample_rate=None): + if sample_rate is None: + sample_rate = frame.rate + r = AudioResampler(format='s16', layout='mono', rate=sample_rate) + frame = r.resample(frame) + return frame + +class MyVad(webrtcvad.Vad): + def __init__(self, callback=None): + super().__init__(3) + self.voiced_frames = [] + self.num_padding_frames = 40 + self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) + self.onvoiceend = callback + self.triggered = False + self.cnt = 0 + + def voice_duration(self): + duration = 0 + for f in self.voiced_frames: + duration = f.samples * 1000 / f.sample_rate + duration + return duration + + async def vad_check(self, inframe): + """ + ONLY SUPPORT frame with sample_rate = 16000 samples = 160 + """ + frame = inframe + byts = frame2bytes(frame) + if self.cnt == 0: + f = frame + print(f'{f.sample_rate=}, {f.samples=},{f.layout=}, {len(byts)=}') + if not webrtcvad.valid_rate_and_frame_length(frame.sample_rate, frame.samples): + print('ftcygvhbunjiokmpl,mknjbhvgc') + is_speech = self.is_speech(byts, frame.sample_rate, length=frame.samples) + if not self.triggered: + self.ring_buffer.append((inframe, is_speech)) + num_voiced = len([f for f, speech in self.ring_buffer if speech]) + # If we're NOTTRIGGERED and more than 90% of the frames in + # the ring buffer are voiced frames, then enter the + # TRIGGERED state. + if num_voiced > 0.9 * self.ring_buffer.maxlen: + self.triggered = True + # We want to yield all the audio we see from now until + # we are NOTTRIGGERED, but we have to start with the + # audio that's already in the ring buffer. + for f, s in self.ring_buffer: + self.voiced_frames.append(f) + self.ring_buffer.clear() + # print('start voice .....', len(self.voiced_frames)) + else: + # We're in the TRIGGERED state, so collect the audio data + # and add it to the ring buffer. + self.voiced_frames.append(inframe) + self.ring_buffer.append((frame, is_speech)) + num_unvoiced = len([f for f, speech in self.ring_buffer if not speech]) + # If more than 90% of the frames in the ring buffer are + # unvoiced, then enter NOTTRIGGERED and yield whatever + # audio we've collected. + if num_unvoiced > 0.9 * self.ring_buffer.maxlen: + self.triggered = False + duration = self.voice_duration() + if duration > 500: + ret = frames_write_wave(self.voiced_frames) + if self.onvoiceend: + if iscoroutinefunction(self.onvoiceend): + await self.onvoiceend(ret) + else: + self.onvoiceend(ret) + else: + print('-----short voice------') + + + self.ring_buffer.clear() + self.voiced_frames = [] + self.cnt += 1 + +class AudioTrackVad(MediaStreamTrack): + def __init__(self, track, stage=3, onvoiceend=None): + super().__init__() + self.track = track + self.vad = MyVad(callback=onvoiceend) + # self.sample_rate = self.track.getSettings().sampleRate + # frameSize = self.track.getSettings().frameSize + # self.frame_duration_ms = (1000 * frameSize) / self.sample_rate + self.frame_duration_ms = 0.02 + self.remind_byts = b'' + self.loop = asyncio.get_event_loop() + self.task = None + self.debug = True + self.running = False + + def start_vad(self): + self.running = True + self.task = self.loop.call_later(self.frame_duration_ms, self._recv) + + def _recv(self): + asyncio.create_task(self.recv()) + + def stop(self): + self.running = False + + def to16000_160_frames(self, frame): + frames = resample(frame, sample_rate=16000) + ret_frames = [] + for f in frames: + if f.samples == 160: + return frames + for f in frames: + b1 = self.remind_byts + frame2bytes(f) + while len(b1) >= 320: + b = b1[:320] + b1 = b1[320:] + ret_frames.append(bytes2frame(b)) + self.remind_byts = b1 + return ret_frames + + async def recv(self): + frame = await self.track.recv() + self.sample_rate = frame.sample_rate + duration = (frame.samples * 1000) / frame.sample_rate + # print(f'{self.__class__.__name__}.recv(): {duration=}, {frame.samples=}, {frame.sample_rate=}') + try: + frames = self.to16000_160_frames(frame) + for frame in frames: + await self.vad.vad_check(frame) + except Exception as e: + print(f'{e=}') + print_exc() + return + if self.task: + self.task.cancel() + if self.running: + self.task = self.loop.call_later(self.frame_duration_ms, self._recv) + return frame diff --git a/rtcllm/wss.py b/rtcllm/wss.py new file mode 100644 index 0000000..964ee1d --- /dev/null +++ b/rtcllm/wss.py @@ -0,0 +1,142 @@ + +from websockets.client import connect +import websockets +from appPublic.background import Background +from appPublic.uniqueID import getID +from appPublic.dictObject import DictObject + +class Wss: + """ + new session + session_type: p2p, aba, meeting + + signaling data format + create a new session + { + type:new_session, + session:{ + sessiontype: + sessionid + } + from: + to: + } + (will boastcast to all peer online)login: + { + type:login, + from: + } + (will boastcast to all peer online)logout: + { + type:logout + from: + } + """ + def __init__(self, app, wss_url, info=None, opts={}): + self.wss_url = wss_url + if info is None: + info = { + 'id':getID() + } + self.opts = opts + self.info = info + self.ws = None + self.handler_name = 'sessionid' + self.app = app + self.running = False + self.app.wss = self + self.handlers = {} + self.sessionhandlers = {} + self.peers = [] + + def add_handler(self, key, handler): + sef.handlers[key] = handler + + def start(self): + loop = asyncio.get_running_loop() + f = asyncio.run_coroutine_threadsafe(self._start, loop) + return f.result() + + def add_sessionhandler(self, sessiontype, handler): + self.sessionhandlers[sessiontype] = handler + + async def _start(self): + self.ws = await connect(self.wss_url) + self.running = True + while self.running: + msg = await self.ws.recv() + data = DictObject(**json.loads(msg)) + name = data.get(self.handler_name) + handler = self.handlers.get(name) + if handler is None: + handler = self.recvdata_handler + await handler(data) + + def add_peer(self, info): + self.peers = [ i for i in self.peers if i.id != info.id ] + self.peers.append(info) + + def delete_peer(self, info): + self.peers = [ i for i in self.peers if i.id != info.id ] + + async def recvdata_handler(self, data): + if data.type == 'login': + self.add_peer(data.msgfrom) + return + if data.type == 'logout': + self.delete_peer(data.msgfrom) + return + if data.type == 'new_session' + h = self.sessionhandlers.get(data.session.sessiontype) + i = h(data.from) + k = data.session.sessionid + self.add_handler(k, i.recvdata_handler) + return + + async def stop(self): + self.running = False + await self.ws.close() + + async def restart(self): + try: + await self.stop() + except: + pass + await self.start() + + async def send(self, dic): + dic.msgfrom = self.info + txt = json.dumps(dic) + try: + return await self.ws.send(txt) + except: + await self.restart() + return self.send(dic) + + async def new_session(self, sessiontype, opts): + k = self.sessionhandlers.get(sessiontype) + if k is None: + e = Exception(f'Sessiontype({sessiontype}) not registed') + exception(f'new_session() exception {e}') + raise e + d = DictObject() + d.type = 'new_session' + d.msgfrom = self.info + sessionid = getID() + d.session = DictObject(sessiontype=sessiontype, sessionid=sessionid) + await self.send(d) + h = k(self, sessionid, opts) + self.add_sessionhandler(sessionid, h.recvdata_handler) + return h + + async def login(self): + d = DictObject() + d.msgfrom = self.info + d.type = 'login' + await self.send(d) + + async def logout(self): + d = DictObject() + d.msgfrom = self.info + d.type = 'logout' + await self.send(d)