1+ import json
2+ import requests
13from typing import Dict
24
35from loguru import logger
46
5- from data_engine .utils .availability_utils import AvailabilityChecking
6- from data_engine .utils .model_utils import get_model , prepare_model
7-
8- from ..base_op import OPERATORS , UNFORKABLE , Mapper , Sample , Param , DataType
7+ from ..base_op import OPERATORS , Mapper , Sample , Param , DataType
98
109DEFAULT_PROMPT_TEMPLATE = """
1110为了输出下面代码片段,请生成对应prompt内容,该prompt应该用中文详细描述需求, 比如使用python实现什么功能。请回复:prompt=?
1211代码片段:
1312{input_data}
1413"""
1514
16- OP_NAME = 'generate_code_qa_pair_mapper'
17-
18- with AvailabilityChecking (['torch' , 'transformers' ], OP_NAME ):
19- import torch
15+ DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
2016
21- # avoid hanging when calling model in multiprocessing
22- torch .set_num_threads (1 )
17+ OP_NAME = 'generate_code_qa_pair_mapper'
2318
2419
25- @UNFORKABLE .register_module (OP_NAME )
2620@OPERATORS .register_module (OP_NAME )
2721class GenerateCodeQAPairMapper (Mapper ):
28- _accelerator = 'cuda'
22+ """
23+ Mapper to generate code QA pairs using remote LLM API.
24+ Supports OpenAI-compatible API formats including Qwen, DeepSeek, GPT, etc.
25+ """
26+ _accelerator = 'cpu'
2927
3028 def __init__ (self ,
31- hf_model ,
32- trust_remote_code : bool = True ,
33- prompt_template : str = None ,
34- # {'temperature': 0.2, 'top_k': 10, 'top_p': 0.95}
35- sampling_params : Dict = {
36- 'temperature' : 0.2 , 'top_k' : 10 , 'top_p' : 0.95 },
29+ model_url : str = 'https://api.deepseek.com/chat/completions' ,
30+ model_name : str = 'deepseek-chat' ,
31+ auth_token : str = '' ,
32+ system_prompt : str = None ,
33+ sampling_params : Dict = None ,
3734 * args ,
3835 ** kwargs ):
3936 """
4037 Initialization method.
4138
42- :param hf_model: Hugginface model id.
43- :param trust_remote_code: passed to transformers
44- :param prompt_template: Prompt template for generate samples.
45- Please make sure the template contains "{augmented_data}",
46- which corresponds to the augmented samples.
39+ :param model_url: API endpoint URL (OpenAI-compatible format).
40+ :param model_name: Model name to use.
41+ :param auth_token: API authentication token.
42+ :param system_prompt: System prompt for the model.
4743 :param sampling_params: Sampling parameters for text generation.
4844 e.g {'temperature': 0.9, 'top_p': 0.95}
4945 :param args: extra args
@@ -52,53 +48,131 @@ def __init__(self,
5248 super ().__init__ (* args , ** kwargs )
5349 self .num_proc = 1
5450
55- if prompt_template is None :
56- prompt_template = DEFAULT_PROMPT_TEMPLATE
51+ self .model_url = model_url
52+ self .model_name = model_name
53+ self .auth_token = auth_token
54+
55+ if not self .model_url :
56+ raise ValueError ("model_url is required" )
57+ if not self .auth_token :
58+ raise ValueError ("auth_token is required" )
5759
58- self .prompt_template = prompt_template
60+ if system_prompt is None :
61+ system_prompt = DEFAULT_SYSTEM_PROMPT
62+ self .system_prompt = system_prompt
5963
60- self .model_key = prepare_model (
61- model_type = 'opcsg_inference' ,
62- pretrained_model_name_or_path = hf_model ,
63- trust_remote_code = trust_remote_code )
64+ if sampling_params is None :
65+ sampling_params = {'temperature' : 0.2 , 'top_k' : 10 , 'top_p' : 0.95 }
6466 self .sampling_params = sampling_params
6567
66- def build_prompt (self , sample , prompt_template ):
67- return prompt_template .format (input_data = sample )
68+ def build_prompt (self , code_snippet ):
69+ return DEFAULT_PROMPT_TEMPLATE .format (input_data = code_snippet )
6870
6971 def process (self , sample = None , rank = None ):
70- model , _ = get_model (self .model_key , rank = rank )
71- data = sample [self .text_key ]
72- input_prompt = self .build_prompt (data ,
73- self .prompt_template )
74-
75- response_str = model .generate (
76- message = input_prompt , sampling_params = self .sampling_params , system_prompt = 'You are a helpful assistant.' )
77- logger .debug (f'input_prompt is: { input_prompt } ' )
78- logger .debug (f'response_str is: { response_str } ' )
79- message_list = {self .text_key : {
80- 'input' : response_str .replace ('prompt=' , '' ), 'response' : data }}
81-
82- return message_list
72+ try :
73+ data = sample [self .text_key ]
74+ input_prompt = self .build_prompt (data )
75+
76+ messages = [
77+ {
78+ "role" : "system" ,
79+ "content" : self .system_prompt
80+ },
81+ {
82+ "role" : "user" ,
83+ "content" : input_prompt
84+ }
85+ ]
86+
87+ headers = {
88+ 'Authorization' : f'Bearer { self .auth_token } ' ,
89+ 'Content-Type' : 'application/json'
90+ }
91+
92+ request_data = {
93+ "model" : self .model_name ,
94+ "messages" : messages ,
95+ "stream" : False ,
96+ }
97+ # Merge sampling_params
98+ if self .sampling_params :
99+ request_data .update (self .sampling_params )
100+
101+ logger .info (f'Calling API: { self .model_url } , Model: { self .model_name } ' )
102+ logger .debug (f'input_prompt is: { input_prompt } ' )
103+
104+ response = requests .post (
105+ url = self .model_url ,
106+ headers = headers ,
107+ json = request_data ,
108+ timeout = 120
109+ )
110+ response .raise_for_status ()
111+
112+ result = response .json ()
113+
114+ if 'choices' not in result :
115+ logger .error (f'API response missing "choices" field: { result } ' )
116+ return sample
117+
118+ response_str = result ['choices' ][0 ]['message' ]['content' ]
119+
120+ logger .debug (f'response_str is: { response_str } ' )
121+
122+ # Extract content after "prompt="
123+ generated_prompt = response_str .replace ('prompt=' , '' ).strip ()
124+
125+ message_list = {
126+ self .text_key : {
127+ 'input' : generated_prompt ,
128+ 'response' : data
129+ }
130+ }
131+
132+ return message_list
133+
134+ except requests .exceptions .RequestException as e :
135+ logger .error (f'HTTP request error: { e } ' )
136+ logger .warning (f'API call failed, returning original sample' )
137+ except (KeyError , IndexError , json .JSONDecodeError ) as e :
138+ logger .error (f'API response parsing error: { e } ' )
139+ logger .warning (f'Response parsing failed, returning original sample' )
140+ except Exception as e :
141+ logger .error (f'Unexpected error: { e } ' )
142+ logger .warning (f'Exception occurred, returning original sample' )
143+
144+ # Return original sample on failure
145+ return sample
83146
84147 @classmethod
85148 @property
86149 def description (cls ):
87- return """Mapper to generate new instruction data based on code.
88- """
150+ return """Code QA pair generator: Generate requirement description prompts from code snippets. Supports OpenAI-compatible APIs including Qwen, DeepSeek, GPT, etc."""
89151
90152 @classmethod
91153 @property
92154 def sample (cls ):
93- return Sample ('def hello_world():\n print("Hello, World!")\n hello_world()' ,
94- 'message:[{"input": "create hello word function by python", "response": "def hello_world():\n print("Hello, World!")\n hello_world()" }]' )
155+ return Sample (
156+ 'def hello_world():\n print("Hello, World!")\n hello_world()' ,
157+ 'message:[{"input": "Write a Python function named hello_world that prints Hello, World! and call it", "response": "def hello_world():\\ n print(\\ "Hello, World!\\ ")\\ nhello_world()" }]'
158+ )
95159
96160 @classmethod
97161 @property
98162 def init_params (cls ):
99163 return [
100- Param ("hf_model" , DataType .STRING , {
101- "AIWizards/Llama2-Chinese-7b-Chat" : "AIWizards/Llama2-Chinese-7b-Chat" ,
102- }, "AIWizards/Llama2-Chinese-7b-Chat" ),
103- Param ("prompt_template" , DataType .STRING , None , None ),
164+ Param ("model_url" , DataType .STRING , {
165+ "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions" : "Qwen API" ,
166+ "https://api.deepseek.com/chat/completions" : "DeepSeek API" ,
167+ "https://api.openai.com/v1/chat/completions" : "OpenAI API" ,
168+ }, "https://api.deepseek.com/chat/completions" ),
169+ Param ("model_name" , DataType .STRING , {
170+ "qwen-plus" : "qwen-plus" ,
171+ "qwen-max" : "qwen-max" ,
172+ "deepseek-chat" : "deepseek-chat" ,
173+ "deepseek-reasoner" : "deepseek-reasoner" ,
174+ "gpt-4" : "gpt-4" ,
175+ "gpt-3.5-turbo" : "gpt-3.5-turbo" ,
176+ }, "deepseek-chat" ),
177+ Param ("auth_token" , DataType .STRING , {}, "" ),
104178 ]
0 commit comments