1
Some checks failed
CI / build (push) Has been cancelled

This commit is contained in:
jrhlh
2025-07-18 18:49:59 +08:00
parent 3eeebab4f8
commit bd1fc93771
30 changed files with 119 additions and 177 deletions

View File

@ -13,77 +13,40 @@ from cryptography.hazmat.primitives import serialization
bp = Blueprint('auth', __name__, url_prefix='/auth')
logger = logging.getLogger(__name__)
# 添加 CORS 头
FRONTEND_ORIGINS = {
"http://localhost:8080",
"http://127.0.0.1:8080",
"http://[::1]:8080",
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://[::1]:5173"
}
# 允许所有局域网IP或指定前端源
FRONTEND_ORIGINS = {"*"} # 使用通配符表示接受所有来源或指定具体的局域网IP如{"http://192.168.x.x:8080"}
def add_cors_headers(response):
origin = request.headers.get('Origin')
if origin in FRONTEND_ORIGINS:
response.headers['Access-Control-Allow-Origin'] = origin
if origin in FRONTEND_ORIGINS or "*" in FRONTEND_ORIGINS:
response.headers['Access-Control-Allow-Origin'] = origin if origin else '*'
response.headers['Access-Control-Allow-Credentials'] = 'true'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
return response
# 辅助函数创建JWT令牌
def create_jwt_token(payload, secret_key, algorithm="HS256", expires_in=7200):
# 添加过期时间
payload_with_exp = payload.copy()
payload_with_exp["exp"] = int((datetime.datetime.utcnow() + datetime.timedelta(seconds=expires_in)).timestamp())
# JWT头部
header = {"alg": algorithm, "typ": "JWT"}
# 编码头部和载荷
encoded_header = base64.urlsafe_b64encode(json.dumps(header).encode('utf-8')).rstrip(b'=').decode('utf-8')
encoded_payload = base64.urlsafe_b64encode(json.dumps(payload_with_exp).encode('utf-8')).rstrip(b'=').decode(
'utf-8')
# 组合头部和载荷
encoded_payload = base64.urlsafe_b64encode(json.dumps(payload_with_exp).encode('utf-8')).rstrip(b'=').decode('utf-8')
message = f"{encoded_header}.{encoded_payload}"
# 创建签名
if algorithm == "HS256":
# 使用HMAC-SHA256创建签名
signature = hmac.new(
secret_key.encode('utf-8'),
message.encode('utf-8'),
hashlib.sha256
).digest()
signature = hmac.new(secret_key.encode('utf-8'), message.encode('utf-8'), hashlib.sha256).digest()
encoded_signature = base64.urlsafe_b64encode(signature).rstrip(b'=').decode('utf-8')
elif algorithm == "RS256":
# 使用RSA-SHA256创建签名 (生产环境中应妥善管理私钥)
private_key = serialization.load_pem_private_key(
secret_key.encode('utf-8'),
password=None
)
signature = private_key.sign(
message.encode('utf-8'),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
private_key = serialization.load_pem_private_key(secret_key.encode('utf-8'), password=None)
signature = private_key.sign(message.encode('utf-8'), padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), hashes.SHA256())
encoded_signature = base64.urlsafe_b64encode(signature).rstrip(b'=').decode('utf-8')
else:
raise ValueError(f"不支持的算法: {algorithm}")
# 组合JWT
jwt_token = f"{encoded_header}.{encoded_payload}.{encoded_signature}"
return jwt_token
@bp.route('/login', methods=['POST', 'OPTIONS'])
def login():
if request.method == "OPTIONS":
@ -99,11 +62,8 @@ def login():
response = jsonify({'message': '缺少必要字段'})
return add_cors_headers(response), 400
# 获取数据库连接
db = current_app.get_db()
cursor = db.cursor()
# 查询用户注意生产环境应使用参数化查询防止SQL注入
cursor.execute("SELECT * FROM user WHERE username = ?", (username,))
user_row = cursor.fetchone()
@ -112,40 +72,29 @@ def login():
response = jsonify({'message': '用户不存在'})
return add_cors_headers(response), 401
# 将元组结果转换为字典(如果需要)
if isinstance(user_row, tuple):
user_dict = dict(zip([column[0] for column in cursor.description], user_row))
else:
user_dict = user_row
# 明文密码比对(⚠️ 不推荐用于生产环境)
if password != user_dict['password']:
logger.warning(f"密码错误: {username}")
response = jsonify({'message': '密码错误'})
return add_cors_headers(response), 401
# 检查用户状态
if user_dict['status'] != 'Active':
logger.warning(f"用户已禁用: {username}")
response = jsonify({'message': '用户已禁用'})
return add_cors_headers(response), 403
# 判断是否为管理员基于permission_level字段
is_admin = user_dict['permission_level'] == 'Admin'
secret_key = os.getenv('SECRET_KEY', '默认密钥') # 生产环境建议使用环境变量设置
# 构建 JWT Token
secret_key = os.getenv('SECRET_KEY', '默认密钥') # 建议设置环境变量
# 使用我们自己的函数创建JWT
token = create_jwt_token(
{
'user_id': user_dict['id'],
'username': user_dict['username'],
'is_admin': is_admin,
},
{'user_id': user_dict['id'], 'username': user_dict['username'], 'is_admin': is_admin},
secret_key,
algorithm="HS256",
expires_in=2 * 60 * 60 # 2小时
expires_in=2 * 60 * 60
)
response_data = jsonify({
@ -153,19 +102,18 @@ def login():
'message': '登录成功',
'username': user_dict['username'],
'is_admin': is_admin,
'user_id': user_dict['id'] # 可选返回用户ID
'user_id': user_dict['id']
})
response = add_cors_headers(response_data)
# 设置 Cookie注意生产环境应启用 secure=True
response.set_cookie(
'token',
value=token,
max_age=2 * 60 * 60, # 2小时
max_age=2 * 60 * 60,
httponly=True,
samesite='None',
secure=False # 开发环境使用False生产环境使用True
samesite='None', # 注意开发环境下可以设为None生产环境推荐使用'Lax'或'Strict'
secure=False # 开发环境使用False生产环境应设为True
)
logger.info(f"用户登录成功: {username}")
@ -174,4 +122,4 @@ def login():
except Exception as e:
logger.error(f"登录过程发生错误: {str(e)}", exc_info=True)
response = jsonify({'message': '服务器内部错误'})
return add_cors_headers(response), 500
return add_cors_headers(response), 500

