diff --git a/api/endpoint.go b/api/endpoint.go index 1a34160b..291565b9 100644 --- a/api/endpoint.go +++ b/api/endpoint.go @@ -33,12 +33,13 @@ type CreateEndpointInput struct { WorkersMin int `json:"workersMin"` WorkersMax int `json:"workersMax"` FlashBootType string `json:"flashBootType"` + ModelReferences []string `json:"modelReferences"` } // there are many more fields in the result of the query but I just care about these for CLI port type Endpoint struct { Name string `json:"name"` - Id string + Id string `json:"id"` } type EndpointOut struct { Data *EndpointData `json:"data"` @@ -225,6 +226,53 @@ func UpdateEndpointTemplate(endpointId string, templateId string) (err error) { return } +func UpdateEndpointModel(endpointId string, endpointName string, modelReferences []string) (err error) { + input := Input{ + Query: ` + mutation saveEndpoint($input: EndpointInput!) { + saveEndpoint(input: $input) { + id + modelReferences + } + } + `, + Variables: map[string]interface{}{"input": map[string]interface{}{ + "id": endpointId, + "name": endpointName, + "modelReferences": modelReferences, + }}, + } + res, err := Query(input) + if err != nil { + return + } + defer res.Body.Close() + rawData, err := io.ReadAll(res.Body) + if err != nil { + return + } + if res.StatusCode != 200 { + err = fmt.Errorf("statuscode %d: %s", res.StatusCode, string(rawData)) + return + } + data := make(map[string]interface{}) + if err = json.Unmarshal(rawData, &data); err != nil { + return + } + gqlErrors, ok := data["errors"].([]interface{}) + if ok && len(gqlErrors) > 0 { + firstErr, _ := gqlErrors[0].(map[string]interface{}) + err = errors.New(firstErr["message"].(string)) + return + } + gqldata, ok := data["data"].(map[string]interface{}) + if !ok || gqldata == nil { + err = fmt.Errorf("data is nil: %s", string(rawData)) + return + } + return +} + func GetEndpoints() (endpoints []*Endpoint, err error) { input := Input{ Query: ` @@ -248,6 +296,7 @@ func GetEndpoints() (endpoints []*Endpoint, err error) { workersMin workersStandby gpuCount + modelReferences env { key value @@ -289,4 +338,4 @@ func GetEndpoints() (endpoints []*Endpoint, err error) { } endpoints = data.Data.Myself.Endpoints return -} +} \ No newline at end of file diff --git a/cmd/endpoint/endpoint.go b/cmd/endpoint/endpoint.go new file mode 100644 index 00000000..a4e19d1b --- /dev/null +++ b/cmd/endpoint/endpoint.go @@ -0,0 +1,15 @@ +package endpoint + +import ( + "github.com/spf13/cobra" +) + +var Cmd = &cobra.Command{ + Use: "endpoint", + Short: "manage serverless endpoints", + Long: "manage serverless endpoints on runpod", +} + +func init() { + Cmd.AddCommand(modelCmd) +} diff --git a/cmd/endpoint/update_model.go b/cmd/endpoint/update_model.go new file mode 100644 index 00000000..479accff --- /dev/null +++ b/cmd/endpoint/update_model.go @@ -0,0 +1,44 @@ +package endpoint + +import ( + "fmt" + + "github.com/runpod/runpodctl/api" + "github.com/spf13/cobra" +) + +var modelCmd = &cobra.Command{ + Use: "model [model-ref...]", + Short: "update model references on an endpoint", + Long: "set the model references (cached models) for a serverless endpoint", + Args: cobra.MinimumNArgs(2), + RunE: runModel, +} + +func runModel(cmd *cobra.Command, args []string) error { + endpointID := args[0] + modelRefs := args[1:] + + endpoints, err := api.GetEndpoints() + if err != nil { + return fmt.Errorf("failed to list endpoints: %w", err) + } + + var endpointName string + for _, ep := range endpoints { + if ep.Id == endpointID { + endpointName = ep.Name + break + } + } + if endpointName == "" { + return fmt.Errorf("endpoint %s not found", endpointID) + } + + if err := api.UpdateEndpointModel(endpointID, endpointName, modelRefs); err != nil { + return fmt.Errorf("failed to update endpoint model: %w", err) + } + + fmt.Printf("updated model references for endpoint %s (%s)\n", endpointName, endpointID) + return nil +} diff --git a/cmd/project/functions.go b/cmd/project/functions.go index c8cffc50..2639f66e 100644 --- a/cmd/project/functions.go +++ b/cmd/project/functions.go @@ -406,8 +406,7 @@ func startProject(networkVolumeId string) error { echo "Process $1 has been killed with SIGKILL." return fi - sleep 1 - done + done echo "Failed to kill process with PID: $1 after SIGKILL attempt." exit 1 @@ -580,6 +579,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) { flashboot := true flashBootType := "FLASHBOOT" idleTimeout := 5 + var modelRefs []string endpointConfig, ok := config.Get("endpoint").(*toml.Tree) if ok { if min, ok := endpointConfig.Get("active_workers").(int64); ok { @@ -597,6 +597,11 @@ func deployProject(networkVolumeId string) (endpointId string, err error) { if idle, ok := endpointConfig.Get("idle_timeout").(int64); ok { idleTimeout = int(idle) } + if refs, ok := endpointConfig.Get("model_refs").([]interface{}); ok { + for _, r := range refs { + modelRefs = append(modelRefs, r.(string)) + } + } } if err != nil { deployedEndpointId, err = api.CreateEndpoint(&api.CreateEndpointInput{ @@ -610,6 +615,7 @@ func deployProject(networkVolumeId string) (endpointId string, err error) { WorkersMin: minWorkers, WorkersMax: maxWorkers, FlashBootType: flashBootType, + ModelReferences: modelRefs, }) if err != nil { fmt.Println("error making endpoint") diff --git a/cmd/project/tomlBuilder.go b/cmd/project/tomlBuilder.go index 8fca466b..0662ab58 100644 --- a/cmd/project/tomlBuilder.go +++ b/cmd/project/tomlBuilder.go @@ -77,6 +77,11 @@ active_workers = 0 max_workers = 3 flashboot = true +# model_refs - List of model references to cache on the endpoint workers. +# Format: "owner/model-name" or "owner/model-name:branch". +# Example: ["runpod/stable-diffusion-v1-5", "meta-llama/Llama-2-7b-chat-hf"] +# model_refs = [] + [runtime] # python_version - Python version to use for the project. # diff --git a/cmd/root.go b/cmd/root.go index 205b8ef0..19d545a5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "github.com/runpod/runpodctl/cmd/billing" "github.com/runpod/runpodctl/cmd/config" "github.com/runpod/runpodctl/cmd/datacenter" + "github.com/runpod/runpodctl/cmd/endpoint" "github.com/runpod/runpodctl/cmd/doctor" "github.com/runpod/runpodctl/cmd/gpu" "github.com/runpod/runpodctl/cmd/hub" @@ -86,6 +87,7 @@ func registerCommands() { rootCmd.AddCommand(pod.Cmd) rootCmd.AddCommand(serverless.Cmd) rootCmd.AddCommand(template.Cmd) + rootCmd.AddCommand(endpoint.Cmd) rootCmd.AddCommand(model.Cmd) rootCmd.AddCommand(volume.Cmd) rootCmd.AddCommand(registry.Cmd) @@ -147,7 +149,7 @@ func registerCommands() { // Version flag rootCmd.Version = version rootCmd.Flags().BoolP("version", "v", false, "print the version of runpodctl") - rootCmd.SetVersionTemplate(`runpodctl {{ .Version }} + rootCmd.SetVersionTemplate(`runpodctl {{ .version }} `) }