Open_Duck_Mini_Interact/ark_api_module.py

101 lines
4.5 KiB
Python
Raw Normal View History

2025-09-29 09:19:40 +08:00
from openai import OpenAI
import time
import sys
import queue # 新增:用于缓存实时文本片段
import threading # 新增:用于并行处理语音播放
# 原代码7. 火山方舟API调用完整逻辑
class ArkAPIController:
def __init__(self, ark_api_key, ark_model_id, tts_controller, feedback_text):
# 接收调度脚本传入的TTS实例和反馈文本保持原逻辑
self.ARK_API_KEY = ark_api_key
self.ARK_MODEL_ID = ark_model_id
self.tts_controller = tts_controller
self.FEEDBACK_TEXT = feedback_text
self.chat_context = [] # 聊天上下文由模块内部维护(与原逻辑一致)
self.MAX_CONTEXT_LEN = 10
# 新增:实时语音播放队列与线程
self.speech_queue = queue.Queue() # 缓存待播放的文本片段
self.speech_thread = threading.Thread(target=self._process_speech_queue, daemon=True)
self.speech_thread.start() # 启动语音播放线程
# 新增:处理语音队列的函数(循环从队列取片段并播放)
def _process_speech_queue(self):
"""持续从队列中获取文本片段并调用TTS播放"""
while True:
text = self.speech_queue.get() # 阻塞等待队列消息
if text is None: # 退出信号
break
self.tts_controller.speak(text) # 播放片段
self.speech_queue.task_done() # 标记任务完成
def call_ark_api(self, content_type: str, content: dict):
# 播放操作反馈(同步执行)
self.tts_controller.speak(self.FEEDBACK_TEXT[content_type])
client = OpenAI(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=self.ARK_API_KEY
)
try:
messages = []
if content_type == "chat":
messages.extend(self.chat_context[-self.MAX_CONTEXT_LEN*2:])
messages.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
elif content_type == "image_recog":
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{content['image_base64']}"}},
{"type": "text", "text": content["prompt"]}
]
})
response = client.chat.completions.create(
model=self.ARK_MODEL_ID,
messages=messages,
max_tokens=300,
temperature=0.7 if content_type == "chat" else 0.3,
stream=True
)
full_response = ""
current_speech_chunk = "" # 缓存当前待播放的片段
print("\n" + "="*50)
print("🤖 回应:", end="", flush=True)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
char = chunk.choices[0].delta.content
full_response += char
current_speech_chunk += char # 累加片段
print(char, end="", flush=True)
time.sleep(0.05)
# 关键逻辑:当片段包含标点或达到一定长度时,推送到语音队列
if any(punct in current_speech_chunk for punct in [".", "", "!", "", "?", "", ",", "", ";", ""]):
self.speech_queue.put(current_speech_chunk) # 推送片段到队列
current_speech_chunk = "" # 重置片段缓存
# 处理最后剩余的片段(如果有)
if current_speech_chunk:
self.speech_queue.put(current_speech_chunk)
print("\n" + "="*50 + "\n")
# 等待所有语音片段播放完成
self.speech_queue.join()
# 维护聊天上下文(原有逻辑)
if content_type == "chat" and full_response.strip():
self.chat_context.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
self.chat_context.append({"role": "assistant", "content": [{"type": "text", "text": full_response}]})
return full_response
except Exception as e:
error_msg = f"❌ API调用失败{str(e)}"
print(f"\n" + "="*50)
print(error_msg)
print("="*50 + "\n")
self.tts_controller.speak(self.FEEDBACK_TEXT["api_error"])
return error_msg