164 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			164 lines
		
	
	
		
			5.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import websocket | |||
|  | import json | |||
|  | import hmac | |||
|  | import hashlib | |||
|  | import base64 | |||
|  | import time | |||
|  | import ssl | |||
|  | from urllib.parse import urlencode | |||
|  | import datetime | |||
|  | from threading import Lock | |||
|  | from flask import current_app | |||
|  | 
 | |||
|  | # 全局处理状态锁 | |||
|  | status_lock = Lock() | |||
|  | processing_status = {} | |||
|  | 
 | |||
|  | def process_ai_message(appid, api_key, api_secret, spark_url, domain, messages): | |||
|  |     """处理AI消息并返回回答,优化超时处理和错误处理""" | |||
|  |     max_retries = 3 | |||
|  |     retries = 0 | |||
|  |     last_error = None | |||
|  | 
 | |||
|  |     while retries < max_retries: | |||
|  |         try: | |||
|  |             # 生成认证URL | |||
|  |             url = generate_auth_url(spark_url, api_key, api_secret) | |||
|  | 
 | |||
|  |             # 创建WebSocket连接,设置超时为30秒 | |||
|  |             ws = websocket.create_connection( | |||
|  |                 url, | |||
|  |                 sslopt={"cert_reqs": ssl.CERT_NONE}, | |||
|  |                 timeout=30 | |||
|  |             ) | |||
|  | 
 | |||
|  |             # 准备请求数据 | |||
|  |             request_data = { | |||
|  |                 "header": { | |||
|  |                     "app_id": appid, | |||
|  |                     "uid": f"user_{int(time.time())}_{retries}" | |||
|  |                 }, | |||
|  |                 "parameter": { | |||
|  |                     "chat": { | |||
|  |                         "domain": domain, | |||
|  |                         "temperature": 0.7, | |||
|  |                         "max_tokens": 2048, | |||
|  |                         "top_k": 3 | |||
|  |                     } | |||
|  |                 }, | |||
|  |                 "payload": { | |||
|  |                     "message": { | |||
|  |                         "text": messages | |||
|  |                     } | |||
|  |                 } | |||
|  |             } | |||
|  | 
 | |||
|  |             # 发送请求 | |||
|  |             ws.send(json.dumps(request_data)) | |||
|  | 
 | |||
|  |             # 接收响应 | |||
|  |             answer = "" | |||
|  |             start_time = time.time() | |||
|  |             while True: | |||
|  |                 try: | |||
|  |                     response = ws.recv() | |||
|  |                     if not response: | |||
|  |                         break | |||
|  | 
 | |||
|  |                     data = json.loads(response) | |||
|  |                     if "payload" in data and "choices" in data["payload"]: | |||
|  |                         if "text" in data["payload"]["choices"]: | |||
|  |                             for item in data["payload"]["choices"]["text"]: | |||
|  |                                 if "content" in item: | |||
|  |                                     answer += item["content"] | |||
|  | 
 | |||
|  |                     # 检查是否是最后一条消息 | |||
|  |                     if "header" in data and "status" in data["header"] and data["header"]["status"] == 2: | |||
|  |                         break | |||
|  | 
 | |||
|  |                     # 超时检查(45秒超时) | |||
|  |                     if time.time() - start_time > 45: | |||
|  |                         current_app.logger.warning("AI接口响应超时") | |||
|  |                         raise TimeoutError("AI接口响应超时") | |||
|  | 
 | |||
|  |                 except websocket.WebSocketTimeoutException: | |||
|  |                     current_app.logger.warning("WebSocket接收超时") | |||
|  |                     raise TimeoutError("WebSocket接收超时") | |||
|  | 
 | |||
|  |             # 关闭连接 | |||
|  |             ws.close() | |||
|  | 
 | |||
|  |             # 检查回答是否有效 | |||
|  |             if answer.strip(): | |||
|  |                 return answer | |||
|  | 
 | |||
|  |             retries += 1 | |||
|  |             last_error = "AI返回空回答" | |||
|  |             current_app.logger.warning(f"AI返回空回答,重试 {retries}/{max_retries}") | |||
|  | 
 | |||
|  |         except TimeoutError as te: | |||
|  |             retries += 1 | |||
|  |             last_error = str(te) | |||
|  |             current_app.logger.error(f"AI接口超时 ({retries}/{max_retries}): {str(te)}") | |||
|  |             if ws: | |||
|  |                 try: | |||
|  |                     ws.close() | |||
|  |                 except: | |||
|  |                     pass | |||
|  |         except websocket.WebSocketException as we: | |||
|  |             retries += 1 | |||
|  |             last_error = str(we) | |||
|  |             current_app.logger.error(f"WebSocket错误 ({retries}/{max_retries}): {str(we)}") | |||
|  |         except Exception as e: | |||
|  |             retries += 1 | |||
|  |             last_error = str(e) | |||
|  |             current_app.logger.error(f"处理错误 ({retries}/{max_retries}): {str(e)}") | |||
|  |             if ws: | |||
|  |                 try: | |||
|  |                     ws.close() | |||
|  |                 except: | |||
|  |                     pass | |||
|  | 
 | |||
|  |     return f"抱歉,AI处理失败: {last_error or '未知错误'}" | |||
|  | 
 | |||
|  | def generate_auth_url(api_url, api_key, api_secret): | |||
|  |     """生成认证URL""" | |||
|  |     from urllib.parse import urlparse | |||
|  |     url = urlparse(api_url) | |||
|  |     host = url.netloc | |||
|  |     path = url.path | |||
|  | 
 | |||
|  |     # 生成RFC1123格式的时间戳 | |||
|  |     now = time.time() | |||
|  |     date = datetime.datetime.fromtimestamp(now, datetime.timezone.utc).strftime('%a, %d %b %Y %H:%M:%S GMT') | |||
|  | 
 | |||
|  |     # 构建签名原始字符串 | |||
|  |     signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" | |||
|  | 
 | |||
|  |     # 计算HMAC-SHA256签名 | |||
|  |     signature_sha = hmac.new( | |||
|  |         api_secret.encode('utf-8'), | |||
|  |         signature_origin.encode('utf-8'), | |||
|  |         digestmod=hashlib.sha256 | |||
|  |     ).digest() | |||
|  | 
 | |||
|  |     # Base64编码 | |||
|  |     signature_sha_base64 = base64.b64encode(signature_sha).decode() | |||
|  | 
 | |||
|  |     # 构建认证字符串 | |||
|  |     authorization_origin = ( | |||
|  |         f'api_key="{api_key}", algorithm="hmac-sha256", ' | |||
|  |         f'headers="host date request-line", signature="{signature_sha_base64}"' | |||
|  |     ) | |||
|  | 
 | |||
|  |     # Base64编码认证字符串 | |||
|  |     authorization = base64.b64encode(authorization_origin.encode()).decode() | |||
|  | 
 | |||
|  |     # 构建最终URL | |||
|  |     query_params = { | |||
|  |         'authorization': authorization, | |||
|  |         'date': date, | |||
|  |         'host': host | |||
|  |     } | |||
|  | 
 | |||
|  |     return f"{api_url}?{urlencode(query_params)}" |