From 6b758d109608052f4892394c0445c5f1098aa815 Mon Sep 17 00:00:00 2001 From: Timo Riegebauer Date: Mon, 23 Feb 2026 21:37:17 +0000 Subject: [PATCH] go-ollama: changed development to dev container and added streaming capabilities to methods v0.1.1 --- .devcontainer/devcontainer.json | 22 +++++++++ .github/dependabot.yml | 12 +++++ copy_model.go | 16 +++--- create_model.go | 77 ++++++++++++++++++++++++++--- delete_model.go | 14 +++--- generate_chat_message.go | 87 ++++++++++++++++++++++++++++++--- generate_embeddings.go | 16 +++--- generate_response.go | 86 +++++++++++++++++++++++++++++--- get_version.go | 14 +++--- list_models.go | 14 +++--- list_running_models.go | 14 +++--- pull_model.go | 75 +++++++++++++++++++++++++--- push_model.go | 78 +++++++++++++++++++++++++---- show_model_details.go | 16 +++--- 14 files changed, 448 insertions(+), 93 deletions(-) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .github/dependabot.yml diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..ceb7cca --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/go . +{ + "name": "Go", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/go:2-1.25-trixie" + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + // "postCreateCommand": "go version", + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f33a02c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/copy_model.go b/copy_model.go index 9c5b2b5..19896df 100644 --- a/copy_model.go +++ b/copy_model.go @@ -13,15 +13,15 @@ type CopyModelRequest struct { Destination string `json:"destination"` } -func (o Ollama) CopyModel(reqBody CopyModelRequest) (int, error) { +func (o Ollama) CopyModel(reqBody CopyModelRequest) error { reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return -1, err + return err } - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/create", o.baseUrl), bytes.NewReader(reqBodyBytes)) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/copy", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return -1, err + return err } for key, val := range o.customHeaders { @@ -31,12 +31,12 @@ func (o Ollama) CopyModel(reqBody CopyModelRequest) (int, error) { resp, err := http.DefaultClient.Do(req) if err != nil { - return -1, err + return err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") } - return resp.StatusCode, nil + return nil } diff --git a/create_model.go b/create_model.go index c08c405..c82176c 100644 --- a/create_model.go +++ b/create_model.go @@ -1,6 +1,7 @@ package ollama import ( + "bufio" "bytes" "encoding/json" "errors" @@ -41,15 +42,24 @@ type CreateModelResponse struct { Status string `json:"status"` } -func (o Ollama) CreateModel(reqBody CreateModelRequest) (CreateModelResponse, int, error) { +type CreateModelResponseStream struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int `json:"total"` + Completed int `json:"completed"` +} + +func (o Ollama) CreateModel(reqBody CreateModelRequest) (CreateModelResponse, error) { + reqBody.Stream = PtrOf(false) + reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return CreateModelResponse{}, -1, err + return CreateModelResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/create", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return CreateModelResponse{}, -1, err + return CreateModelResponse{}, err } for key, val := range o.customHeaders { @@ -59,17 +69,68 @@ func (o Ollama) CreateModel(reqBody CreateModelRequest) (CreateModelResponse, in resp, err := http.DefaultClient.Do(req) if err != nil { - return CreateModelResponse{}, -1, err + return CreateModelResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return CreateModelResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return CreateModelResponse{}, errors.New("status code is not 200") } var respBody CreateModelResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return CreateModelResponse{}, -1, err + return CreateModelResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil +} + +func (o Ollama) CreateModelStream(reqBody CreateModelRequest, onChunk func(chunk CreateModelResponseStream)) error { + reqBody.Stream = PtrOf(true) + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/create", o.baseUrl), bytes.NewReader(reqBodyBytes)) + if err != nil { + return err + } + + for key, val := range o.customHeaders { + req.Header.Set(key, val) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") + } + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + + var chunk CreateModelResponseStream + if err := json.Unmarshal(line, &chunk); err != nil { + return err + } + + onChunk(chunk) + if chunk.Status == "success" { + break + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil } diff --git a/delete_model.go b/delete_model.go index e978dcd..5fc14ba 100644 --- a/delete_model.go +++ b/delete_model.go @@ -12,15 +12,15 @@ type DeleteModelRequest struct { Model string `json:"model"` } -func (o Ollama) DeleteModel(reqBody DeleteModelRequest) (int, error) { +func (o Ollama) DeleteModel(reqBody DeleteModelRequest) error { reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return -1, err + return err } req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/delete", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return -1, err + return err } for key, val := range o.customHeaders { @@ -30,12 +30,12 @@ func (o Ollama) DeleteModel(reqBody DeleteModelRequest) (int, error) { resp, err := http.DefaultClient.Do(req) if err != nil { - return -1, err + return err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") } - return resp.StatusCode, nil + return nil } diff --git a/generate_chat_message.go b/generate_chat_message.go index 3dcdba2..a15c5f9 100644 --- a/generate_chat_message.go +++ b/generate_chat_message.go @@ -1,6 +1,7 @@ package ollama import ( + "bufio" "bytes" "encoding/json" "errors" @@ -96,15 +97,34 @@ type GenerateChatMessageResponse struct { } `json:"logprobs"` } -func (o Ollama) GenerateChatMessage(reqBody GenerateChatMessageRequest) (GenerateChatMessageResponse, int, error) { +type GenerateChatMessageResponseStream struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking string `json:"thinking"` + ToolCalls []struct { + Function struct { + Name string `json:"name"` + Description string `json:"description"` + Arguments map[string]any `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + Images []string `json:"images"` + } `json:"message"` + Done bool `json:"done"` +} + +func (o Ollama) GenerateChatMessage(reqBody GenerateChatMessageRequest) (GenerateChatMessageResponse, error) { reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return GenerateChatMessageResponse{}, -1, err + return GenerateChatMessageResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/chat", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return GenerateChatMessageResponse{}, -1, err + return GenerateChatMessageResponse{}, err } for key, val := range o.customHeaders { @@ -114,17 +134,68 @@ func (o Ollama) GenerateChatMessage(reqBody GenerateChatMessageRequest) (Generat resp, err := http.DefaultClient.Do(req) if err != nil { - return GenerateChatMessageResponse{}, -1, err + return GenerateChatMessageResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return GenerateChatMessageResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return GenerateChatMessageResponse{}, errors.New("status code is not 200") } var respBody GenerateChatMessageResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return GenerateChatMessageResponse{}, -1, err + return GenerateChatMessageResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil +} + +func (o Ollama) GenerateChatMessageStream(reqBody GenerateChatMessageRequest, onChunk func(chunk GenerateChatMessageResponseStream)) error { + reqBody.Stream = PtrOf(true) + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/chat", o.baseUrl), bytes.NewReader(reqBodyBytes)) + if err != nil { + return err + } + + for key, val := range o.customHeaders { + req.Header.Set(key, val) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") + } + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + + var chunk GenerateChatMessageResponseStream + if err := json.Unmarshal(line, &chunk); err != nil { + return err + } + + onChunk(chunk) + if chunk.Done { + return nil + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil } diff --git a/generate_embeddings.go b/generate_embeddings.go index c38c98d..fac3fe0 100644 --- a/generate_embeddings.go +++ b/generate_embeddings.go @@ -36,15 +36,15 @@ type GenerateEmbeddingsResponse struct { PromptEvalCount int `json:"prompt_eval_count"` } -func (o Ollama) GenerateEmbeddings(reqBody GenerateEmbeddingsRequest) (GenerateEmbeddingsResponse, int, error) { +func (o Ollama) GenerateEmbeddings(reqBody GenerateEmbeddingsRequest) (GenerateEmbeddingsResponse, error) { reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return GenerateEmbeddingsResponse{}, -1, err + return GenerateEmbeddingsResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/embed", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return GenerateEmbeddingsResponse{}, -1, err + return GenerateEmbeddingsResponse{}, err } for key, val := range o.customHeaders { @@ -54,17 +54,17 @@ func (o Ollama) GenerateEmbeddings(reqBody GenerateEmbeddingsRequest) (GenerateE resp, err := http.DefaultClient.Do(req) if err != nil { - return GenerateEmbeddingsResponse{}, -1, err + return GenerateEmbeddingsResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return GenerateEmbeddingsResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return GenerateEmbeddingsResponse{}, errors.New("status code is not 200") } var respBody GenerateEmbeddingsResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return GenerateEmbeddingsResponse{}, -1, err + return GenerateEmbeddingsResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil } diff --git a/generate_response.go b/generate_response.go index 0e45663..3545cae 100644 --- a/generate_response.go +++ b/generate_response.go @@ -1,6 +1,7 @@ package ollama import ( + "bufio" "bytes" "encoding/json" "errors" @@ -16,6 +17,7 @@ type GenerateResponseRequest struct { Format string `json:"format,omitempty"` System string `json:"system,omitempty"` Stream *bool `json:"stream,omitempty"` + Think *bool `json:"think,omitempty"` Raw *bool `json:"raw,omitempty"` KeepAlive string `json:"keep_alive,omitempty"` Options *GenerateResponseRequestOptions `json:"options,omitempty"` @@ -59,15 +61,32 @@ type GenerateResponseResponse struct { } `json:"logprobs"` } -func (o Ollama) GenerateResponse(reqBody GenerateResponseRequest) (GenerateResponseResponse, int, error) { +type GenerateResponseResponseStream struct { + Model string `json:"model"` + CreatedAt string `json:"created_at"` + Response string `json:"response"` + Thinking string `json:"thinking"` + Done bool `json:"done"` + DoneReason string `json:"done_reason"` + TotalDuration int `json:"total_duration"` + LoadDuration int `json:"load_duration"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration int `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration int `json:"eval_duration"` +} + +func (o Ollama) GenerateResponse(reqBody GenerateResponseRequest) (GenerateResponseResponse, error) { + reqBody.Stream = PtrOf(false) + reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return GenerateResponseResponse{}, -1, err + return GenerateResponseResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/generate", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return GenerateResponseResponse{}, -1, err + return GenerateResponseResponse{}, err } for key, val := range o.customHeaders { @@ -77,17 +96,68 @@ func (o Ollama) GenerateResponse(reqBody GenerateResponseRequest) (GenerateRespo resp, err := http.DefaultClient.Do(req) if err != nil { - return GenerateResponseResponse{}, -1, err + return GenerateResponseResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return GenerateResponseResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return GenerateResponseResponse{}, errors.New("status code is not 200") } var respBody GenerateResponseResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return GenerateResponseResponse{}, -1, err + return GenerateResponseResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil +} + +func (o Ollama) GenerateResponseStream(reqBody GenerateResponseRequest, onChunk func(chunk GenerateResponseResponseStream)) error { + reqBody.Stream = PtrOf(true) + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/generate", o.baseUrl), bytes.NewReader(reqBodyBytes)) + if err != nil { + return err + } + + for key, val := range o.customHeaders { + req.Header.Set(key, val) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") + } + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + + var chunk GenerateResponseResponseStream + if err := json.Unmarshal(line, &chunk); err != nil { + return err + } + + onChunk(chunk) + if chunk.Done { + return nil + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil } diff --git a/get_version.go b/get_version.go index 97af3ee..1d642fe 100644 --- a/get_version.go +++ b/get_version.go @@ -11,10 +11,10 @@ type GetVersionResponse struct { Version string `json:"version"` } -func (o Ollama) GetVersion() (GetVersionResponse, int, error) { +func (o Ollama) GetVersion() (GetVersionResponse, error) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/version", o.baseUrl), nil) if err != nil { - return GetVersionResponse{}, -1, err + return GetVersionResponse{}, err } for key, val := range o.customHeaders { @@ -23,17 +23,17 @@ func (o Ollama) GetVersion() (GetVersionResponse, int, error) { resp, err := http.DefaultClient.Do(req) if err != nil { - return GetVersionResponse{}, -1, err + return GetVersionResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return GetVersionResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return GetVersionResponse{}, errors.New("status code is not 200") } var respBody GetVersionResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return GetVersionResponse{}, -1, err + return GetVersionResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil } diff --git a/list_models.go b/list_models.go index ec7dd76..8032cd6 100644 --- a/list_models.go +++ b/list_models.go @@ -26,10 +26,10 @@ type ListModelsResponse struct { } `json:"models"` } -func (o Ollama) ListModels() (ListModelsResponse, int, error) { +func (o Ollama) ListModels() (ListModelsResponse, error) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/tags", o.baseUrl), nil) if err != nil { - return ListModelsResponse{}, -1, err + return ListModelsResponse{}, err } for key, val := range o.customHeaders { @@ -38,17 +38,17 @@ func (o Ollama) ListModels() (ListModelsResponse, int, error) { resp, err := http.DefaultClient.Do(req) if err != nil { - return ListModelsResponse{}, -1, err + return ListModelsResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return ListModelsResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return ListModelsResponse{}, errors.New("status code is not 200") } var respBody ListModelsResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return ListModelsResponse{}, -1, err + return ListModelsResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil } diff --git a/list_running_models.go b/list_running_models.go index a8ca066..c553dc9 100644 --- a/list_running_models.go +++ b/list_running_models.go @@ -26,10 +26,10 @@ type ListRunningModelsResponse struct { } `json:"models"` } -func (o Ollama) ListRunningModels() (ListRunningModelsResponse, int, error) { +func (o Ollama) ListRunningModels() (ListRunningModelsResponse, error) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/ps", o.baseUrl), nil) if err != nil { - return ListRunningModelsResponse{}, -1, err + return ListRunningModelsResponse{}, err } for key, val := range o.customHeaders { @@ -38,17 +38,17 @@ func (o Ollama) ListRunningModels() (ListRunningModelsResponse, int, error) { resp, err := http.DefaultClient.Do(req) if err != nil { - return ListRunningModelsResponse{}, -1, err + return ListRunningModelsResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return ListRunningModelsResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return ListRunningModelsResponse{}, errors.New("status code is not 200") } var respBody ListRunningModelsResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return ListRunningModelsResponse{}, -1, err + return ListRunningModelsResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil } diff --git a/pull_model.go b/pull_model.go index f8ebe17..b04272c 100644 --- a/pull_model.go +++ b/pull_model.go @@ -1,6 +1,7 @@ package ollama import ( + "bufio" "bytes" "encoding/json" "errors" @@ -18,15 +19,23 @@ type PullModelResponse struct { Status string `json:"status"` } -func (o Ollama) PullModel(reqBody PullModelRequest) (PullModelResponse, int, error) { +type PullModelResponseStream struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int `json:"total"` + Completed int `json:"completed"` +} + +func (o Ollama) PullModel(reqBody PullModelRequest) (PullModelResponse, error) { + reqBody.Stream = PtrOf(false) reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return PullModelResponse{}, -1, err + return PullModelResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/pull", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return PullModelResponse{}, -1, err + return PullModelResponse{}, err } for key, val := range o.customHeaders { @@ -36,17 +45,67 @@ func (o Ollama) PullModel(reqBody PullModelRequest) (PullModelResponse, int, err resp, err := http.DefaultClient.Do(req) if err != nil { - return PullModelResponse{}, -1, err + return PullModelResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return PullModelResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return PullModelResponse{}, errors.New("status code is not 200") } var respBody PullModelResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return PullModelResponse{}, -1, err + return PullModelResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil +} + +func (o Ollama) PullModelStream(reqBody PullModelRequest, onChunk func(chunk PullModelResponseStream)) error { + reqBody.Stream = PtrOf(true) + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/pull", o.baseUrl), bytes.NewReader(reqBodyBytes)) + if err != nil { + return err + } + + for key, val := range o.customHeaders { + req.Header.Set(key, val) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") + } + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + + var chunk PullModelResponseStream + if err := json.Unmarshal(line, &chunk); err != nil { + return err + } + + onChunk(chunk) + if chunk.Status == "success" { + break + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil } diff --git a/push_model.go b/push_model.go index bdef991..6c9be7f 100644 --- a/push_model.go +++ b/push_model.go @@ -1,6 +1,7 @@ package ollama import ( + "bufio" "bytes" "encoding/json" "errors" @@ -18,15 +19,23 @@ type PushModelResponse struct { Status string `json:"status"` } -func (o Ollama) PushModel(reqBody PushModelRequest) (PushModelResponse, int, error) { +type PushModelResponseStream struct { + Status string `json:"status"` + Digest string `json:"digest"` + Total int `json:"total"` + Completed int `json:"completed"` +} + +func (o Ollama) PushModel(reqBody PushModelRequest) (PushModelResponse, error) { + reqBody.Stream = PtrOf(false) reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return PushModelResponse{}, -1, err + return PushModelResponse{}, err } - req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/pull", o.baseUrl), bytes.NewReader(reqBodyBytes)) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/push", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return PushModelResponse{}, -1, err + return PushModelResponse{}, err } for key, val := range o.customHeaders { @@ -36,17 +45,68 @@ func (o Ollama) PushModel(reqBody PushModelRequest) (PushModelResponse, int, err resp, err := http.DefaultClient.Do(req) if err != nil { - return PushModelResponse{}, -1, err + return PushModelResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return PushModelResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return PushModelResponse{}, errors.New("status code is not 200") } var respBody PushModelResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return PushModelResponse{}, -1, err + return PushModelResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil +} + +func (o Ollama) PushModelStream(reqBody PushModelRequest, onChunk func(chunk PushModelResponseStream)) error { + reqBody.Stream = PtrOf(true) + + reqBodyBytes, err := json.Marshal(reqBody) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/push", o.baseUrl), bytes.NewReader(reqBodyBytes)) + if err != nil { + return err + } + + for key, val := range o.customHeaders { + req.Header.Set(key, val) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("status code is not 200") + } + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + + var chunk PushModelResponseStream + if err := json.Unmarshal(line, &chunk); err != nil { + return err + } + + onChunk(chunk) + if chunk.Status == "success" { + break + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil } diff --git a/show_model_details.go b/show_model_details.go index 4efbc87..4d68e1c 100644 --- a/show_model_details.go +++ b/show_model_details.go @@ -29,15 +29,15 @@ type ShowModelDetailsResponse struct { ModelInfo map[string]any `json:"model_info"` } -func (o Ollama) ShowModelDetails(reqBody ShowModelDetailsRequest) (ShowModelDetailsResponse, int, error) { +func (o Ollama) ShowModelDetails(reqBody ShowModelDetailsRequest) (ShowModelDetailsResponse, error) { reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - return ShowModelDetailsResponse{}, -1, err + return ShowModelDetailsResponse{}, err } req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/show", o.baseUrl), bytes.NewReader(reqBodyBytes)) if err != nil { - return ShowModelDetailsResponse{}, -1, err + return ShowModelDetailsResponse{}, err } for key, val := range o.customHeaders { @@ -47,17 +47,17 @@ func (o Ollama) ShowModelDetails(reqBody ShowModelDetailsRequest) (ShowModelDeta resp, err := http.DefaultClient.Do(req) if err != nil { - return ShowModelDetailsResponse{}, -1, err + return ShowModelDetailsResponse{}, err } defer resp.Body.Close() - if resp.StatusCode != 200 { - return ShowModelDetailsResponse{}, resp.StatusCode, errors.New("status code is not 200") + if resp.StatusCode != http.StatusOK { + return ShowModelDetailsResponse{}, errors.New("status code is not 200") } var respBody ShowModelDetailsResponse if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { - return ShowModelDetailsResponse{}, -1, err + return ShowModelDetailsResponse{}, err } - return respBody, resp.StatusCode, nil + return respBody, nil }