Skip to content
Commits on Source (2)
local cjson = require "cjson.safe"
local redis = require "resty.redis"
local function read_full_body()
ngx.req.read_body()
local data = ngx.req.get_body_data()
if data then return data end
local file = ngx.req.get_body_file()
if not file then return "" end
local f, err = io.open(file, "rb")
if not f then return nil, ("failed to open body file: " .. (err or "")) end
local content = f:read("*a")
f:close()
return content or ""
end
local function extract_token(content_type, body)
content_type = content_type or ""
if content_type:find("application/json", 1, true) then
local obj = cjson.decode(body or "")
if type(obj) == "table" then
return obj.token, obj
end
end
-- 默认按 x-www-form-urlencoded 解析
local args = ngx.decode_args(body or "", 0)
if args and args.token then
return args.token, args
end
return nil, nil
end
local function replace_token_in_body(content_type, body, new_value, parsed)
content_type = content_type or ""
if content_type:find("application/json", 1, true) then
local obj = parsed
if type(obj) ~= "table" then
obj = cjson.decode(body or "") or {}
end
obj.token = new_value
return cjson.encode(obj) or body
else
-- form-urlencoded
local args = parsed
if type(args) ~= "table" then
args = ngx.decode_args(body or "", 0) or {}
end
args.token = new_value
return ngx.encode_args(args)
end
end
do
local has_hrtime = false
if ngx.hrtime and type(ngx.hrtime) == "function" then
has_hrtime = true
end
if not has_hrtime then
ngx.update_time()
end
local start_clock = has_hrtime and ngx.hrtime() or ngx.now()
local content_type = ngx.req.get_headers()["content-type"] or ""
local body, berr = read_full_body()
if not body then
ngx.log(ngx.ERR, "token_auth: failed to read request body: ", berr or "unknown error")
ngx.status = 400
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":400,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(400)
end
local token, parsed = extract_token(content_type, body)
if not token or token == "" then
ngx.log(ngx.WARN, "token_auth: missing token in body, content-type=", content_type)
ngx.status = 401
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":401,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(401)
end
local red = redis:new()
red:set_timeout(100)
local ok, err = red:connect(ngx.var.redis_host or "127.0.0.1", tonumber(ngx.var.redis_port or 6379))
if not ok then
ngx.log(ngx.ERR, "token_auth: redis connect failed for token=", token, ", err=", err or "unknown")
ngx.status = 500
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":500,"msg":"服务异常,请稍后再试或联系管理员"}')
return ngx.exit(500)
end
if ngx.var.redis_password and ngx.var.redis_password ~= "" then
local auth_ok, auth_err = red:auth(ngx.var.redis_password)
if not auth_ok then
ngx.log(ngx.ERR, "token_auth: redis auth failed for token=", token, ", err=", auth_err or "unknown")
ngx.status = 500
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":500,"msg":"服务异常,请稍后再试或联系管理员"}')
return ngx.exit(500)
end
end
-- 使用 Hash: token:<token_value>
local token_key = "token:" .. token
local exists = red:exists(token_key)
if exists == 0 then
ngx.log(ngx.INFO, "token_auth: token not found token=", token)
ngx.status = 401
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":401,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(401)
end
-- 读取必要字段
local fields = { "is_discarded", "is_locked", "lock_until_ts", "end_ts" }
local res, herr = red:hmget(token_key, unpack(fields))
if not res then
ngx.log(ngx.ERR, "token_auth: redis hmget failed token=", token, ", err=", herr or "unknown")
ngx.status = 500
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":500,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(500)
end
local is_discarded = tonumber(res[1] or 0) or 0
local is_locked = tonumber(res[2] or 0) or 0
local lock_until_ts = tonumber(res[3] or 0) or 0
local end_ts = tonumber(res[4] or 0) or 0
local now = ngx.time()
-- 提前初始化用于事件上报的参数,确保任一锁分支都能上报
local ip = ngx.var.remote_addr or "unknown"
local window_sec = tonumber(ngx.var.token_window_sec or 60) or 60
local limit_per_ip = tonumber(ngx.var.token_ip_limit or 10) or 10
local lock_sec = tonumber(ngx.var.token_lock_sec or 43200) or 43200 -- 默认12小时
local overSetKey = "token:overips:" .. token
-- 作废
if is_discarded == 1 then
ngx.log(ngx.INFO, "token_auth: token discarded token=", token)
ngx.status = 401
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":401,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(401)
end
-- 过期
if end_ts > 0 and now > end_ts then
ngx.log(ngx.INFO, "token_auth: token expired token=", token, ", end_ts=", end_ts, ", now=", now)
ngx.status = 401
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":401,"msg":"Token已过期或不存在,请购买新token"}')
return ngx.exit(401)
end
-- 持久锁
if is_locked == 1 then
if lock_until_ts == 0 or now < lock_until_ts then
ngx.log(ngx.INFO, "token_auth: token locked token=", token, ", lock_until_ts=", lock_until_ts, ", now=", now)
ngx.status = 403
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":403,"msg":"检测到同一token多IP高频访问,已锁定,请等待12个小时解锁"}')
return ngx.exit(403)
end
end
-- 临时锁(风控锁)
local tmp_lock_key = "token_lock:" .. token
local tmp_locked = red:exists(tmp_lock_key)
if tmp_locked == 1 then
ngx.log(ngx.INFO, "token_auth: token temporarily locked token=", token)
-- 已处于临时锁时也上报 Streams 事件(一次性去重)
local dedupe_key = "token:notif:" .. token
local first = red:setnx(dedupe_key, 1)
if first == 1 then
red:expire(dedupe_key, 60)
local over = red:scard(overSetKey) or 0
red:xadd("token:lock_events", "MAXLEN", "~", 10000, "*",
"token", token,
"ip", ip,
"ts", tostring(now),
"reason", "tmp_locked",
"window_sec", tostring(window_sec),
"limit_per_ip", tostring(limit_per_ip),
"over_ips", tostring(over)
)
end
ngx.status = 403
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":403,"msg":"检测到同一token多IP高频访问,已锁定,请等待12个小时解锁"}')
return ngx.exit(403)
end
-- 防刷计数与可能锁定(EVAL)
local ipCountKey = "token:ipcnt:" .. token .. ":" .. ip
-- overSetKey 已在前面定义
local lockKey = tmp_lock_key
local script = [[
local c = redis.call('INCR', KEYS[1])
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[1]))
if c > tonumber(ARGV[2]) then
redis.call('SADD', KEYS[2], ARGV[4])
redis.call('EXPIRE', KEYS[2], tonumber(ARGV[1]))
end
local over = redis.call('SCARD', KEYS[2])
if over >= 2 then
if redis.call('EXISTS', KEYS[3]) == 0 then
redis.call('SET', KEYS[3], 1, 'EX', tonumber(ARGV[3]))
end
return 1
end
return 0
]]
local eval_res, eval_err = red:eval(script, 3, ipCountKey, overSetKey, lockKey,
window_sec, limit_per_ip, lock_sec, ip)
if not eval_res then
ngx.log(ngx.ERR, "token_auth: redis eval failed token=", token, ", err=", eval_err or "unknown")
-- Redis 异常时,按需决定放行或拦截,这里放行
else
if tonumber(eval_res) == 1 then
-- 触发临时锁后,写入 Streams 事件并去重节流
local dedupe_key = "token:notif:" .. token
local first = red:setnx(dedupe_key, 1)
if first == 1 then
red:expire(dedupe_key, 60)
local over = red:scard(overSetKey) or 0
red:xadd("token:lock_events", "MAXLEN", "~", 10000, "*",
"token", token,
"ip", ip,
"ts", tostring(now),
"reason", "multi_ip_high_freq",
"window_sec", tostring(window_sec),
"limit_per_ip", tostring(limit_per_ip),
"over_ips", tostring(over)
)
end
ngx.log(ngx.WARN, "token_auth: token locked by anti-abuse token=", token,
", ip=", ip, ", limit_per_ip=", limit_per_ip, ", window_sec=", window_sec)
ngx.status = 403
ngx.header["Content-Type"] = "application/json"
ngx.say('{"code":403,"msg":"检测到同一token多IP高频访问,已锁定,请等待12个小时解锁"}')
return ngx.exit(403)
end
end
ngx.log(ngx.DEBUG, "token_auth: token validation passed token=", token,
", ip=", ip)
local uri = ngx.var.uri or ""
local upstream_map = {
["/api-new/tushare"] = "http://123.57.69.240:9002/tq",
["/api-new/tushare/pro_bar"] = "http://123.57.69.240:9002/tp",
}
local target = upstream_map[uri]
if target then
ngx.var.target_upstream = target
ngx.log(ngx.DEBUG, "token_auth: route mapped uri=", uri, ", target=", target)
end
-- 替换 token 为占位符
local replacement = ngx.var.token_placeholder or "s6f0bde77fb2fe31f729f14fc849ddc5378"
local new_body = replace_token_in_body(content_type, body, replacement, parsed)
ngx.req.set_body_data(new_body or body)
ngx.req.clear_header("Content-Length")
ngx.req.set_header("Content-Length", #(new_body or body))
local cost_ms
if has_hrtime then
local elapsed_ns = ngx.hrtime() - start_clock
cost_ms = elapsed_ns / 1e6
else
ngx.update_time()
cost_ms = (ngx.now() - start_clock) * 1000
end
ngx.log(ngx.INFO, "token_auth: request finished token=", token,
", uri=", uri, ", cost_ms=", string.format("%.2f", cost_ms))
red:set_keepalive(60000, 100)
return
end
...@@ -39,7 +39,7 @@ class TestTuDataComprehensive(unittest.TestCase): ...@@ -39,7 +39,7 @@ class TestTuDataComprehensive(unittest.TestCase):
def setUp(self): def setUp(self):
"""Setup test fixtures before each test method.""" """Setup test fixtures before each test method."""
# 优先从环境变量读取token,如果没有则使用有效的默认token # 优先从环境变量读取token,如果没有则使用有效的默认token
self.test_token = os.environ.get('TUSHARE_TOKEN', "1ab08efbf57546eab5a62499848c542a") self.test_token = os.environ.get('TUSHARE_TOKEN', "426a56d8a78c4b96a41b1c5f58b8120d")
# Mock DataFrame response # Mock DataFrame response
self.mock_response_df = pd.DataFrame({ self.mock_response_df = pd.DataFrame({
'ts_code': ['000001.SZ', '000002.SZ'], 'ts_code': ['000001.SZ', '000002.SZ'],
...@@ -183,7 +183,7 @@ class TestTuDataComprehensive(unittest.TestCase): ...@@ -183,7 +183,7 @@ class TestTuDataComprehensive(unittest.TestCase):
def test_daily(self): def test_daily(self):
"""Test daily method""" """Test daily method"""
print(" 📊 测试日线数据接口") print(" 📊 测试日线数据接口")
result = self._mock_api_response('daily', ts_code='000001.SZ') result = self._mock_api_response('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718')
self._validate_dataframe_result(result, ['ts_code', 'trade_date', 'close'], min_rows=1) self._validate_dataframe_result(result, ['ts_code', 'trade_date', 'close'], min_rows=1)
# 验证数据类型 # 验证数据类型
......
...@@ -95,7 +95,7 @@ class pro_api: ...@@ -95,7 +95,7 @@ class pro_api:
Returns: Returns:
pd.DataFrame: A DataFrame containing the fetched stock data. pd.DataFrame: A DataFrame containing the fetched stock data.
""" """
url = "http://114.132.244.63/api-tushare/tushare" url = "http://114.132.244.63/api-new/tushare"
params = { params = {
'token': get_token(), 'token': get_token(),
...@@ -902,7 +902,7 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E ...@@ -902,7 +902,7 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E
""" """
""" """
url = "http://114.132.244.63/api-tushare/tushare/pro_bar" url = "http://114.132.244.63/api-new/tushare/pro_bar"
params = { params = {
'token':get_token(), 'token':get_token(),
......
...@@ -115,13 +115,29 @@ async def pro_bar_view(request: Request): ...@@ -115,13 +115,29 @@ async def pro_bar_view(request: Request):
t5 = time.time() t5 = time.time()
if hasattr(resp, "status_code") and hasattr(resp, "content"): if hasattr(resp, "status_code") and hasattr(resp, "content"):
await logger.ainfo(f"[pro_bar_view] finished, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s") # 统一在 API 层输出一次包含元字段与耗时的日志
await logger.ainfo(
f"[pro_bar_view] elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={body.get('api')} token={token} "
f"start_date={body.get('start_date')} end_date={body.get('end_date')} "
f"trade_date={body.get('trade_date')}"
)
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None)) return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None))
await logger.ainfo(f"[pro_bar_view] finished, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s") await logger.ainfo(
f"[pro_bar_view] elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={body.get('api')} token={token} "
f"start_date={body.get('start_date')} end_date={body.get('end_date')} "
f"trade_date={body.get('trade_date')}"
)
return resp return resp
except Exception as e: except Exception as e:
await logger.aerror(f"[pro_bar_view] exception={str(e)}, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s") await logger.aerror(
f"[pro_bar_view] exception={str(e)} elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={body.get('api')} token={token} "
f"start_date={body.get('start_date')} end_date={body.get('end_date')} "
f"trade_date={body.get('trade_date')}"
)
return Response(content=str(e), status_code=500) return Response(content=str(e), status_code=500)
except Exception as e: except Exception as e:
...@@ -140,16 +156,16 @@ async def tushare_entry(request: Request): ...@@ -140,16 +156,16 @@ async def tushare_entry(request: Request):
ok, msg = check_token(token, client_ip) ok, msg = check_token(token, client_ip)
t3 = time.time() t3 = time.time()
if not ok: if not ok:
await logger.ainfo(f"[tushare_entry] token check failed, token={token}, total={time.time()-start_time:.4f}s") await logger.ainfo(f"[tushare_entry] token check failed, token={token}, elapsed_ms={(time.time()-start_time)*1000:.2f}")
return Response(content=msg, status_code=401) return Response(content=msg, status_code=401)
api_name = body.get("api_name") api_name = body.get("api_name")
t4 = time.time() t4 = time.time()
if not api_name: if not api_name:
await logger.ainfo(f"[tushare_entry] api_name missing, token={token}, total={time.time()-start_time:.4f}s") await logger.ainfo(f"[tushare_entry] api_name missing, token={token}, elapsed_ms={(time.time()-start_time)*1000:.2f}")
return {"success": False, "msg": "api_name 不能为空"} return {"success": False, "msg": "api_name 不能为空"}
# 动态分发 # 动态分发
if not hasattr(pro, api_name): if not hasattr(pro, api_name):
await logger.ainfo(f"[tushare_entry] api_name not supported, token={token}, method={api_name}, total={time.time()-start_time:.4f}s") await logger.ainfo(f"[tushare_entry] api_name not supported, token={token}, method={api_name}, elapsed_ms={(time.time()-start_time)*1000:.2f}")
return {"success": False, "msg": f"不支持的api_name: {api_name}"} return {"success": False, "msg": f"不支持的api_name: {api_name}"}
method = getattr(pro, api_name) method = getattr(pro, api_name)
t5 = time.time() t5 = time.time()
...@@ -158,12 +174,31 @@ async def tushare_entry(request: Request): ...@@ -158,12 +174,31 @@ async def tushare_entry(request: Request):
resp = await run_in_threadpool(method, **body) resp = await run_in_threadpool(method, **body)
t6 = time.time() t6 = time.time()
if hasattr(resp, "status_code") and hasattr(resp, "content"): if hasattr(resp, "status_code") and hasattr(resp, "content"):
await logger.ainfo(f"[tushare_entry] finished, token={token}, method={api_name}, total={time.time()-start_time:.4f}s") # 统一在 API 层输出一次包含元字段与耗时的日志
params = body.get("params", {}) or {}
await logger.ainfo(
f"[tushare_entry] elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={api_name} token={token} "
f"start_date={params.get('start_date')} end_date={params.get('end_date')} "
f"trade_date={params.get('trade_date')}"
)
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None)) return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None))
await logger.ainfo(f"[tushare_entry] finished, token={token}, method={api_name}, total={time.time()-start_time:.4f}s") params = body.get("params", {}) or {}
await logger.ainfo(
f"[tushare_entry] elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={api_name} token={token} "
f"start_date={params.get('start_date')} end_date={params.get('end_date')} "
f"trade_date={params.get('trade_date')}"
)
return resp return resp
except Exception as e: except Exception as e:
await logger.aerror(f"[tushare_entry] exception={str(e)}, token={token}, method={api_name}, total={time.time()-start_time:.4f}s") params = body.get("params", {}) or {}
await logger.aerror(
f"[tushare_entry] exception={str(e)} elapsed_ms={(time.time()-start_time)*1000:.2f} "
f"api_name={api_name} token={token} "
f"start_date={params.get('start_date')} end_date={params.get('end_date')} "
f"trade_date={params.get('trade_date')}"
)
return Response(content=str(e), status_code=500) return Response(content=str(e), status_code=500)
@router.post("/unlock_token") @router.post("/unlock_token")
......
...@@ -3,22 +3,40 @@ import requests.adapters ...@@ -3,22 +3,40 @@ import requests.adapters
import pandas as pd import pandas as pd
import time import time
import logging import logging
import asyncio
from app.service.config import get_tushare_token from app.service.config import get_tushare_token
import importlib
try:
from urllib3.util import Retry # type: ignore
except Exception:
try:
from requests.packages.urllib3.util.retry import Retry # type: ignore
except Exception:
Retry = None # type: ignore
# 配置日志
logger = logging.getLogger(__name__)
# 创建长连接会话与连接池,减少 TIME_WAIT # 创建长连接会话与连接池,减少 TIME_WAIT
_session = requests.Session() _session = requests.Session()
_adapter = requests.adapters.HTTPAdapter(pool_connections=20, pool_maxsize=100, max_retries=0) if 'Retry' in globals() and Retry is not None:
_retry = Retry(
total=3,
connect=3,
read=3,
backoff_factor=0.3,
status_forcelist=[502, 503, 504],
allowed_methods=frozenset(['POST'])
)
_adapter = requests.adapters.HTTPAdapter(pool_connections=20, pool_maxsize=100, max_retries=_retry)
else:
_adapter = requests.adapters.HTTPAdapter(pool_connections=20, pool_maxsize=100, max_retries=3)
_session.mount('http://', _adapter) _session.mount('http://', _adapter)
_session.mount('https://', _adapter) _session.mount('https://', _adapter)
_session.headers.update({'Connection': 'keep-alive'}) _session.headers.update({'Connection': 'keep-alive'})
# async HTTP client # async HTTP client
try: try:
import httpx httpx = importlib.import_module("httpx")
_async_client = httpx.AsyncClient(timeout=30, headers={'Connection': 'keep-alive'}) _async_client = httpx.AsyncClient(timeout=30, headers={'Connection': 'keep-alive'})
except Exception: except Exception:
httpx = None httpx = None
...@@ -65,7 +83,17 @@ class pro_api: ...@@ -65,7 +83,17 @@ class pro_api:
'params': params_data, 'params': params_data,
'fields': fields_data, 'fields': fields_data,
} }
# 优先走异步 httpx(在无事件循环的线程中用 asyncio.run 包装)
if 'httpx' in globals() and httpx is not None:
try:
asyncio.get_running_loop()
# 当前线程已有事件循环(无法在已运行的 loop 中直接 asyncio.run),回退到同步请求
except RuntimeError:
# 当前线程没有运行中的事件循环,创建临时事件循环执行异步请求
async def _do():
return await self.async_query(api_name, fields=fields_data, params=params_data)
return asyncio.run(_do())
# 同步回退
response = _session.post(url, json=params, timeout=30) response = _session.post(url, json=params, timeout=30)
return response return response
...@@ -85,10 +113,11 @@ class pro_api: ...@@ -85,10 +113,11 @@ class pro_api:
'params': params_data, 'params': params_data,
'fields': fields_data, 'fields': fields_data,
} }
if _async_client is None: if 'httpx' in globals() and httpx is not None:
async with httpx.AsyncClient(timeout=30, headers={'Connection': 'keep-alive'}) as client:
return await client.post(url, json=payload)
# 若无 httpx 可用,回退为同步请求(注意:在事件循环中调用会阻塞)
return _session.post(url, json=payload, timeout=30) return _session.post(url, json=payload, timeout=30)
resp = await _async_client.post(url, json=payload)
return resp
...@@ -877,17 +906,13 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E ...@@ -877,17 +906,13 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E
"contract_type" : contract_type "contract_type" : contract_type
} }
logger.info(f"=== tushare_funet pro_bar 请求开始 ===")
logger.info(f"请求URL: {url}")
logger.info(f"请求参数: {params}")
if 'min' in freq : if 'min' in freq :
logger.error("此接口为单独权限,和积分没有关系,需要单独购买")
logger.info(f"=== tushare_funet pro_bar 请求结束 ===")
return '此接口为单独权限,和积分没有关系,需要单独购买' return '此接口为单独权限,和积分没有关系,需要单独购买'
else: else:
_t0 = time.perf_counter()
response = _session.post(url, json=params,) response = _session.post(url, json=params,)
_elapsed_ms = (time.perf_counter() - _t0) * 1000.0
# 仅输出一条日志,包含耗时与元字段
return response return response
async def pro_bar_async(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E', async def pro_bar_async(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E',
......
...@@ -79,7 +79,7 @@ class Logger: ...@@ -79,7 +79,7 @@ class Logger:
"""设置日志器""" """设置日志器"""
# 创建日志器 # 创建日志器
self.logger = logging.getLogger(self.name) self.logger = logging.getLogger(self.name)
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.INFO)
# 若根日志器尚未配置任何处理器,则在此完成异步日志配置(默认到项目 logs 目录) # 若根日志器尚未配置任何处理器,则在此完成异步日志配置(默认到项目 logs 目录)
root_logger = logging.getLogger() root_logger = logging.getLogger()
...@@ -97,11 +97,11 @@ class Logger: ...@@ -97,11 +97,11 @@ class Logger:
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(fmt) console_handler.setFormatter(fmt)
console_handler.setLevel(logging.DEBUG) console_handler.setLevel(logging.INFO)
q = queue.SimpleQueue() q = queue.SimpleQueue()
qh = QueueHandler(q) qh = QueueHandler(q)
root_logger.setLevel(logging.DEBUG) root_logger.setLevel(logging.INFO)
root_logger.addHandler(qh) root_logger.addHandler(qh)
listener = QueueListener(q, file_handler, console_handler) listener = QueueListener(q, file_handler, console_handler)
listener.daemon = True listener.daemon = True
......
#!/bin/bash #!/bin/bash
set -e set -e # 遇到错误退出(避免无效命令影响后续执行)
# 进入脚本所在目录(确保相对路径正确)
cd "$(dirname "$0")" cd "$(dirname "$0")"
# 激活虚拟环境
source /home/leewcc/tushare-web-api1/myenv/bin/activate source /home/leewcc/tushare-web-api1/myenv/bin/activate
pip install -r requirements.txt
# 支持外部传入端口与进程数(workers),默认 8000 / 1 # 可选:安装依赖(如果需要动态更新依赖,可保留;否则可注释掉,避免每次启动都执行)
PORT="${PORT:-${1:-8000}}" # pip install -r requirements.txt
WORKERS="${WORKERS:-${2:-1}}"
# 解析参数:$1=端口,$2=工作进程数(默认值:端口8000,进程1)
PORT="${1:-8000}" # 第一个参数是端口,默认8000
WORKERS="${2:-1}" # 第二个参数是工作进程数,默认1
# 使用 exec 让 uvicorn 取代当前进程,便于 Supervisor 正确管理 # 启动 uvicorn(关键:确保 --workers 传递的是正确的 WORKERS 变量)
exec uvicorn app.main:app --host 0.0.0.0 --port "$PORT" --workers "$WORKERS" exec uvicorn app.main:app --host 0.0.0.0 --port "$PORT" --workers "$WORKERS"
\ No newline at end of file
...@@ -11,6 +11,7 @@ def create_app(config_name=None): ...@@ -11,6 +11,7 @@ def create_app(config_name=None):
"""创建Flask应用""" """创建Flask应用"""
import os import os
from config.settings import config from config.settings import config
from config.settings import Config
# 获取配置 # 获取配置
if config_name is None: if config_name is None:
...@@ -36,6 +37,26 @@ def create_app(config_name=None): ...@@ -36,6 +37,26 @@ def create_app(config_name=None):
# 注册请求处理器 # 注册请求处理器
register_request_handlers(app) register_request_handlers(app)
# 启动后台消费者(可配置开关),遵循开发热重载仅主进程启动
try:
if Config.AUTO_START_LOCK_EVENTS_CONSUMER:
should_start = True
if app.config.get('DEBUG'):
# Flask debug 模式会启动两次,只有主进程启动
if os.environ.get('WERKZEUG_RUN_MAIN') != 'true':
should_start = False
if should_start:
from app.background.lock_events_consumer import LockEventsConsumer
consumer = LockEventsConsumer()
consumer.start()
# 保存到扩展字典,便于管理/测试
if not hasattr(app, 'extensions'):
app.extensions = {}
app.extensions['lock_events_consumer'] = consumer
except Exception:
# 后台线程启动失败不影响主服务,但会在日志中体现
pass
return app return app
def init_extensions(app): def init_extensions(app):
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
后台线程:消费 Redis Streams token:lock_events,并通过 Service 层执行业务更新
"""
import os
import socket
import threading
import time
from typing import List, Optional
from config.settings import Config
from app.utils.redis_client import get_redis
from app.services.token_service import TokenService
from app.services.token_intercept_service import TokenInterceptService
from app.utils.logger import get_logger
logger = get_logger(__name__)
class LockEventsConsumer:
"""
简单的基于线程的 Redis Stream 消费者
- 使用 Consumer Group 保证多进程/多线程可水平扩展
- 仅通过 Service 层执行业务,符合分层约束
"""
def __init__(
self,
stream: Optional[str] = None,
group: Optional[str] = None,
consumer_name: Optional[str] = None,
block_ms: int = 5000,
batch_size: int = 50,
):
self.stream = stream or Config.REDIS_STREAM_LOCK_EVENTS
self.group = group or Config.REDIS_STREAM_GROUP
# 确保消费者名在同一主机多进程时也唯一
hostname = socket.gethostname()
pid = os.getpid()
self.consumer_name = consumer_name or f"{hostname}-{pid}"
self.block_ms = block_ms
self.batch_size = batch_size
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
def _ensure_group(self, r):
try:
r.xgroup_create(self.stream, self.group, id='0', mkstream=True)
logger.info(f"创建Consumer Group成功 stream={self.stream}, group={self.group}")
except Exception:
# 已存在则忽略
pass
def _handle_event(self, fields: dict):
token_value = fields.get('token')
ip = fields.get('ip')
# ts = fields.get('ts') # 可选使用
# reason = fields.get('reason') # 可选使用
if not token_value:
return
token_info = TokenService.get_token_by_value(token_value)
if not token_info:
return
ip_list: List[str] = [ip] if ip else []
TokenInterceptService.create_intercept_record(
token_info['id'], ip_list, created_by='gateway'
)
try:
# 业务层负责乐观锁与审计等
TokenService.lock_token(token_info['id'], operator='gateway')
except Exception:
# 防御性忽略单条失败,避免阻塞消费
pass
def _run_loop(self):
r = get_redis()
self._ensure_group(r)
logger.info(f"LockEventsConsumer 启动: stream={self.stream}, group={self.group}, consumer={self.consumer_name}")
while not self._stop_event.is_set():
try:
resp = r.xreadgroup(
self.group,
self.consumer_name,
{self.stream: '>'},
count=self.batch_size,
block=self.block_ms
)
if not resp:
continue
for _, messages in resp:
for msg_id, fields in messages:
try:
self._handle_event(fields)
finally:
try:
r.xack(self.stream, self.group, msg_id)
except Exception:
pass
except Exception as e:
# 防止异常退出,打印并短暂休眠后重试
logger.error(f"LockEventsConsumer 异常: {e}")
time.sleep(1.0)
logger.info("LockEventsConsumer 已停止")
def start(self):
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._thread = threading.Thread(target=self._run_loop, name="LockEventsConsumer", daemon=True)
self._thread.start()
def stop(self, timeout: Optional[float] = 5.0):
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
...@@ -8,6 +8,8 @@ from sqlalchemy.orm import Session ...@@ -8,6 +8,8 @@ from sqlalchemy.orm import Session
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from app.models.token import TokenIntercept, Token from app.models.token import TokenIntercept, Token
from app.utils.redis_client import get_redis
from config.settings import Config
from app.database import get_db_session, close_db_session from app.database import get_db_session, close_db_session
from app.utils.logger import get_logger from app.utils.logger import get_logger
...@@ -55,6 +57,7 @@ class TokenInterceptService: ...@@ -55,6 +57,7 @@ class TokenInterceptService:
session.commit() session.commit()
logger.info(f"拦截记录创建成功: Token ID {token_id}, IP数量 {len(ip_list)}") logger.info(f"拦截记录创建成功: Token ID {token_id}, IP数量 {len(ip_list)}")
# 可选:推送到其它事件流或更新 Redis 辅助索引
return intercept_record.to_dict() return intercept_record.to_dict()
except Exception as e: except Exception as e:
......
...@@ -9,6 +9,7 @@ from sqlalchemy import and_, or_ ...@@ -9,6 +9,7 @@ from sqlalchemy import and_, or_
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from app.models.token import Token, TokenIntercept from app.models.token import Token, TokenIntercept
from app.utils.redis_client import get_redis
from app.database import get_db_session, close_db_session from app.database import get_db_session, close_db_session
from app.utils.logger import get_logger from app.utils.logger import get_logger
...@@ -18,6 +19,37 @@ logger = get_logger(__name__) ...@@ -18,6 +19,37 @@ logger = get_logger(__name__)
class TokenService: class TokenService:
"""Token服务类""" """Token服务类"""
@staticmethod
def _is_tmp_locked(token_value: str) -> bool:
"""
判断是否处于临时锁(Redis 风控锁)
"""
try:
r = get_redis()
tmp_lock_key = f"token_lock:{token_value}"
return bool(r.exists(tmp_lock_key))
except Exception as re:
logger.warning(f"检查临时锁失败,按未锁处理: {re}")
return False
@staticmethod
def _apply_effective_lock(token_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
将 is_locked 扩展为“有效锁定状态”:数据库锁 or Redis 临时锁 其一即可视为锁定。
同时返回 tmp_locked 字段用于前端/调用方可视化区分。
"""
if not token_dict:
return token_dict
token_value = token_dict.get('token_value')
if not token_value:
return token_dict
tmp_locked = TokenService._is_tmp_locked(token_value)
db_locked = bool(token_dict.get('is_locked'))
effective_locked = db_locked or tmp_locked
token_dict['tmp_locked'] = int(tmp_locked)
token_dict['is_locked'] = int(effective_locked)
return token_dict
@staticmethod @staticmethod
def generate_token() -> str: def generate_token() -> str:
"""生成Token值""" """生成Token值"""
...@@ -71,6 +103,24 @@ class TokenService: ...@@ -71,6 +103,24 @@ class TokenService:
session.commit() session.commit()
logger.info(f"Token创建成功: {token_value}") logger.info(f"Token创建成功: {token_value}")
# 同步到 Redis
try:
r = get_redis()
key = f"token:{token_value}"
r.hset(key, mapping={
'id': token.id,
'token_value': token.token_value,
'is_discarded': int(getattr(token, 'is_discarded', False)),
'is_locked': int(getattr(token, 'is_locked', False)),
'lock_until_ts': 0,
'start_ts': int(token.start_time.timestamp()) if token.start_time else 0,
'end_ts': int(token.end_time.timestamp()) if token.end_time else 0,
'credits': token.credits,
'validity_period': token.validity_period,
'version': token.version,
})
except Exception as re:
logger.warning(f"Redis 同步创建Token失败: {re}")
return token.to_dict() return token.to_dict()
except Exception as e: except Exception as e:
...@@ -112,6 +162,10 @@ class TokenService: ...@@ -112,6 +162,10 @@ class TokenService:
if token_value: if token_value:
query = query.filter(Token.token_value.like(f"%{token_value}%")) query = query.filter(Token.token_value.like(f"%{token_value}%"))
else:
# 未进行模糊搜索时,自动过滤已过期的Token
now = datetime.now()
query = query.filter(Token.end_time >= now)
# 计算总数 # 计算总数
total = query.count() total = query.count()
...@@ -121,7 +175,7 @@ class TokenService: ...@@ -121,7 +175,7 @@ class TokenService:
tokens = query.order_by(Token.created_at.desc()).offset(offset).limit(page_size).all() tokens = query.order_by(Token.created_at.desc()).offset(offset).limit(page_size).all()
# 转换为字典列表 # 转换为字典列表
token_list = [token.to_dict() for token in tokens] token_list = [TokenService._apply_effective_lock(token.to_dict()) for token in tokens]
return { return {
'items': token_list, 'items': token_list,
...@@ -147,7 +201,7 @@ class TokenService: ...@@ -147,7 +201,7 @@ class TokenService:
operator: 操作人 operator: 操作人
Returns: Returns:
是否成功 是否成功(幂等:若数据库与临时锁均已是未锁定,则返回True)
""" """
session = get_db_session() session = get_db_session()
try: try:
...@@ -155,10 +209,13 @@ class TokenService: ...@@ -155,10 +209,13 @@ class TokenService:
if not token: if not token:
logger.warning(f"解锁Token失败: Token不存在 - {token_value}") logger.warning(f"解锁Token失败: Token不存在 - {token_value}")
return False return False
if not token.is_locked:
logger.warning(f"解锁Token失败: Token未锁定 - {token_value}") db_locked = bool(token.is_locked)
return False tmp_locked = TokenService._is_tmp_locked(token_value)
# 乐观锁更新
db_unlocked_ok = False
# 仅当数据库为锁定状态时才尝试乐观解锁
if db_locked:
result = session.query(Token).filter( result = session.query(Token).filter(
Token.token_value == token.token_value, Token.token_value == token.token_value,
Token.version == token.version Token.version == token.version
...@@ -170,10 +227,33 @@ class TokenService: ...@@ -170,10 +227,33 @@ class TokenService:
}) })
if result == 0: if result == 0:
logger.warning(f"解锁Token失败: 版本冲突 - {token_value}") logger.warning(f"解锁Token失败: 版本冲突 - {token_value}")
return False # 或返回 VERSION_CONFLICT 错误码 db_unlocked_ok = False
else:
session.commit() session.commit()
logger.info(f"Token解锁成功: {token.token_value}") logger.info(f"Token解锁成功(数据库): {token.token_value}")
db_unlocked_ok = True
# 清理 Redis 临时锁和哈希状态(无论数据库是否锁定都做幂等处理)
redis_cleared = False
try:
r = get_redis()
key = f"token:{token_value}"
# 删除风控临时锁键
tmp_lock_key = f"token_lock:{token_value}"
del_res = r.delete(tmp_lock_key)
redis_cleared = (del_res or 0) > 0 or not tmp_locked
# 同步哈希视图为未锁定;仅在DB成功更新时同步版本号提升
mapping = {'is_locked': 0, 'lock_until_ts': 0}
if db_unlocked_ok:
mapping['version'] = token.version + 1
r.hset(key, mapping=mapping)
except Exception as re:
logger.warning(f"Redis 解锁/同步失败: {re}")
# 成功条件:任一来源原本为锁定并被成功清理;或二者原本均未锁定(幂等)
if (db_locked and db_unlocked_ok) or (tmp_locked and redis_cleared) or (not db_locked and not tmp_locked):
return True return True
return False
except Exception as e: except Exception as e:
session.rollback() session.rollback()
logger.error(f"解锁Token异常: {str(e)}") logger.error(f"解锁Token异常: {str(e)}")
...@@ -195,7 +275,8 @@ class TokenService: ...@@ -195,7 +275,8 @@ class TokenService:
session = get_db_session() session = get_db_session()
try: try:
token = session.query(Token).filter(Token.id == token_id).first() token = session.query(Token).filter(Token.id == token_id).first()
return token.to_dict() if token else None token_dict = token.to_dict() if token else None
return TokenService._apply_effective_lock(token_dict)
except Exception as e: except Exception as e:
logger.error(f"获取Token信息异常: {str(e)}") logger.error(f"获取Token信息异常: {str(e)}")
return None return None
...@@ -216,7 +297,8 @@ class TokenService: ...@@ -216,7 +297,8 @@ class TokenService:
session = get_db_session() session = get_db_session()
try: try:
token = session.query(Token).filter(Token.token_value == token_value).first() token = session.query(Token).filter(Token.token_value == token_value).first()
return token.to_dict() if token else None token_dict = token.to_dict() if token else None
return TokenService._apply_effective_lock(token_dict)
except Exception as e: except Exception as e:
logger.error(f"获取Token信息异常: {str(e)}") logger.error(f"获取Token信息异常: {str(e)}")
return None return None
...@@ -259,6 +341,13 @@ class TokenService: ...@@ -259,6 +341,13 @@ class TokenService:
return False # 或返回 VERSION_CONFLICT 错误码 return False # 或返回 VERSION_CONFLICT 错误码
session.commit() session.commit()
logger.info(f"Token锁定成功: {token.token_value}") logger.info(f"Token锁定成功: {token.token_value}")
# 同步 Redis: is_locked=1,可选设置 lock_until_ts(若有策略)
try:
r = get_redis()
key = f"token:{token.token_value}"
r.hset(key, mapping={'is_locked': 1, 'version': token.version + 1})
except Exception as re:
logger.warning(f"Redis 同步锁定失败: {re}")
return True return True
except Exception as e: except Exception as e:
session.rollback() session.rollback()
...@@ -267,6 +356,7 @@ class TokenService: ...@@ -267,6 +356,7 @@ class TokenService:
finally: finally:
close_db_session() close_db_session()
@staticmethod @staticmethod
def discard_token(token_value: str, operator: str) -> bool: def discard_token(token_value: str, operator: str) -> bool:
""" """
...@@ -299,6 +389,13 @@ class TokenService: ...@@ -299,6 +389,13 @@ class TokenService:
return False return False
session.commit() session.commit()
logger.info(f"Token废弃成功: {token.token_value}") logger.info(f"Token废弃成功: {token.token_value}")
# 同步 Redis
try:
r = get_redis()
key = f"token:{token_value}"
r.hset(key, mapping={'is_discarded': 1, 'version': token.version + 1})
except Exception as re:
logger.warning(f"Redis 同步废弃失败: {re}")
return True return True
except Exception as e: except Exception as e:
session.rollback() session.rollback()
...@@ -348,6 +445,13 @@ class TokenService: ...@@ -348,6 +445,13 @@ class TokenService:
return False return False
session.commit() session.commit()
logger.info(f"Token到期日修改成功: {token.token_value} -> {end_date}") logger.info(f"Token到期日修改成功: {token.token_value} -> {end_date}")
# 同步 Redis
try:
r = get_redis()
key = f"token:{token_value}"
r.hset(key, mapping={'end_ts': int(new_end_time.timestamp()), 'validity_period': new_validity_period, 'version': token.version + 1})
except Exception as re:
logger.warning(f"Redis 同步到期日失败: {re}")
return True return True
except Exception as e: except Exception as e:
session.rollback() session.rollback()
......
from typing import Optional
from redis import Redis
from config.settings import Config
_redis_client: Optional[Redis] = None
def get_redis() -> Redis:
global _redis_client
if _redis_client is None:
_redis_client = Redis(
host=Config.REDIS_HOST,
port=Config.REDIS_PORT,
db=Config.REDIS_DB,
password=Config.REDIS_PASSWORD,
decode_responses=True,
)
return _redis_client
...@@ -47,6 +47,18 @@ class Config: ...@@ -47,6 +47,18 @@ class Config:
# CORS配置 # CORS配置
CORS_ORIGINS = os.environ.get('CORS_ORIGINS') or '*' CORS_ORIGINS = os.environ.get('CORS_ORIGINS') or '*'
# Redis 配置
REDIS_HOST = os.environ.get('REDIS_HOST') or '127.0.0.1'
REDIS_PORT = int(os.environ.get('REDIS_PORT') or 6379)
REDIS_DB = int(os.environ.get('REDIS_DB') or 0)
REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD') or None
# Redis Streams 配置
REDIS_STREAM_LOCK_EVENTS = os.environ.get('REDIS_STREAM_LOCK_EVENTS') or 'token:lock_events'
REDIS_STREAM_GROUP = os.environ.get('REDIS_STREAM_GROUP') or 'lock_consumer_group'
# 后台消费者自动启动开关
AUTO_START_LOCK_EVENTS_CONSUMER = (os.environ.get('AUTO_START_LOCK_EVENTS_CONSUMER', 'true').lower() == 'true')
class DevelopmentConfig(Config): class DevelopmentConfig(Config):
"""开发环境配置""" """开发环境配置"""
DEBUG = True DEBUG = True
......
...@@ -12,3 +12,4 @@ SQLAlchemy>=2.0.25 ...@@ -12,3 +12,4 @@ SQLAlchemy>=2.0.25
PyMySQL==1.1.0 PyMySQL==1.1.0
bcrypt==4.1.2 bcrypt==4.1.2
python-dateutil>=2.8.2 python-dateutil>=2.8.2
redis==5.0.4
\ No newline at end of file