Skip to content

Commit 7594599

Browse files
committed
fix: system config model provider issues
- Fix default model config not saved correctly after refresh - Remove max_tokens upper limit to allow flexible model specs - Fix provider dropdown showing too many options - Fix configured models not appearing in Agent/homepage after save - Change ModelConfig.provider from Enum to str for flexibility
1 parent b547023 commit 7594599

3 files changed

Lines changed: 106 additions & 26 deletions

File tree

packages/derisk-core/src/derisk_core/config/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class LLMProvider(str, Enum):
1616
class ModelConfig(BaseModel):
1717
"""模型配置"""
1818

19-
provider: LLMProvider = LLMProvider.OPENAI
19+
provider: str = "openai"
2020
model_id: str = "gpt-4"
2121
api_key: Optional[str] = None
2222
base_url: Optional[str] = None

packages/derisk-serve/src/derisk_serve/model/api/endpoints.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,94 @@ async def test_auth():
127127
async def model_params(worker_manager: WorkerManager = Depends(get_worker_manager)):
128128
try:
129129
params = []
130+
131+
# 1. Get models from worker_manager
130132
workers = await worker_manager.supported_models()
131133
for worker in workers:
132134
for model in worker.models:
133135
model_dict = model.__dict__
134136
model_dict["host"] = worker.host
135137
model_dict["port"] = worker.port
136138
params.append(model_dict)
139+
140+
# 2. Get models from system_app.config (JSON configuration)
141+
system_app = SystemApp.get_instance() or global_system_app
142+
if system_app and system_app.config:
143+
# Try "agent.llm" direct key
144+
agent_llm_conf = system_app.config.get("agent.llm")
145+
146+
# If not found, try "agent" -> "llm" (nested dict access)
147+
if not agent_llm_conf:
148+
agent_conf = system_app.config.get("agent")
149+
if isinstance(agent_conf, dict):
150+
agent_llm_conf = agent_conf.get("llm")
151+
152+
# Check for flattened keys (fallback)
153+
if not agent_llm_conf:
154+
flattened = system_app.config.get_all_by_prefix("agent.llm.")
155+
if flattened:
156+
agent_llm_conf = {}
157+
prefix_len = len("agent.llm.")
158+
for k, v in flattened.items():
159+
agent_llm_conf[k[prefix_len:]] = v
160+
161+
# Also try app_config from configs dict (JSON config source)
162+
if not agent_llm_conf:
163+
app_config = system_app.config.configs.get("app_config")
164+
if app_config:
165+
agent_llm_attr = getattr(app_config, "agent_llm", None)
166+
if agent_llm_attr:
167+
# Convert frontend format to backend format
168+
agent_llm_dict = (
169+
agent_llm_attr.model_dump(mode="json")
170+
if hasattr(agent_llm_attr, "model_dump")
171+
else dict(agent_llm_attr)
172+
)
173+
# Convert providers -> provider, models -> model
174+
if "providers" in agent_llm_dict:
175+
providers = agent_llm_dict.pop("providers")
176+
if isinstance(providers, list):
177+
converted = []
178+
for p in providers:
179+
if isinstance(p, dict):
180+
cp = dict(p)
181+
if "models" in cp:
182+
cp["model"] = cp.pop("models")
183+
converted.append(cp)
184+
agent_llm_dict["provider"] = converted
185+
agent_llm_conf = agent_llm_dict
186+
187+
# Parse models from Multi-Provider List Structure [[agent.llm.provider]]
188+
if agent_llm_conf and isinstance(agent_llm_conf.get("provider"), list):
189+
providers = agent_llm_conf.get("provider")
190+
for p_conf in providers:
191+
if isinstance(p_conf, dict) and "model" in p_conf:
192+
p_models = p_conf.get("model")
193+
p_name = p_conf.get("provider", "unknown")
194+
if isinstance(p_models, list):
195+
for m in p_models:
196+
if isinstance(m, dict) and "name" in m:
197+
m_name = m.get("name")
198+
# Add model to params if not already present
199+
if not any(
200+
p.get("model") == m_name
201+
and p.get("provider") == p_name
202+
for p in params
203+
):
204+
params.append(
205+
{
206+
"model": m_name,
207+
"provider": p_name,
208+
"worker_type": "llm",
209+
"host": f"proxy@{p_name}",
210+
"port": 0,
211+
"enabled": True,
212+
}
213+
)
214+
137215
return Result.succ(params)
138216
except Exception as e:
139-
return Result.failed(err_code="E000X", msg=f"model stop failed {e}")
217+
return Result.failed(err_code="E000X", msg=f"model types failed {e}")
140218

