Skip to content

Commit f37162f

Browse files
committed
refactor: redesign default model configuration with is_default field
Breaking Changes: - Remove separate default_model config section - Add is_default field to model configuration - Simplify UI: set default model directly in model list - Each provider can have one default model Changes: - Backend: Add is_default to LLMProviderModelConfig schema - Frontend: Rewrite LLMSettingsSection with new design - Add model helper functions for finding default model - Update type definitions Benefits: - Single source of truth: no sync issues - Better UX: set default directly in model list - Simpler code: no complex derivation logic
1 parent 2135ef6 commit f37162f

7 files changed

Lines changed: 1455 additions & 715 deletions

File tree

packages/derisk-app/src/derisk_app/openapi/api_v1/helpers/__init__.py

Whitespace-only changes.
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
def find_default_model(config: AppConfig) -> Optional[Dict[str, Any]]:
2+
"""从配置中查找默认模型
3+
4+
Args:
5+
config: 应用配置
6+
7+
Returns:
8+
默认模型配置字典,包含 provider, model_name, temperature, max_new_tokens 等
9+
如果没有找到返回 None
10+
"""
11+
try:
12+
agent_llm = getattr(config, "agent_llm", None)
13+
if not agent_llm or not hasattr(agent_llm, "providers"):
14+
return None
15+
16+
providers = agent_llm.providers or []
17+
18+
# 查找第一个标记为 is_default 的模型
19+
for provider_config in providers:
20+
if not hasattr(provider_config, "models"):
21+
continue
22+
23+
models = provider_config.models or []
24+
for model_config in models:
25+
if getattr(model_config, "is_default", False):
26+
return {
27+
"provider": provider_config.provider,
28+
"model_name": model_config.name,
29+
"temperature": model_config.temperature or 0.7,
30+
"max_new_tokens": model_config.max_new_tokens or 4096,
31+
"is_multimodal": getattr(model_config, "is_multimodal", False),
32+
"api_base": provider_config.api_base,
33+
"api_key_ref": provider_config.api_key_ref,
34+
}
35+
36+
# 如果没有找到 is_default,返回第一个 provider 的第一个模型
37+
if providers and hasattr(providers[0], "models") and providers[0].models:
38+
first_provider = providers[0]
39+
first_model = first_provider.models[0]
40+
return {
41+
"provider": first_provider.provider,
42+
"model_name": first_model.name,
43+
"temperature": first_model.temperature or 0.7,
44+
"max_new_tokens": first_model.max_new_tokens or 4096,
45+
"is_multimodal": getattr(first_model, "is_multimodal", False),
46+
"api_base": first_provider.api_base,
47+
"api_key_ref": first_provider.api_key_ref,
48+
}
49+
50+
return None
51+
except Exception as e:
52+
logger.warning(f"Failed to find default model: {e}")
53+
return None
54+
55+
56+
def get_all_models_from_config(config: AppConfig) -> List[str]:
57+
"""从配置中获取所有模型名称
58+
59+
Args:
60+
config: 应用配置
61+
62+
Returns:
63+
模型名称列表
64+
"""
65+
models = []
66+
try:
67+
agent_llm = getattr(config, "agent_llm", None)
68+
if not agent_llm or not hasattr(agent_llm, "providers"):
69+
return models
70+
71+
providers = agent_llm.providers or []
72+
for provider_config in providers:
73+
if not hasattr(provider_config, "models"):
74+
continue
75+
76+
for model_config in provider_config.models or []:
77+
if hasattr(model_config, "name") and model_config.name:
78+
models.append(model_config.name)
79+
80+
return models
81+
except Exception as e:
82+
logger.warning(f"Failed to get all models from config: {e}")
83+
return models

packages/derisk-core/src/derisk_core/config/schema.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,13 @@ class OAuth2Config(BaseModel):
139139

140140

141141
class LLMProviderModelConfig(BaseModel):
142-
"""LLM Provider 中的模型配置"""
142+
"""模型配置(provider下的模型)"""
143143

144-
name: str = "gpt-4"
145-
temperature: float = 0.7
146-
max_new_tokens: int = 4096
147-
is_multimodal: bool = False
144+
name: str = Field(..., description="模型名称,如 gpt-4o, deepseek-chat")
145+
temperature: float = Field(0.7, description="模型温度参数")
146+
max_new_tokens: int = Field(4096, description="最大生成token数")
147+
is_multimodal: bool = Field(False, description="是否支持多模态(图片输入)")
148+
is_default: bool = Field(False, description="是否为该provider下的默认模型")
148149

149150

150151
class LLMProviderConfig(BaseModel):

0 commit comments

Comments
 (0)