Skip to content

Commit 2ae1676

Browse files
JingTYHaiHui886
authored andcommitted
fix the bug of #79.Unable to replace model
1 parent cc9c9c2 commit 2ae1676

3 files changed

Lines changed: 134 additions & 58 deletions

File tree

Lines changed: 127 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,45 @@
1+
import json
2+
import requests
13
from typing import Dict
24

35
from loguru import logger
46

5-
from data_engine.utils.availability_utils import AvailabilityChecking
6-
from data_engine.utils.model_utils import get_model, prepare_model
7-
8-
from ..base_op import OPERATORS, UNFORKABLE, Mapper, Sample, Param, DataType
7+
from ..base_op import OPERATORS, Mapper, Sample, Param, DataType
98

109
DEFAULT_PROMPT_TEMPLATE = """
1110
为了输出下面代码片段,请生成对应prompt内容,该prompt应该用中文详细描述需求, 比如使用python实现什么功能。请回复:prompt=?
1211
代码片段:
1312
{input_data}
1413
"""
1514

16-
OP_NAME = 'generate_code_qa_pair_mapper'
17-
18-
with AvailabilityChecking(['torch', 'transformers'], OP_NAME):
19-
import torch
15+
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
2016

21-
# avoid hanging when calling model in multiprocessing
22-
torch.set_num_threads(1)
17+
OP_NAME = 'generate_code_qa_pair_mapper'
2318

2419

25-
@UNFORKABLE.register_module(OP_NAME)
2620
@OPERATORS.register_module(OP_NAME)
2721
class GenerateCodeQAPairMapper(Mapper):
28-
_accelerator = 'cuda'
22+
"""
23+
Mapper to generate code QA pairs using remote LLM API.
24+
Supports OpenAI-compatible API formats including Qwen, DeepSeek, GPT, etc.
25+
"""
26+
_accelerator = 'cpu'
2927

