101 lines
4.5 KiB
Python
101 lines
4.5 KiB
Python
from openai import OpenAI
|
||
import time
|
||
import sys
|
||
import queue # 新增:用于缓存实时文本片段
|
||
import threading # 新增:用于并行处理语音播放
|
||
# 原代码7. 火山方舟API调用完整逻辑
|
||
class ArkAPIController:
|
||
def __init__(self, ark_api_key, ark_model_id, tts_controller, feedback_text):
|
||
# 接收调度脚本传入的TTS实例和反馈文本,保持原逻辑
|
||
self.ARK_API_KEY = ark_api_key
|
||
self.ARK_MODEL_ID = ark_model_id
|
||
self.tts_controller = tts_controller
|
||
self.FEEDBACK_TEXT = feedback_text
|
||
self.chat_context = [] # 聊天上下文由模块内部维护(与原逻辑一致)
|
||
self.MAX_CONTEXT_LEN = 10
|
||
|
||
# 新增:实时语音播放队列与线程
|
||
self.speech_queue = queue.Queue() # 缓存待播放的文本片段
|
||
self.speech_thread = threading.Thread(target=self._process_speech_queue, daemon=True)
|
||
self.speech_thread.start() # 启动语音播放线程
|
||
|
||
# 新增:处理语音队列的函数(循环从队列取片段并播放)
|
||
def _process_speech_queue(self):
|
||
"""持续从队列中获取文本片段并调用TTS播放"""
|
||
while True:
|
||
text = self.speech_queue.get() # 阻塞等待队列消息
|
||
if text is None: # 退出信号
|
||
break
|
||
self.tts_controller.speak(text) # 播放片段
|
||
self.speech_queue.task_done() # 标记任务完成
|
||
def call_ark_api(self, content_type: str, content: dict):
|
||
# 播放操作反馈(同步执行)
|
||
self.tts_controller.speak(self.FEEDBACK_TEXT[content_type])
|
||
|
||
client = OpenAI(
|
||
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
||
api_key=self.ARK_API_KEY
|
||
)
|
||
|
||
try:
|
||
messages = []
|
||
if content_type == "chat":
|
||
messages.extend(self.chat_context[-self.MAX_CONTEXT_LEN*2:])
|
||
messages.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
|
||
elif content_type == "image_recog":
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{content['image_base64']}"}},
|
||
{"type": "text", "text": content["prompt"]}
|
||
]
|
||
})
|
||
|
||
response = client.chat.completions.create(
|
||
model=self.ARK_MODEL_ID,
|
||
messages=messages,
|
||
max_tokens=300,
|
||
temperature=0.7 if content_type == "chat" else 0.3,
|
||
stream=True
|
||
)
|
||
|
||
full_response = ""
|
||
current_speech_chunk = "" # 缓存当前待播放的片段
|
||
print("\n" + "="*50)
|
||
print("🤖 回应:", end="", flush=True)
|
||
|
||
for chunk in response:
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
char = chunk.choices[0].delta.content
|
||
full_response += char
|
||
current_speech_chunk += char # 累加片段
|
||
print(char, end="", flush=True)
|
||
time.sleep(0.05)
|
||
|
||
# 关键逻辑:当片段包含标点或达到一定长度时,推送到语音队列
|
||
if any(punct in current_speech_chunk for punct in [".", "。", "!", "!", "?", "?", ",", ",", ";", ";"]):
|
||
self.speech_queue.put(current_speech_chunk) # 推送片段到队列
|
||
current_speech_chunk = "" # 重置片段缓存
|
||
|
||
# 处理最后剩余的片段(如果有)
|
||
if current_speech_chunk:
|
||
self.speech_queue.put(current_speech_chunk)
|
||
|
||
print("\n" + "="*50 + "\n")
|
||
|
||
# 等待所有语音片段播放完成
|
||
self.speech_queue.join()
|
||
|
||
# 维护聊天上下文(原有逻辑)
|
||
if content_type == "chat" and full_response.strip():
|
||
self.chat_context.append({"role": "user", "content": [{"type": "text", "text": content["prompt"]}]})
|
||
self.chat_context.append({"role": "assistant", "content": [{"type": "text", "text": full_response}]})
|
||
|
||
return full_response
|
||
except Exception as e:
|
||
error_msg = f"❌ API调用失败:{str(e)}"
|
||
print(f"\n" + "="*50)
|
||
print(error_msg)
|
||
print("="*50 + "\n")
|
||
self.tts_controller.speak(self.FEEDBACK_TEXT["api_error"])
|
||
return error_msg |