Skip to content
Commits on Source (2)
...@@ -269,6 +269,12 @@ class TestTuDataComprehensive(unittest.TestCase): ...@@ -269,6 +269,12 @@ class TestTuDataComprehensive(unittest.TestCase):
result = self._mock_api_response('top_inst', trade_date='20240101') result = self._mock_api_response('top_inst', trade_date='20240101')
self._validate_dataframe_result(result) self._validate_dataframe_result(result)
def test_stock_st(self):
"""Test stock_st method"""
print(" ⚠️ 测试ST股票标记接口")
result = self._mock_api_response('stock_st', ts_code='000001.SZ')
self._validate_dataframe_result(result)
# Financial Statement Methods Tests # Financial Statement Methods Tests
def test_income(self): def test_income(self):
"""Test income method""" """Test income method"""
...@@ -434,6 +440,12 @@ class TestTuDataComprehensive(unittest.TestCase): ...@@ -434,6 +440,12 @@ class TestTuDataComprehensive(unittest.TestCase):
result = self._mock_api_response('index_weight', index_code='000001.SH') result = self._mock_api_response('index_weight', index_code='000001.SH')
self._validate_dataframe_result(result) self._validate_dataframe_result(result)
def test_ci_index_member(self):
"""Test ci_index_member method"""
print(" 🧩 测试指数成分成员接口")
result = self._mock_api_response('ci_index_member', index_code='000001.SH')
self._validate_dataframe_result(result)
def test_index_dailybasic(self): def test_index_dailybasic(self):
"""Test index_dailybasic method""" """Test index_dailybasic method"""
print(" 📊 测试指数每日指标接口") print(" 📊 测试指数每日指标接口")
......
...@@ -840,6 +840,10 @@ class pro_api: ...@@ -840,6 +840,10 @@ class pro_api:
def etf_index(self, api_name='etf_index', **kwargs): def etf_index(self, api_name='etf_index', **kwargs):
return self.query(token=self.token, api_name=api_name, **kwargs) return self.query(token=self.token, api_name=api_name, **kwargs)
def ci_index_member(self, api_name='ci_index_member', **kwargs):
return self.query(token=self.token, api_name=api_name, **kwargs)
def stock_st(self, api_name='stock_st', **kwargs):
return self.query(token=self.token, api_name=api_name, **kwargs)
import pandas as pd import pandas as pd
import os import os
......
// 环境配置 // 环境配置
export const environment = { export const environment = {
// API基础地址 // API基础地址
API_BASE_URL: 'http://localhost:7777', API_BASE_URL: 'http://114.132.244.63/token-tushare',
// Tushare服务地址 // Tushare服务地址
TUSHARE_API_URL: 'http://localhost:8001' TUSHARE_API_URL: 'http://114.132.244.63:8000'
} }
export default environment; export default environment;
\ No newline at end of file
...@@ -72,7 +72,8 @@ const formData = reactive({ ...@@ -72,7 +72,8 @@ const formData = reactive({
const loading = ref(false) const loading = ref(false)
const API_BASE_URL = import.meta.env.VITE_API_BASE_URL import { environment } from '../config/environment'
const API_BASE_URL = environment.API_BASE_URL
onMounted(() => { onMounted(() => {
checkSavedLogin() checkSavedLogin()
......
from fastapi import APIRouter, Request, Response, Header, Body from fastapi import APIRouter, Request, Response, Header, Body
from starlette.concurrency import run_in_threadpool
from app.service import tushare_funet from app.service import tushare_funet
from app.services import TokenService from app.services import TokenService
import inspect import inspect
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.utils.logger import get_logger from app.utils.logger import get_logger
from app.utils.redis_bus import publish_event
router = APIRouter(prefix="/tushare", tags=["tushare"]) router = APIRouter(prefix="/tushare", tags=["tushare"])
...@@ -93,24 +95,30 @@ async def pro_bar_view(request: Request): ...@@ -93,24 +95,30 @@ async def pro_bar_view(request: Request):
t3 = time.time() t3 = time.time()
if not ok: if not ok:
logger.error(f"[pro_bar_view] token check failed, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, msg: {msg}") await logger.aerror(f"[pro_bar_view] token check failed, token={token}, total={time.time()-start_time:.4f}s")
return Response(content=msg, status_code=401) return Response(content=msg, status_code=401)
t4 = time.time() t4 = time.time()
try: try:
resp = tushare_funet.pro_bar(**body) # 优先异步 httpx 实现
if hasattr(tushare_funet, 'pro_bar_async'):
resp = await tushare_funet.pro_bar_async(**body)
else:
resp = await run_in_threadpool(tushare_funet.pro_bar, **body)
t5 = time.time() t5 = time.time()
if hasattr(resp, "status_code") and hasattr(resp, "content"): if hasattr(resp, "status_code") and hasattr(resp, "content"):
await logger.ainfo(f"[pro_bar_view] finished, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s")
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None)) return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None))
await logger.ainfo(f"[pro_bar_view] finished, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s")
return resp return resp
except Exception as e: except Exception as e:
logger.error(f"[pro_bar_view] pro_bar call exception, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, get_params: {t4-t3:.4f}s, exception: {str(e)}") await logger.aerror(f"[pro_bar_view] exception={str(e)}, token={token}, method=pro_bar, total={time.time()-start_time:.4f}s")
return Response(content=str(e), status_code=500) return Response(content=str(e), status_code=500)
except Exception as e: except Exception as e:
logger.error(f"[pro_bar_view] request processing exception, total: {time.time() - start_time:.4f}s, exception: {str(e)}") await logger.aerror(f"[pro_bar_view] request exception={str(e)}, total={time.time()-start_time:.4f}s")
return Response(content=str(e), status_code=500) return Response(content=str(e), status_code=500)
@router.post("") @router.post("")
...@@ -125,29 +133,30 @@ async def tushare_entry(request: Request): ...@@ -125,29 +133,30 @@ async def tushare_entry(request: Request):
ok, msg = check_token(token, client_ip) ok, msg = check_token(token, client_ip)
t3 = time.time() t3 = time.time()
if not ok: if not ok:
logger.info(f"[tushare_entry] token check failed, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s") await logger.ainfo(f"[tushare_entry] token check failed, token={token}, total={time.time()-start_time:.4f}s")
return Response(content=msg, status_code=401) return Response(content=msg, status_code=401)
api_name = body.get("api_name") api_name = body.get("api_name")
t4 = time.time() t4 = time.time()
if not api_name: if not api_name:
logger.info(f"[tushare_entry] api_name missing, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, api_name: {t4-t3:.4f}s") await logger.ainfo(f"[tushare_entry] api_name missing, token={token}, total={time.time()-start_time:.4f}s")
return {"success": False, "msg": "api_name 不能为空"} return {"success": False, "msg": "api_name 不能为空"}
# 动态分发 # 动态分发
if not hasattr(pro, api_name): if not hasattr(pro, api_name):
logger.info(f"[tushare_entry] api_name not supported, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, api_name: {t4-t3:.4f}s") await logger.ainfo(f"[tushare_entry] api_name not supported, token={token}, method={api_name}, total={time.time()-start_time:.4f}s")
return {"success": False, "msg": f"不支持的api_name: {api_name}"} return {"success": False, "msg": f"不支持的api_name: {api_name}"}
method = getattr(pro, api_name) method = getattr(pro, api_name)
t5 = time.time() t5 = time.time()
try: try:
resp = method(**body) # 统一走对应方法(方法内部会调用 query),放入线程池避免阻塞
resp = await run_in_threadpool(method, **body)
t6 = time.time() t6 = time.time()
if hasattr(resp, "status_code") and hasattr(resp, "content"): if hasattr(resp, "status_code") and hasattr(resp, "content"):
logger.info(f"[tushare_entry] finished, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, api_name: {t4-t3:.4f}s, get_method: {t5-t4:.4f}s, method_call: {t6-t5:.4f}s, response: {time.time()-t6:.4f}s") await logger.ainfo(f"[tushare_entry] finished, token={token}, method={api_name}, total={time.time()-start_time:.4f}s")
return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None)) return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type", None))
logger.info(f"[tushare_entry] finished, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, api_name: {t4-t3:.4f}s, get_method: {t5-t4:.4f}s, method_call: {t6-t5:.4f}s, response: {time.time()-t6:.4f}s") await logger.ainfo(f"[tushare_entry] finished, token={token}, method={api_name}, total={time.time()-start_time:.4f}s")
return resp return resp
except Exception as e: except Exception as e:
logger.error(f"[tushare_entry] exception, total: {time.time() - start_time:.4f}s, body: {t2-t1:.4f}s, check_token: {t3-t2:.4f}s, api_name: {t4-t3:.4f}s, get_method: {t5-t4:.4f}s, exception: {str(e)}") await logger.aerror(f"[tushare_entry] exception={str(e)}, token={token}, method={api_name}, total={time.time()-start_time:.4f}s")
return Response(content=str(e), status_code=500) return Response(content=str(e), status_code=500)
@router.post("/unlock_token") @router.post("/unlock_token")
...@@ -162,6 +171,8 @@ async def unlock_token_api(data: dict = Body(...)): ...@@ -162,6 +171,8 @@ async def unlock_token_api(data: dict = Body(...)):
token_info['is_locked'] = False token_info['is_locked'] = False
if token in TOKEN_IP_MAP: if token in TOKEN_IP_MAP:
TOKEN_IP_MAP[token]['locked_at'] = None TOKEN_IP_MAP[token]['locked_at'] = None
# 通知其他进程刷新该 token
publish_event({"type": "token_unlock", "token": token})
return {"success": True, "msg": "内存解锁成功"} return {"success": True, "msg": "内存解锁成功"}
@router.post("/add_token") @router.post("/add_token")
...@@ -173,6 +184,7 @@ async def add_token_api(data: dict = Body(...)): ...@@ -173,6 +184,7 @@ async def add_token_api(data: dict = Body(...)):
if not token_info: if not token_info:
return {"success": False, "msg": "数据库中未找到该 token"} return {"success": False, "msg": "数据库中未找到该 token"}
ALL_TOKENS[token_value] = token_info ALL_TOKENS[token_value] = token_info
publish_event({"type": "token_add", "token": token_value})
return {"success": True, "msg": "token 已添加到内存", "token_value": token_value} return {"success": True, "msg": "token 已添加到内存", "token_value": token_value}
@router.post("/remove_token") @router.post("/remove_token")
...@@ -187,4 +199,5 @@ async def remove_token_api(data: dict = Body(...)): ...@@ -187,4 +199,5 @@ async def remove_token_api(data: dict = Body(...)):
TOKEN_IP_MAP.pop(token_value, None) TOKEN_IP_MAP.pop(token_value, None)
if token_value in LOCKED_TOKENS: if token_value in LOCKED_TOKENS:
LOCKED_TOKENS.discard(token_value) LOCKED_TOKENS.discard(token_value)
publish_event({"type": "token_remove", "token": token_value})
return {"success": True, "msg": "token 已从内存移除", "token_value": token_value} return {"success": True, "msg": "token 已从内存移除", "token_value": token_value}
\ No newline at end of file
...@@ -3,6 +3,8 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -3,6 +3,8 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api import tushare from app.api import tushare
from app.database import db from app.database import db
from config.settings import Config from config.settings import Config
from app.utils.redis_bus import start_subscriber
from app.utils.logger import get_logger
# 启动时初始化数据库 # 启动时初始化数据库
class DummyApp: class DummyApp:
...@@ -33,6 +35,41 @@ app.include_router(tushare.router) ...@@ -33,6 +35,41 @@ app.include_router(tushare.router)
def startup_event(): def startup_event():
from app.api.tushare import load_all_tokens from app.api.tushare import load_all_tokens
load_all_tokens() load_all_tokens()
# 启动 Redis 订阅,跨进程同步内存 token 状态
rlogger = get_logger("redis_sync").logger
def _on_event(evt: dict):
try:
et = evt.get("type")
tv = evt.get("token")
if not tv:
return
rlogger.info(f"redis handle start type={et} token={tv}")
# 直接操作内存字典,避免 DB
if et == "token_remove":
from app.api.tushare import ALL_TOKENS, TOKEN_IP_MAP, LOCKED_TOKENS
ALL_TOKENS.pop(tv, None)
TOKEN_IP_MAP.pop(tv, None)
if tv in LOCKED_TOKENS:
LOCKED_TOKENS.discard(tv)
rlogger.info(f"redis handled token_remove token={tv}")
elif et == "token_add":
from app.services import TokenService as Svc
info = Svc.get_token_by_value(tv)
if info:
from app.api.tushare import ALL_TOKENS
ALL_TOKENS[tv] = info
rlogger.info(f"redis handled token_add token={tv} found={bool(info)}")
elif et == "token_unlock":
from app.api.tushare import ALL_TOKENS, TOKEN_IP_MAP
info = ALL_TOKENS.get(tv)
if info:
info['is_locked'] = False
if tv in TOKEN_IP_MAP:
TOKEN_IP_MAP[tv]['locked_at'] = None
rlogger.info(f"redis handled token_unlock token={tv}")
except Exception:
pass
start_subscriber(_on_event)
@app.get("/") @app.get("/")
def read_root(): def read_root():
......
import requests import requests
import requests.adapters
import pandas as pd import pandas as pd
import time import time
import logging import logging
...@@ -8,6 +9,21 @@ from app.service.config import get_tushare_token ...@@ -8,6 +9,21 @@ from app.service.config import get_tushare_token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 创建长连接会话与连接池,减少 TIME_WAIT
_session = requests.Session()
_adapter = requests.adapters.HTTPAdapter(pool_connections=20, pool_maxsize=100, max_retries=0)
_session.mount('http://', _adapter)
_session.mount('https://', _adapter)
_session.headers.update({'Connection': 'keep-alive'})
# async HTTP client
try:
import httpx
_async_client = httpx.AsyncClient(timeout=30, headers={'Connection': 'keep-alive'})
except Exception:
httpx = None
_async_client = None
class pro_api: class pro_api:
# def __init__(self, token): # def __init__(self, token):
# self.token = token # 实例变量 name # self.token = token # 实例变量 name
...@@ -50,26 +66,29 @@ class pro_api: ...@@ -50,26 +66,29 @@ class pro_api:
'fields': fields_data, 'fields': fields_data,
} }
# # print(f"=== tushare_funet query 请求开始 ===") response = _session.post(url, json=params, timeout=30)
# # print(f"请求URL: {url}") return response
# # print(f"API名称: {api_name}")
# # print(f"Param参数: {params}")
# # print(f"kwargs参数: {kwargs}")
response = requests.post(url, json=params, timeout=30)
# # print(f"=== tushare_funet query 响应信息 ===")
# # print(f"响应状态码: {response.status_code}")
# # print(f"响应内容长度: {len(response.text)}")
# # print(f"响应头: {dict(response.headers)}")
# # print(f"响应内容前500字符: {response.text[:500]}")
# if response.status_code != 200:
# print(f"响应错误内容: {response.text}")
# # print(f"=== tushare_funet query 请求结束 ===") async def async_query(self, api_name, fields='', **kwargs):
"""
Async query via httpx.AsyncClient for general tushare API calls.
return response Returns:
httpx.Response when using async client, or requests.Response fallback
"""
url = "http://120.53.122.167:9002/tq"
params_data = kwargs.get('params', {})
fields_data = kwargs.get('fields', fields)
payload = {
'token': self.token,
'api_name': api_name,
'params': params_data,
'fields': fields_data,
}
if _async_client is None:
return _session.post(url, json=payload, timeout=30)
resp = await _async_client.post(url, json=payload)
return resp
...@@ -757,6 +776,12 @@ class pro_api: ...@@ -757,6 +776,12 @@ class pro_api:
return self.query(api_name=api_name, **kwargs) return self.query(api_name=api_name, **kwargs)
def etf_index(self, api_name='etf_index', **kwargs): def etf_index(self, api_name='etf_index', **kwargs):
return self.query(api_name=api_name, **kwargs) return self.query(api_name=api_name, **kwargs)
def ci_index_member(self, api_name='ci_index_member', **kwargs):
return self.query(api_name=api_name, **kwargs)
def stock_st(self, api_name='stock_st', **kwargs):
return self.query(api_name=api_name, **kwargs)
import pandas as pd import pandas as pd
import os import os
...@@ -851,5 +876,45 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E ...@@ -851,5 +876,45 @@ def pro_bar(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E
return '此接口为单独权限,和积分没有关系,需要单独购买' return '此接口为单独权限,和积分没有关系,需要单独购买'
else: else:
response = requests.post(url, json=params,) response = _session.post(url, json=params,)
return response return response
async def pro_bar_async(ts_code='', api=None, start_date='', end_date='', freq='D', asset='E',
exchange='',
adj = None,
ma = [],
factors = None,
adjfactor = False,
offset = None,
limit = None,
fields = '',
contract_type = '', token=None):
"""
Async version of pro_bar, using httpx.AsyncClient
"""
url = "http://120.53.122.167:9002/tp"
params = {
'token': get_tushare_token(),
'ts_code':ts_code,
'api':api,
'start_date':start_date,
'end_date':end_date,
'freq':freq,
'asset':asset,
'exchange':exchange,
'adj' :adj,
'ma' :ma,
"factors" : factors,
"adjfactor" : adjfactor,
"offset" : offset,
"limit" :limit,
"fields" : fields,
"contract_type" : contract_type
}
if 'min' in freq:
return '此接口为单独权限,和积分没有关系,需要单独购买'
if _async_client is None:
# fallback to sync session in thread if httpx is unavailable
return _session.post(url, json=params)
resp = await _async_client.post(url, json=params)
return resp
\ No newline at end of file
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
import os import os
import logging import logging
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler, QueueHandler, QueueListener
import queue
from datetime import datetime from datetime import datetime
def setup_logging(app): def setup_logging(app):
...@@ -24,29 +25,40 @@ def setup_logging(app): ...@@ -24,29 +25,40 @@ def setup_logging(app):
'%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s' '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
) )
# 文件处理器 - 按大小轮转 # 文件/控制台处理器(用于后台监听线程)
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(log_file,
log_file, maxBytes=app.config.get('LOG_MAX_BYTES', 10*1024*1024),
maxBytes=app.config.get('LOG_MAX_BYTES', 10*1024*1024), # 10MB
backupCount=app.config.get('LOG_BACKUP_COUNT', 5), backupCount=app.config.get('LOG_BACKUP_COUNT', 5),
encoding='utf-8' encoding='utf-8')
)
file_handler.setFormatter(log_format) file_handler.setFormatter(log_format)
file_handler.setLevel(getattr(logging, app.config.get('LOG_LEVEL', 'INFO'))) file_handler.setLevel(getattr(logging, app.config.get('LOG_LEVEL', 'INFO')))
# 控制台处理器
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(log_format) console_handler.setFormatter(log_format)
console_handler.setLevel(logging.DEBUG) console_handler.setLevel(logging.INFO)
# 配置根日志 # 队列 + 监听器,实现非阻塞异步日志
root_logger = logging.getLogger() log_queue = queue.SimpleQueue()
root_logger.setLevel(logging.DEBUG) queue_handler = QueueHandler(log_queue)
# 避免重复添加处理器 root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
# 避免重复添加
if not root_logger.handlers: if not root_logger.handlers:
root_logger.addHandler(file_handler) root_logger.addHandler(queue_handler)
root_logger.addHandler(console_handler) listener = QueueListener(log_queue, file_handler, console_handler)
listener.daemon = True
listener.start()
# 静音第三方库的 DEBUG 噪音
for noisy in [
'urllib3', 'urllib3.connectionpool',
'httpx', 'httpcore',
'sqlalchemy.engine',
]:
nlog = logging.getLogger(noisy)
nlog.setLevel(logging.WARNING)
nlog.propagate = True
# 创建应用日志器 # 创建应用日志器
app_logger = logging.getLogger('tushare_web') app_logger = logging.getLogger('tushare_web')
...@@ -65,42 +77,36 @@ class Logger: ...@@ -65,42 +77,36 @@ class Logger:
def _setup_logger(self): def _setup_logger(self):
"""设置日志器""" """设置日志器"""
# 创建logs目录
if self.log_file:
log_dir = os.path.dirname(self.log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
# 创建日志器 # 创建日志器
self.logger = logging.getLogger(self.name) self.logger = logging.getLogger(self.name)
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
# 避免重复添加处理器 # 若根日志器尚未配置任何处理器,则在此完成异步日志配置(默认到项目 logs 目录)
if self.logger.handlers: root_logger = logging.getLogger()
return if not root_logger.handlers:
# 计算默认日志目录:项目根的 logs
# 日志格式 base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
formatter = logging.Formatter( default_log_dir = os.path.join(base_dir, 'logs')
'%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s' os.makedirs(default_log_dir, exist_ok=True)
) log_file = self.log_file or os.path.join(default_log_dir, 'tushare_web.log')
# 文件处理器 - 按大小轮转 fmt = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s')
if self.log_file: file_handler = RotatingFileHandler(log_file, maxBytes=10*1024*1024, backupCount=5, encoding='utf-8')
file_handler = RotatingFileHandler( file_handler.setFormatter(fmt)
self.log_file,
maxBytes=10*1024*1024, # 10MB
backupCount=5,
encoding='utf-8'
)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO) file_handler.setLevel(logging.INFO)
self.logger.addHandler(file_handler)
# 控制台处理器
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter) console_handler.setFormatter(fmt)
console_handler.setLevel(logging.DEBUG) console_handler.setLevel(logging.DEBUG)
self.logger.addHandler(console_handler)
q = queue.SimpleQueue()
qh = QueueHandler(q)
root_logger.setLevel(logging.DEBUG)
root_logger.addHandler(qh)
listener = QueueListener(q, file_handler, console_handler)
listener.daemon = True
listener.start()
# 子日志器不再加额外处理器,复用根日志器的 QueueHandler
def debug(self, message): def debug(self, message):
"""调试日志""" """调试日志"""
...@@ -122,6 +128,18 @@ class Logger: ...@@ -122,6 +128,18 @@ class Logger:
"""严重错误日志""" """严重错误日志"""
self.logger.critical(message) self.logger.critical(message)
# ----- 异步快捷方法(可 await)-----
async def adebug(self, message):
self.logger.debug(message)
async def ainfo(self, message):
self.logger.info(message)
async def awarning(self, message):
self.logger.warning(message)
async def aerror(self, message):
self.logger.error(message)
async def acritical(self, message):
self.logger.critical(message)
def log_request(self, request): def log_request(self, request):
"""记录请求日志""" """记录请求日志"""
self.info(f"请求: {request.method} {request.url} - IP: {request.remote_addr}") self.info(f"请求: {request.method} {request.url} - IP: {request.remote_addr}")
......
import json
import threading
from typing import Callable, Optional
import logging
from config.settings import Config
import time
try:
import redis
except Exception: # pragma: no cover
redis = None
_logger = logging.getLogger("redis_bus")
def _get_redis_client() -> Optional["redis.Redis"]:
if redis is None:
_logger.warning("redis library not installed; pub/sub disabled")
return None
# 统一从 settings.Config 读取配置,避免模块内直接访问环境变量
url = getattr(Config, "REDIS_URL", None)
if url:
return redis.from_url(url, decode_responses=True)
host = getattr(Config, "REDIS_HOST", "127.0.0.1")
port = int(getattr(Config, "REDIS_PORT", 6379))
db = int(getattr(Config, "REDIS_DB", 0))
return redis.Redis(host=host, port=port, db=db, decode_responses=True)
CHANNEL = "tushare_token_events"
def publish_event(event: dict) -> None:
client = _get_redis_client()
if not client:
return
try:
client.publish(CHANNEL, json.dumps(event))
_logger.info(f"redis publish success channel={CHANNEL} event={event}")
except Exception as e: # pragma: no cover
_logger.warning(f"redis publish failed: {e}")
def start_subscriber(callback: Callable[[dict], None]) -> Optional[threading.Thread]:
client = _get_redis_client()
if not client:
return None
def _loop():
# 永久循环,网络中断时自动重连
while True:
try:
local_client = _get_redis_client()
if not local_client:
time.sleep(2)
continue
pubsub = local_client.pubsub()
pubsub.subscribe(CHANNEL)
_logger.info(f"redis subscribed channel={CHANNEL}")
for msg in pubsub.listen():
if msg.get("type") != "message":
continue
try:
data = json.loads(msg.get("data", "{}"))
_logger.info(f"redis received event channel={CHANNEL} event={data}")
callback(data)
except Exception as e: # pragma: no cover
_logger.warning(f"redis message handling failed: {e}")
except Exception as e: # 连接错误,重连
_logger.warning(f"redis subscriber error: {e}, reconnecting in 1s")
time.sleep(1)
th = threading.Thread(target=_loop, name="redis-subscriber", daemon=True)
th.start()
return th
...@@ -37,6 +37,12 @@ class Config: ...@@ -37,6 +37,12 @@ class Config:
# CORS配置 # CORS配置
CORS_ORIGINS = os.environ.get('CORS_ORIGINS') or '*' CORS_ORIGINS = os.environ.get('CORS_ORIGINS') or '*'
# Redis 配置(用于跨进程事件同步等)
REDIS_URL = os.environ.get('REDIS_URL')
REDIS_HOST = os.environ.get('REDIS_HOST') or '127.0.0.1'
REDIS_PORT = int(os.environ.get('REDIS_PORT') or 6379)
REDIS_DB = int(os.environ.get('REDIS_DB') or 0)
class DevelopmentConfig(Config): class DevelopmentConfig(Config):
"""开发环境配置""" """开发环境配置"""
DEBUG = True DEBUG = True
......
#!/bin/bash
# 配置参数
APP_DIR="tushare-web-api" # 源代码目录(构建后的目录)
NEW_DIR="${APP_DIR}1" # 生产环境目标目录(字符串连接:APP_DIR + 1)
# 备份文件夹增加时间后缀(YYYYMMDDHHSS)
TIMESTAMP=$(date +%Y%m%d%H%S)
BACKUP_DIR="${NEW_DIR}-backup-${TIMESTAMP}"
SERVICE_NAME="tushareweb1" # Supervisor中的服务名(根据实际情况修改)
# 备份现有生产目录 NEW_DIR
echo "开始备份生产目录 $NEW_DIR$BACKUP_DIR..."
if [ -d "$NEW_DIR" ]; then
cp -r "$NEW_DIR" "$BACKUP_DIR"
if [ $? -eq 0 ]; then
echo "备份完成"
else
echo "错误:备份 $NEW_DIR 失败"
exit 1
fi
else
echo "警告:未找到 $NEW_DIR 目录,跳过备份"
fi
# 复制 APP_DIR 到 APP_DIR + 1(字符串连接)
echo "开始复制 $APP_DIR$NEW_DIR ..."
if [ ! -d "$APP_DIR" ]; then
echo "错误:源目录 $APP_DIR 不存在"
exit 1
fi
# 如目标目录已存在,执行覆盖式复制(不删除目录,覆盖同名文件)
if [ -d "$NEW_DIR" ]; then
echo "目标目录已存在,执行覆盖式复制..."
cp -r "$APP_DIR"/. "$NEW_DIR"/
RC=$?
else
cp -r "$APP_DIR" "$NEW_DIR"
RC=$?
fi
if [ $RC -eq 0 ]; then
echo "复制/覆盖完成:$APP_DIR -> $NEW_DIR"
else
echo "错误:复制 $APP_DIR$NEW_DIR 失败"
exit 1
fi
# 重启服务
echo "重启 $SERVICE_NAME 服务..."
supervisorctl restart "$SERVICE_NAME"
if [ $? -eq 0 ]; then
echo "服务重启成功"
else
echo "警告:服务重启可能失败,请检查服务状态"
exit 1
fi
echo "部署完成"
exit 0
...@@ -4,6 +4,8 @@ pydantic ...@@ -4,6 +4,8 @@ pydantic
sqlalchemy sqlalchemy
pymysql pymysql
requests requests
httpx>=0.24
redis>=4.5
pandas pandas
python-dateutil python-dateutil
python-dotenv python-dotenv
......
#!/bin/bash #!/bin/bash
set -e
cd "$(dirname "$0")" cd "$(dirname "$0")"
source /home/leewcc/tushare-web-api/myenv/bin/activate source /home/leewcc/tushare-web-api1/myenv/bin/activate
pip install -r requirements.txt pip install -r requirements.txt
exec uvicorn app.main:app --host 0.0.0.0 --port 8000
# 支持外部传入端口与进程数(workers),默认 8000 / 1
PORT="${PORT:-${1:-8000}}"
WORKERS="${WORKERS:-${2:-1}}"
# 使用 exec 让 uvicorn 取代当前进程,便于 Supervisor 正确管理
exec uvicorn app.main:app --host 0.0.0.0 --port "$PORT" --workers "$WORKERS"
\ No newline at end of file
#!/bin/bash #!/bin/bash
set -e # 出错时立即退出,避免无效执行
cd "$(dirname "$0")" cd "$(dirname "$0")"
source /home/leewcc/tushare-web-back/vene/bin/activate
pip install -r requirements.txt # 仅在虚拟环境不存在时才创建(避免重复创建)
if [ ! -d "vene" ]; then
python3 -m venv vene
fi
# 激活虚拟环境
source ./vene/bin/activate # 用相对路径更可靠
# 仅在依赖未安装或有更新时手动执行,注释掉自动安装
# pip install -r requirements.txt
# 设置环境变量
export FLASK_APP=run.py export FLASK_APP=run.py
export FLASK_ENV=production export FLASK_ENV=production
flask run --host=0.0.0.0 --port=7777
# 启动命令:使用 exec 确保进程替换,便于后续管理
exec flask run --host=0.0.0.0 --port=7777
\ No newline at end of file