|
7 | 7 | from loguru import logger |
8 | 8 | from ...schemas.responses import response_success, response_fail |
9 | 9 | from data_engine.utils.cache_utils import DATA_JUICER_MODELS_CACHE |
| 10 | +from typing import Optional |
10 | 11 |
|
11 | 12 | router = APIRouter() |
12 | 13 |
|
@@ -55,7 +56,7 @@ def check_model_for_md_to_jsonl( |
55 | 56 |
|
56 | 57 | # Get the first matching model information |
57 | 58 | model_info = data['data'][0] |
58 | | - model_name_from_api = model_info.get('name', model_name) # Get name field from API response |
| 59 | + model_name_from_api = model_info.get('path', model_name) # Get path field from API response |
59 | 60 | repository = model_info.get('repository', {}) |
60 | 61 | http_clone_url = repository.get('http_clone_url') |
61 | 62 |
|
@@ -219,3 +220,87 @@ def check_model_for_md_to_jsonl( |
219 | 220 | except Exception as e: |
220 | 221 | logger.error(f'Error checking model: {str(e)}') |
221 | 222 | return response_fail(msg=f'检查模型时发生错误: {str(e)}') |
| 223 | + |
| 224 | + |
| 225 | +@router.get("/list-models", summary="获取模型列表") |
| 226 | +def list_models( |
| 227 | + page: int = Query(1, description="页码,从1开始"), |
| 228 | + per_page: int = Query(16, description="每页数量,默认16条"), |
| 229 | + search: Optional[str] = Query(None, description="搜索关键词"), |
| 230 | + sort: str = Query("trending", description="排序方式") |
| 231 | +): |
| 232 | + """ |
| 233 | + Fetch model list from OpenCSG Hub and return basic information. |
| 234 | + |
| 235 | + Returns: |
| 236 | + - path: Model path |
| 237 | + - updated_at: Last update time |
| 238 | + - first_tag: show_name of the first tag with category == "task" |
| 239 | + - downloads: Download count |
| 240 | + - description: Model description |
| 241 | + |
| 242 | + :param page: Page number |
| 243 | + :param per_page: Items per page, default 16 |
| 244 | + :param search: Search keyword (optional) |
| 245 | + :param sort: Sort method, default trending |
| 246 | + :return: Model list |
| 247 | + """ |
| 248 | + try: |
| 249 | + # Get CSGHUB_ENDPOINT from environment |
| 250 | + csghub_endpoint = os.getenv('CSGHUB_ENDPOINT', 'https://hub.opencsg.com') |
| 251 | + api_url = f'{csghub_endpoint}/api/v1/models' |
| 252 | + |
| 253 | + # Set request parameters |
| 254 | + params = { |
| 255 | + 'page': page, |
| 256 | + 'per': per_page, |
| 257 | + 'search': search or '', |
| 258 | + 'sort': sort, |
| 259 | + 'source': '' |
| 260 | + } |
| 261 | + |
| 262 | + logger.info(f'Fetching models from {api_url} with params: {params}') |
| 263 | + |
| 264 | + # Send request |
| 265 | + response = requests.get(api_url, params=params, timeout=30) |
| 266 | + response.raise_for_status() |
| 267 | + data = response.json() |
| 268 | + |
| 269 | + # Extract required fields |
| 270 | + models_data = [] |
| 271 | + if data.get('data'): |
| 272 | + for model in data['data']: |
| 273 | + # Find the first tag with category == "task" and get its show_name |
| 274 | + first_tag_show_name = '' |
| 275 | + if model.get('tags'): |
| 276 | + for tag in model['tags']: |
| 277 | + if tag.get('category') == 'task': |
| 278 | + first_tag_show_name = tag.get('show_name', '') |
| 279 | + break |
| 280 | + |
| 281 | + model_info = { |
| 282 | + 'path': model.get('path', ''), |
| 283 | + 'updated_at': model.get('updated_at', ''), |
| 284 | + 'first_tag': first_tag_show_name, |
| 285 | + 'downloads': model.get('downloads', 0), |
| 286 | + 'description': model.get('description', '') |
| 287 | + } |
| 288 | + models_data.append(model_info) |
| 289 | + |
| 290 | + # Build result |
| 291 | + result = { |
| 292 | + 'models': models_data, |
| 293 | + 'total': data.get('total', 0), |
| 294 | + 'page': page, |
| 295 | + 'per_page': per_page |
| 296 | + } |
| 297 | + |
| 298 | + logger.info(f'Successfully fetched {len(models_data)} models') |
| 299 | + return response_success(data=result, msg='获取模型列表成功') |
| 300 | + |
| 301 | + except requests.RequestException as e: |
| 302 | + logger.error(f'Failed to request OpenCSG API: {str(e)}') |
| 303 | + return response_fail(msg=f'无法连接到OpenCSG Hub API: {str(e)}') |
| 304 | + except Exception as e: |
| 305 | + logger.error(f'Error fetching models: {str(e)}') |
| 306 | + return response_fail(msg=f'获取模型列表时发生错误: {str(e)}') |
0 commit comments