141219

142220
@router.get("/models")

web/src/components/config/LLMSettingsSection.tsx

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,29 @@ function buildSecretReference(secretName?: string) {
7878

7979
function deriveDefaultProviderName(config: AppConfig) {
8080
const providers = config.agent_llm?.providers || [];
81-
const defaultBaseUrl = config.default_model?.base_url || "";
8281
const defaultProvider = normalizeProviderName(
8382
String(config.default_model?.provider || "")
8483
);
85-
86-
const matchedByBaseUrl = providers.find(
87-
(item) =>
88-
normalizeProviderName(item.api_base || "") ===
89-
normalizeProviderName(defaultBaseUrl)
90-
);
91-
if (matchedByBaseUrl) {
92-
return normalizeProviderName(matchedByBaseUrl.provider);
84+
85+
if (defaultProvider && defaultProvider !== "custom") {
86+
const matchedProvider = providers.find(
87+
(item) => normalizeProviderName(item.provider) === defaultProvider
88+
);
89+
if (matchedProvider) {
90+
return normalizeProviderName(matchedProvider.provider);
91+
}
9392
}
9493

95-
const matchedByProvider = providers.find(
96-
(item) => normalizeProviderName(item.provider) === defaultProvider
97-
);
98-
if (matchedByProvider) {
99-
return normalizeProviderName(matchedByProvider.provider);
94+
const defaultBaseUrl = config.default_model?.base_url || "";
95+
if (defaultBaseUrl) {
96+
const matchedByBaseUrl = providers.find(
97+
(item) =>
98+
normalizeProviderName(item.api_base || "") ===
99+
normalizeProviderName(defaultBaseUrl)
100+
);
101+
if (matchedByBaseUrl) {
102+
return normalizeProviderName(matchedByBaseUrl.provider);
103+
}
100104
}
101105

102106
return defaultProvider || normalizeProviderName(providers[0]?.provider) || "openai";
@@ -197,11 +201,12 @@ export default function LLMSettingsSection({ config, onChange }: Props) {
197201
const providerOptions = useMemo(() => {
198202
const values = new Set<string>();
199203
BUILTIN_PROVIDER_OPTIONS.forEach((item) => values.add(item.value));
200-
llmKeys.forEach((item) => values.add(normalizeProviderName(item.provider)));
201-
Object.keys(modelSuggestionsByProvider).forEach((item) => values.add(item));
202204
configuredProviders.forEach((item: LLMProviderConfig) => {
203205
if (item?.provider) {
204-
values.add(normalizeProviderName(item.provider));
206+
const normalized = normalizeProviderName(item.provider);
207+
if (!BUILTIN_DEFAULT_MODEL_PROVIDERS.has(normalized)) {
208+
values.add(normalized);
209+
}
205210
}
206211
});
207212
return Array.from(values)
@@ -213,7 +218,7 @@ export default function LLMSettingsSection({ config, onChange }: Props) {
213218
BUILTIN_PROVIDER_OPTIONS.find((item) => item.value === value)?.label ||
214219
value,
215220
}));
216-
}, [configuredProviders, llmKeys, modelSuggestionsByProvider]);
221+
}, [configuredProviders]);
217222

218223
async function loadSupportedModels() {
219224
setLoadingModels(true);
@@ -366,11 +371,7 @@ export default function LLMSettingsSection({ config, onChange }: Props) {
366371
...config,
367372
default_model: {
368373
...config.default_model,
369-
provider: (
370-
BUILTIN_DEFAULT_MODEL_PROVIDERS.has(selectedProviderName)
371-
? selectedProviderName
372-
: "custom"
373-
) as AppConfig["default_model"]["provider"],
374+
provider: selectedProviderName as AppConfig["default_model"]["provider"],
374375
model_id: resolvedModelId,
375376
base_url: selectedProvider.api_base || config.default_model?.base_url,
376377
temperature: resolvedTemperature,
@@ -797,11 +798,12 @@ export default function LLMSettingsSection({ config, onChange }: Props) {
797798
<Form.Item
798799
name={[modelField.name, "max_new_tokens"]}
799800
label="Max Tokens"
801+
tooltip="请根据模型实际支持的最大token数设置,不同模型限制不同"
800802
>
801803
<InputNumber
802804
style={{ width: "100%" }}
803805
min={1}
804-
max={128000}
806+
placeholder="4096"
805807
/>
806808
</Form.Item>
807809
<Form.Item

0 commit comments

Comments
 (0)