3028
def __init__(self,
31-
hf_model,
32-
trust_remote_code: bool = True,
33-
prompt_template: str = None,
34-
# {'temperature': 0.2, 'top_k': 10, 'top_p': 0.95}
35-
sampling_params: Dict = {
36-
'temperature': 0.2, 'top_k': 10, 'top_p': 0.95},
29+
model_url: str = 'https://api.deepseek.com/chat/completions',
30+
model_name: str = 'deepseek-chat',
31+
auth_token: str = '',
32+
system_prompt: str = None,
33+
sampling_params: Dict = None,
3734
*args,
3835
**kwargs):
3936
"""
4037
Initialization method.
4138
42-
:param hf_model: Hugginface model id.
43-
:param trust_remote_code: passed to transformers
44-
:param prompt_template: Prompt template for generate samples.
45-
Please make sure the template contains "{augmented_data}",
46-
which corresponds to the augmented samples.
39+
:param model_url: API endpoint URL (OpenAI-compatible format).
40+
:param model_name: Model name to use.
41+
:param auth_token: API authentication token.
42+
:param system_prompt: System prompt for the model.
4743
:param sampling_params: Sampling parameters for text generation.
4844
e.g {'temperature': 0.9, 'top_p': 0.95}
4945
:param args: extra args
@@ -52,53 +48,131 @@ def __init__(self,
5248
super().__init__(*args, **kwargs)
5349
self.num_proc = 1
5450

55-
if prompt_template is None:
56-
prompt_template = DEFAULT_PROMPT_TEMPLATE
51+
self.model_url = model_url
52+
self.model_name = model_name
53+
self.auth_token = auth_token
54+
55+
if not self.model_url:
56+
raise ValueError("model_url is required")
57+
if not self.auth_token:
58+
raise ValueError("auth_token is required")
5759

58-
self.prompt_template = prompt_template
60+
if system_prompt is None:
61+
system_prompt = DEFAULT_SYSTEM_PROMPT
62+
self.system_prompt = system_prompt
5963

60-
self.model_key = prepare_model(
61-
model_type='opcsg_inference',
62-
pretrained_model_name_or_path=hf_model,
63-
trust_remote_code=trust_remote_code)
64+
if sampling_params is None:
65+
sampling_params = {'temperature': 0.2, 'top_k': 10, 'top_p': 0.95}
6466
self.sampling_params = sampling_params
6567

66-
def build_prompt(self, sample, prompt_template):
67-
return prompt_template.format(input_data=sample)
68+
def build_prompt(self, code_snippet):
69+
return DEFAULT_PROMPT_TEMPLATE.format(input_data=code_snippet)
6870

6971
def process(self, sample=None, rank=None):
70-
model, _ = get_model(self.model_key, rank=rank)
71-
data = sample[self.text_key]
72-
input_prompt = self.build_prompt(data,
73-
self.prompt_template)
74-
75-
response_str = model.generate(
76-
message=input_prompt, sampling_params=self.sampling_params, system_prompt='You are a helpful assistant.')
77-
logger.debug(f'input_prompt is: {input_prompt}')
78-
logger.debug(f'response_str is: {response_str}')
79-
message_list = {self.text_key: {
80-
'input': response_str.replace('prompt=', ''), 'response': data}}
81-
82-
return message_list
72+
try:
73+
data = sample[self.text_key]
74+
input_prompt = self.build_prompt(data)
75+
76+
messages = [
77+
{
78+
"role": "system",
79+
"content": self.system_prompt
80+
},
81+
{
82+
"role": "user",
83+
"content": input_prompt
84+
}
85+
]
86+
87+
headers = {
88+
'Authorization': f'Bearer {self.auth_token}',
89+
'Content-Type': 'application/json'
90+
}
91+
92+
request_data = {
93+
"model": self.model_name,
94+
"messages": messages,
95+
"stream": False,
96+
}
97+
# Merge sampling_params
98+
if self.sampling_params:
99+
request_data.update(self.sampling_params)
100+
101+
logger.info(f'Calling API: {self.model_url}, Model: {self.model_name}')
102+
logger.debug(f'input_prompt is: {input_prompt}')
103+
104+
response = requests.post(
105+
url=self.model_url,
106+
headers=headers,
107+
json=request_data,
108+
timeout=120
109+
)
110+
response.raise_for_status()
111+
112+
result = response.json()
113+
114+
if 'choices' not in result:
115+
logger.error(f'API response missing "choices" field: {result}')
116+
return sample
117+
118+
response_str = result['choices'][0]['message']['content']
119+
120+
logger.debug(f'response_str is: {response_str}')
121+
122+
# Extract content after "prompt="
123+
generated_prompt = response_str.replace('prompt=', '').strip()
124+
125+
message_list = {
126+
self.text_key: {
127+
'input': generated_prompt,
128+
'response': data
129+
}
130+
}
131+
132+
return message_list
133+
134+
except requests.exceptions.RequestException as e:
135+
logger.error(f'HTTP request error: {e}')
136+
logger.warning(f'API call failed, returning original sample')
137+
except (KeyError, IndexError, json.JSONDecodeError) as e:
138+
logger.error(f'API response parsing error: {e}')
139+
logger.warning(f'Response parsing failed, returning original sample')
140+
except Exception as e:
141+
logger.error(f'Unexpected error: {e}')
142+
logger.warning(f'Exception occurred, returning original sample')
143+
144+
# Return original sample on failure
145+
return sample
83146

84147
@classmethod
85148
@property
86149
def description(cls):
87-
return """Mapper to generate new instruction data based on code.
88-
"""
150+
return """Code QA pair generator: Generate requirement description prompts from code snippets. Supports OpenAI-compatible APIs including Qwen, DeepSeek, GPT, etc."""
89151

90152
@classmethod
91153
@property
92154
def sample(cls):
93-
return Sample('def hello_world():\n print("Hello, World!")\nhello_world()',
94-
'message:[{"input": "create hello word function by python", "response": "def hello_world():\n print("Hello, World!")\nhello_world()" }]')
155+
return Sample(
156+
'def hello_world():\n print("Hello, World!")\nhello_world()',
157+
'message:[{"input": "Write a Python function named hello_world that prints Hello, World! and call it", "response": "def hello_world():\\n print(\\"Hello, World!\\")\\nhello_world()" }]'
158+
)
95159

96160
@classmethod
97161
@property
98162
def init_params(cls):
99163
return [
100-
Param("hf_model", DataType.STRING, {
101-
"AIWizards/Llama2-Chinese-7b-Chat": "AIWizards/Llama2-Chinese-7b-Chat",
102-
}, "AIWizards/Llama2-Chinese-7b-Chat"),
103-
Param("prompt_template", DataType.STRING, None, None),
164+
Param("model_url", DataType.STRING, {
165+
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions": "Qwen API",
166+
"https://api.deepseek.com/chat/completions": "DeepSeek API",
167+
"https://api.openai.com/v1/chat/completions": "OpenAI API",
168+
}, "https://api.deepseek.com/chat/completions"),
169+
Param("model_name", DataType.STRING, {
170+
"qwen-plus": "qwen-plus",
171+
"qwen-max": "qwen-max",
172+
"deepseek-chat": "deepseek-chat",
173+
"deepseek-reasoner": "deepseek-reasoner",
174+
"gpt-4": "gpt-4",
175+
"gpt-3.5-turbo": "gpt-3.5-turbo",
176+
}, "deepseek-chat"),
177+
Param("auth_token", DataType.STRING, {}, ""),
104178
]

data_server/algo_templates/mapper/algo_template_mapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from sqlalchemy.orm import Session
2+
from sqlalchemy import desc
23
from typing import List, Optional, Tuple
34
import yaml
45

@@ -141,7 +142,7 @@ def get_templates_by_query(db: Session, user_id: str,
141142

142143
total = query.count()
143144

144-
145-
templates = query.offset((page - 1) * page_size).limit(page_size).all()
145+
# desc_by_id
146+
templates = query.order_by(desc(AlgoTemplate.id)).offset((page - 1) * page_size).limit(page_size).all()
146147

147148
return templates, total

data_server/database/Initialization_data/operator_config.sql

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Target Server Version : 150010 (150010)
1313
File Encoding : 65001
1414
15-
Date: 27/11/2025 10:04:18
15+
Date: 03/12/2025 15:44:33
1616
*/
1717

1818

@@ -83,8 +83,6 @@ INSERT INTO "public"."operator_config" VALUES (44, 35, 'chars_to_remove', 'input
8383
INSERT INTO "public"."operator_config" VALUES (5, 2, 'rep_len', 'number', NULL, '10', '0', NULL, NULL, 'f', 'f', '1', NULL, '2025-07-25 17:12:04.424873', '2025-07-25 17:12:04.424873', NULL);
8484
INSERT INTO "public"."operator_config" VALUES (6, 2, 'min_ratio', 'slider', NULL, '0', '0', '1', '0.01', 'f', 'f', NULL, NULL, '2025-07-25 17:12:04.424873', '2025-07-25 17:12:04.424873', NULL);
8585
INSERT INTO "public"."operator_config" VALUES (7, 2, 'max_ratio', 'slider', NULL, '0.5', '0', '1', '0.01', 'f', 'f', NULL, NULL, '2025-07-25 17:12:04.424873', '2025-07-25 17:12:04.424873', NULL);
86-
INSERT INTO "public"."operator_config" VALUES (19, 8, 'prompt_template', 'input', NULL, NULL, NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-07-28 21:56:42.474364', '2025-07-28 21:56:42.474364', NULL);
87-
INSERT INTO "public"."operator_config" VALUES (18, 8, 'hf_model', 'select', '[23]', '23', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-07-28 21:56:42.474364', '2025-07-28 21:56:42.474364', NULL);
8886
INSERT INTO "public"."operator_config" VALUES (13, 13, 'max_len', 'number', NULL, '136028', '0', NULL, NULL, 'f', 'f', '1', NULL, '2025-07-25 17:23:47.885255', '2025-07-25 17:23:47.885255', NULL);
8987
INSERT INTO "public"."operator_config" VALUES (12, 13, 'min_len', 'number', NULL, '10', '0', NULL, NULL, 'f', 'f', '1', NULL, '2025-07-25 17:23:47.885255', '2025-07-25 17:23:47.885255', NULL);
9088
INSERT INTO "public"."operator_config" VALUES (4, 15, 'max_ratio', 'number', NULL, '999999', NULL, NULL, NULL, 'f', 'f', '1', NULL, '2025-07-25 17:07:16.881312', '2025-07-25 17:07:16.881312', NULL);
@@ -196,6 +194,9 @@ INSERT INTO "public"."operator_config" VALUES (145, 59, 'dimensions', 'number',
196194
INSERT INTO "public"."operator_config" VALUES (133, 59, 'model_name', 'input', NULL, 'text-embedding-v4', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-09-02 10:34:57.837909', '2025-09-02 10:34:57.837909', NULL);
197195
INSERT INTO "public"."operator_config" VALUES (146, 59, 'model_url', 'input', NULL, 'https://dashscope.aliyuncs.com/compatible-mode/v1', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-08-25 16:47:28.412958', '2025-08-25 16:47:28.412958', NULL);
198196
INSERT INTO "public"."operator_config" VALUES (142, 61, 'model_name', 'input', NULL, 'qwen-plus', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-08-26 18:26:29.905881', '2025-08-26 18:26:29.905881', NULL);
197+
INSERT INTO "public"."operator_config" VALUES (162, 8, 'auth_token', 'input', NULL, NULL, NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-12-03 14:55:10', '2025-12-03 14:55:13', NULL);
198+
INSERT INTO "public"."operator_config" VALUES (18, 8, 'model_name', 'input', NULL, 'deepseek-chat', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-07-28 21:56:42.474364', '2025-07-28 21:56:42.474364', NULL);
199+
INSERT INTO "public"."operator_config" VALUES (19, 8, 'model_url', 'input', NULL, 'https://www.sophnet.com/api/open-apis/v1/chat/completions', NULL, NULL, NULL, 'f', 'f', NULL, NULL, '2025-07-28 21:56:42.474364', '2025-07-28 21:56:42.474364', NULL);
199200

200201
-- ----------------------------
201202
-- Indexes structure for table operator_config

0 commit comments

Comments
 (0)