View File

@ -3,37 +3,37 @@ import logging
bp = Blueprint('register', __name__)
logger = logging.getLogger(__name__)
FRONTEND_ORIGINS = [
"http://localhost:8080",
"http://127.0.0.1:8080",
"http://[::1]:8080",
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://[::1]:5173"
]
# 支持所有局域网设备访问(或指定 IP
FRONTEND_ORIGINS = ["*"] # 也可以替换为具体的 IP 地址列表
def add_cors_headers(response):
origin = request.headers.get('Origin')
if origin in FRONTEND_ORIGINS or "*" in FRONTEND_ORIGINS:
response.headers['Access-Control-Allow-Origin'] = origin if origin else '*'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Methods'] = 'POST, OPTIONS'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
@bp.route('/register', methods=['POST', 'OPTIONS'])
def register():
if request.method == 'OPTIONS':
origin = request.headers.get('Origin')
if origin in FRONTEND_ORIGINS:
response = jsonify()
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Headers', 'Content-Type, Authorization')
response.headers.add('Access-Control-Allow-Methods', 'POST, OPTIONS')
response.headers.add('Access-Control-Allow-Credentials', 'true')
return response, 200
response = jsonify()
return add_cors_headers(response), 200
try:
logger.info("收到注册请求")
data = request.get_json()
if not data:
return jsonify({'message': '请求数据为空'}), 400
response = jsonify({'message': '请求数据为空'})
return add_cors_headers(response), 400
username = data.get('username')
password = data.get('password')
if not all([username, password]):
return jsonify({'message': '缺少用户名或密码'}), 400
response = jsonify({'message': '缺少用户名或密码'})
return add_cors_headers(response), 400
db = current_app.get_db()
cursor = db.cursor()
@ -41,7 +41,8 @@ def register():
# 检查用户名是否已存在
cursor.execute("SELECT id FROM user WHERE username = ?", (username,))
if cursor.fetchone():
return jsonify({'message': '用户名已存在'}), 400
response = jsonify({'message': '用户名已存在'})
return add_cors_headers(response), 400
# 插入数据时包含 permission_level默认设为 Operator
cursor.execute(
@ -51,9 +52,11 @@ def register():
db.commit()
logger.info(f"用户 {username} 注册成功")
return jsonify({'message': '注册成功'}), 201
response = jsonify({'message': '注册成功'})
return add_cors_headers(response), 201
except Exception as e:
logger.error(f"服务器内部错误: {str(e)}", exc_info=True)
db.rollback()
return jsonify({'message': '服务器内部错误'}), 500
response = jsonify({'message': '服务器内部错误'})
return add_cors_headers(response), 500