Open_Duck_Mini_Interact/ark_api_module.py
2025-09-29 09:19:40 +08:00

101 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from openai import OpenAI
import time
import sys
import queue # 新增:用于缓存实时文本片段
import threading # 新增:用于并行处理语音播放
# 原代码7. 火山方舟API调用完整逻辑
class ArkAPIController:
def __init__(self, ark_api_key, ark_model_id, tts_controller, feedback_text):
# 接收调度脚本传入的TTS实例和反馈文本保持原逻辑
self.ARK_API_KEY = ark_api_key
self.ARK_MODEL_ID = ark_model_id
self.tts_controller = tts_controller
self.FEEDBACK_TEXT = feedback_text
self.chat_context = [] # 聊天上下文由模块内部维护(与原逻辑一致)
self.MAX_CONTEXT_LEN = 10
# 新增:实时语音播放队列与线程
self.speech_queue = queue.Queue() # 缓存待播放的文本片段
self.speech_thread = threading.Thread(target=self._process_speech_queue, daemon=True)
self.speech_thread.start() # 启动语音播放线程
# 新增:处理语音队列的函数(循环从队列取片段并播放)
def _process_speech_queue(self):
"""持续从队列中获取文本片段并调用TTS播放"""
while True:
text = self.speech_queue.get() # 阻塞等待队列消息
if text is None: # 退出信号
break
self.tts_controller.speak(text) # 播放片段
self.speech_queue.task_done() # 标记任务完成
def call_ark_api(self, content_type: str, content: dict):
# 播放操作反馈(同步执行)
self.tts_controller.speak(self.FEEDBACK_TEXT[content_type])
client = OpenAI(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=self.ARK_API_KEY
)
try:
messages = []
if content_type == "chat":
messages.extend(self.chat_context[-self.MAX_CONTEXT_LEN*2:])
messages.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
elif content_type == "image_recog":
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{content['image_base64']}"}},
{"type": "text", "text": content["prompt"]}
]
})
response = client.chat.completions.create(
model=self.ARK_MODEL_ID,
messages=messages,
max_tokens=300,
temperature=0.7 if content_type == "chat" else 0.3,
stream=True
)
full_response = ""
current_speech_chunk = "" # 缓存当前待播放的片段
print("\n" + "="*50)
print("🤖 回应:", end="", flush=True)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
char = chunk.choices[0].delta.content
full_response += char
current_speech_chunk += char # 累加片段
print(char, end="", flush=True)
time.sleep(0.05)
# 关键逻辑:当片段包含标点或达到一定长度时,推送到语音队列
if any(punct in current_speech_chunk for punct in [".", "", "!", "", "?", "", ",", "", ";", ""]):
self.speech_queue.put(current_speech_chunk) # 推送片段到队列
current_speech_chunk = "" # 重置片段缓存
# 处理最后剩余的片段(如果有)
if current_speech_chunk:
self.speech_queue.put(current_speech_chunk)
print("\n" + "="*50 + "\n")
# 等待所有语音片段播放完成
self.speech_queue.join()
# 维护聊天上下文(原有逻辑)
if content_type == "chat" and full_response.strip():
self.chat_context.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
self.chat_context.append({"role": "assistant", "content": [{"type": "text", "text": full_response}]})
return full_response
except Exception as e:
error_msg = f"❌ API调用失败{str(e)}"
print(f"\n" + "="*50)
print(error_msg)
print("="*50 + "\n")
self.tts_controller.speak(self.FEEDBACK_TEXT["api_error"])
return error_msg