~bsprague/gpt4-tui

b6cb7670e8c1a7d4ac3ea871d0cfe151ff8d10ac — Brandon Sprague 3 months ago b1f1f15
Add Claude support, update history format, better summary prompt

A large refactoring to support multiple backends with similar, but not identical, APIs. Backend selection is now done with `--backend={gpt4,claude}`, the `--api_key` parameter should point to the correct API key for the given service.

The history file has been changes from `titles` to `history`, and is now JSON-object-per-line formatted so that more metadata can be stored (e.g. which bot you chatted with)

The summary prompt has also been updated, as it didn't work well on Claude and frequently added "quotes" on both GPT-4 and Claude 3
2 files changed, 311 insertions(+), 51 deletions(-)

M gpt/gpt.go
M main.go
M gpt/gpt.go => gpt/gpt.go +290 -48
@@ 15,15 15,19 @@ import (
	"time"
)

const API_ENDPOINT = "https://api.openai.com/v1/chat/completions"

var dataPrefix = []byte("data: ")
var doneMessage = []byte("[DONE]")
var (
	eventPrefix     = []byte("event: ")
	dataPrefix      = []byte("data: ")
	gpt4DoneMessage = []byte("[DONE]")
	claudeDoneEvent = []byte("message_stop")
)

type Client struct {
	apiKey  string
	dataDir string

	bot Bot

	pastConvos []PastConvo

	// Current convo information


@@ 35,8 39,157 @@ type Client struct {
	dirty    bool // This convo has been edited.
}

func NewClient(apiKey, dataDir string) (*Client, error) {
	c := &Client{apiKey: apiKey, dataDir: dataDir}
type Bot struct {
	name    string
	company string

	endpoint          string
	model             string
	summaryModel      string
	maxTokens         int
	keyHeader         string
	keyValue          func(key string) string
	additionalHeaders map[string]string
	parser            func(io.Reader) Parser
}

type Parser interface {
	Next() (string, error)
}

type gpt4Parser struct {
	r *EventStreamReader
}

func (g *gpt4Parser) Next() (string, error) {
	dat, err := g.r.ReadEvent()
	if errors.Is(err, io.EOF) {
		return "", io.EOF
	} else if err != nil {
		return "", fmt.Errorf("failed to read response: %w", err)
	}
	if !bytes.HasPrefix(dat, dataPrefix) {
		return "", fmt.Errorf("malformed message %q", string(dat))
	}
	if bytes.Equal(dat[len(dataPrefix):], gpt4DoneMessage) {
		return "", io.EOF
	}

	var chunk Chunk
	if err := json.Unmarshal(dat[len(dataPrefix):], &chunk); err != nil {
		return "", fmt.Errorf("failed to parse chunk: %w", err)
	}
	return chunk.Choices[0].Delta.Content, nil
}

func GPT4Parser(r io.Reader) Parser {
	return &gpt4Parser{
		r: NewEventStreamReader(r, 8192),
	}
}

type claudeParser struct {
	r *EventStreamReader
}

func (g *claudeParser) Next() (string, error) {
	for {
		dat, err := g.r.ReadEvent()
		if errors.Is(err, io.EOF) {
			return "", io.EOF
		} else if err != nil {
			return "", fmt.Errorf("failed to read response: %w", err)
		}
		if !bytes.HasPrefix(dat, eventPrefix) {
			return "", fmt.Errorf("malformed message %q", string(dat))
		}
		newline := bytes.IndexByte(dat, '\n')
		if newline == -1 {
			return "", errors.New("event had no newline endpoint")
		}
		evt := dat[len(eventPrefix):newline]

		if bytes.Equal(evt, []byte("message_stop")) {
			return "", io.EOF
		}
		if !bytes.Equal(evt, []byte("content_block_delta")) {
			// A ping, or content_block_start, or message_delta, or something else we can _probably_ ignore
			// See https://docs.anthropic.com/claude/reference/messages-streaming
			continue
		}

		msg := dat[newline+1:]
		if !bytes.HasPrefix(msg, dataPrefix) {
			return "", fmt.Errorf("malformed data %q", string(msg))
		}

		var data ClaudeData
		if err := json.Unmarshal(msg[len(dataPrefix):], &data); err != nil {
			return "", fmt.Errorf("failed to parse chunk: %w", err)
		}
		if data.Type != "content_block_delta" {
			return "", fmt.Errorf("unexpected data type %q", data.Type)
		}
		if data.Delta.Type != "text_delta" {
			return "", fmt.Errorf("unexpected delta type %q", data.Delta.Type)
		}
		return data.Delta.Text, nil
	}
}

type ClaudeDelta struct {
	Type string `json:"type"`
	Text string `json:"text"`
}

type ClaudeData struct {
	Type  string      `json:"type"`
	Index int         `json:"index"`
	Delta ClaudeDelta `json:"delta"`
}

func ClaudeParser(r io.Reader) Parser {
	return &claudeParser{
		r: NewEventStreamReader(r, 8192),
	}
}

var GPT4 = Bot{
	name:         "OpenAI's GPT-4",
	company:      "OpenAI",
	endpoint:     "https://api.openai.com/v1/chat/completions",
	model:        "gpt-4-1106-preview", // "gpt-4"
	summaryModel: "gpt-4-1106-preview", // "gpt-4"
	keyHeader:    "Authorization",
	keyValue:     func(key string) string { return "Bearer " + key },
	additionalHeaders: map[string]string{
		"User-Agent": "OpenAI-Chat-CLI",
	},
	parser: GPT4Parser,
}

var Claude = Bot{
	name:         "Anthropic's Claude 3",
	company:      "Anthropic",
	endpoint:     "https://api.anthropic.com/v1/messages",
	model:        "claude-3-opus-20240229",
	summaryModel: "claude-3-sonnet-20240229",
	keyHeader:    "x-api-key",
	maxTokens:    4096,
	keyValue:     func(key string) string { return key },
	additionalHeaders: map[string]string{
		"anthropic-version": "2023-06-01",
		"anthropic-beta":    "messages-2023-12-15",
	},
	parser: ClaudeParser,
}

func (c *Client) Backend() string {
	return c.bot.name
}

func NewClient(apiKey, dataDir string, bot Bot) (*Client, error) {
	c := &Client{apiKey: apiKey, dataDir: dataDir, bot: bot}
	pcs, err := c.loadPastConvos()
	if err != nil {
		return nil, fmt.Errorf("failed to load existing convos: %w", err)


@@ 65,9 218,10 @@ type Message struct {
}

type Payload struct {
	Model    string    `json:"model"`
	Messages []Message `json:"messages"`
	Stream   bool      `json:"stream"`
	Model     string    `json:"model"`
	MaxTokens int       `json:"max_tokens,omitempty"`
	Messages  []Message `json:"messages"`
	Stream    bool      `json:"stream"`
}

type Delta struct {


@@ 89,12 243,13 @@ func (c *Client) GetTitle(respHandler func(msg string)) error {
	}

	data := Payload{
		Model:  "gpt-4-1106-preview", // "gpt-4"
		Stream: true,
		Model:     c.bot.summaryModel,
		Stream:    true,
		MaxTokens: c.bot.maxTokens,
		Messages: []Message{
			{
				Role: "user",
				Content: fmt.Sprintf(`Summarize the following conversation prompt into a short title.
				Content: fmt.Sprintf(`Summarize the following conversation prompt into a short title. Do not wrap in quotes, do not include anything else in your response except the short title.

"""
%s


@@ 119,9 274,10 @@ func (c *Client) Chat(prompt string, respHandler func(msg string)) error {
		Content: prompt,
	})
	data := Payload{
		Model:    "gpt-4",
		Stream:   true,
		Messages: c.messages,
		Model:     c.bot.model,
		Stream:    true,
		MaxTokens: c.bot.maxTokens,
		Messages:  c.messages,
	}
	msg, err := c.sendRequest(data, respHandler)
	if err != nil {


@@ 140,14 296,16 @@ func (c *Client) sendRequest(data Payload, respHandler func(msg string)) (string
		return "", fmt.Errorf("failed to marshal payload: %w", err)
	}

	req, err := http.NewRequest(http.MethodPost, API_ENDPOINT, bytes.NewReader(payloadBytes))
	req, err := http.NewRequest(http.MethodPost, c.bot.endpoint, bytes.NewReader(payloadBytes))
	if err != nil {
		return "", fmt.Errorf("failed to send request to OpenAI API: %w", err)
	}

	req.Header.Set("Authorization", "Bearer "+c.apiKey)
	req.Header.Set(c.bot.keyHeader, c.bot.keyValue(c.apiKey))
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("User-Agent", "OpenAI-Chat-CLI")
	for k, v := range c.bot.additionalHeaders {
		req.Header.Set(k, v)
	}

	client := &http.Client{}
	resp, err := client.Do(req)


@@ 156,27 314,19 @@ func (c *Client) sendRequest(data Payload, respHandler func(msg string)) (string
	}
	defer resp.Body.Close()

	r := NewEventStreamReader(resp.Body, 8192)
	r := c.bot.parser(resp.Body)
	var msg strings.Builder
	for {
		dat, err := r.ReadEvent()
		content, err := r.Next()
		if errors.Is(err, io.EOF) {
			break
		} else if err != nil {
			return "", fmt.Errorf("failed to read response: %w", err)
			return "", fmt.Errorf("error during parsing: %w", err)
		}
		if !bytes.HasPrefix(dat, dataPrefix) {
			return "", fmt.Errorf("malformed message %q", string(dat))
		}
		if bytes.Equal(dat[len(dataPrefix):], doneMessage) {
			break
		if content == "" {
			continue
		}

		var chunk Chunk
		if err := json.Unmarshal(dat[len(dataPrefix):], &chunk); err != nil {
			return "", fmt.Errorf("failed to parse chunk: %w", err)
		}
		content := chunk.Choices[0].Delta.Content
		msg.WriteString(content)
		respHandler(content)
	}


@@ 288,6 438,10 @@ func (c *Client) titlesPath() string {
	return filepath.Join(c.dataDir, "titles")
}

func (c *Client) historyPath() string {
	return filepath.Join(c.dataDir, "history")
}

type ConvoID string

type PastConvo struct {


@@ 296,49 450,137 @@ type PastConvo struct {
	At    time.Time
}

func (c *Client) loadPastConvos() ([]PastConvo, error) {
	f, err := os.Open(c.titlesPath())
	if errors.Is(err, fs.ErrNotExist) {
		return []PastConvo{}, nil
	} else if err != nil {
		return nil, fmt.Errorf("failed to open titles file: %w", err)
	}
	defer f.Close()

	sc := bufio.NewScanner(f)
func loadPastConvos_OldFormat(r io.Reader) ([]PastConvo, error) {
	sc := bufio.NewScanner(r)
	var out []PastConvo
	for sc.Scan() {
		path, title, ok := strings.Cut(sc.Text(), " ---- ")
		dat := sc.Bytes()
		path, title, ok := bytes.Cut(dat, []byte(" ---- "))
		if !ok {
			return nil, fmt.Errorf("malformed title file had line %q without delimiter", sc.Text())
		}
		at, err := time.Parse("2006-01-02_15-04-05", strings.TrimSuffix(path, "-convo.json"))
		at, err := time.Parse("2006-01-02_15-04-05", string(bytes.TrimSuffix(path, []byte("-convo.json"))))
		if err != nil {
			return nil, fmt.Errorf("failed to parse time from convo path %q: %w", path, err)
		}
		out = append(out, PastConvo{
			ID:    ConvoID(path),
			Title: title,
			Title: string(title),
			At:    at,
		})
	}
	return out, nil
}

func (c *Client) loadPastConvos_NewFormat() ([]PastConvo, error) {
	f, err := os.Open(c.historyPath())
	if errors.Is(err, fs.ErrNotExist) {
		return []PastConvo{}, nil
	} else if err != nil {
		return nil, fmt.Errorf("failed to open history file: %w", err)
	}
	defer f.Close()

	sc := bufio.NewScanner(f)
	var out []PastConvo
	for sc.Scan() {
		var md convoMetadata
		dat := sc.Bytes()
		if err := json.Unmarshal(dat, &md); err != nil {
			return nil, fmt.Errorf("failed to unmarshal convo metadata: %w", err)
		}

		out = append(out, PastConvo{
			ID:    md.Filename,
			Title: md.Title,
			At:    md.At,
		})
	}
	return out, nil
}

func (c *Client) loadPastConvos() ([]PastConvo, error) {
	f, err := os.Open(c.titlesPath())
	if errors.Is(err, fs.ErrNotExist) {
		return c.loadPastConvos_NewFormat()
	} else if err != nil {
		return nil, fmt.Errorf("failed to open titles file: %w", err)
	}
	defer f.Close()

	// Load the old format, convert to the new format
	pcs, err := loadPastConvos_OldFormat(f)
	if err != nil {
		return nil, fmt.Errorf("failed to load convos from 'titles' file: %w", err)
	}

	// Close the old file.
	if err := f.Close(); err != nil {
		return nil, fmt.Errorf("failed to close title file: %w", err)
	}

	return out, nil
	hf, err := os.Create(c.historyPath())
	if err != nil {
		return nil, fmt.Errorf("failed to truncate title file: %w", err)
	}
	defer hf.Close() // Best-effort

	for _, pc := range pcs {
		md := convoMetadata{
			Filename: pc.ID,
			Title:    pc.Title,
			At:       pc.At,
			// Just assume old things were all GPT-4
			Backend: "OpenAI",
			Model:   "gpt-4",
		}
		if err := json.NewEncoder(hf).Encode(md); err != nil {
			return nil, fmt.Errorf("failed to write convo metadata to file during conversion: %w", err)
		}
	}

	if err := hf.Close(); err != nil {
		return nil, fmt.Errorf("failed to close history file")
	}

	// Delete the titles file
	if err := os.Remove(c.titlesPath()); err != nil {
		return nil, fmt.Errorf("failed to delete titles file: %w", err)
	}

	return pcs, nil
}

type convoMetadata struct {
	Filename ConvoID   `json:"filename"`
	Title    string    `json:"title"`
	Backend  string    `json:"backend"`
	Model    string    `json:"model"`
	At       time.Time `json:"at"`
}

func (c *Client) addTitle(id ConvoID) error {
	f, err := os.OpenFile(c.titlesPath(), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
	f, err := os.OpenFile(c.historyPath(), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
	if err != nil {
		return fmt.Errorf("failed to open or create titles file: %w", err)
	}
	defer f.Close() // Best-effort

	if _, err := io.WriteString(f, string(id)+" ---- "+c.title+"\n"); err != nil {
		return fmt.Errorf("failed to write title: %w", err)
	at, err := time.Parse("2006-01-02_15-04-05", strings.TrimSuffix(string(id), "-convo.json"))
	if err != nil {
		return fmt.Errorf("failed to parse ID as time: %w", err)
	}

	md := convoMetadata{
		Filename: id,
		Title:    c.title,
		Backend:  c.bot.company,
		Model:    c.bot.model,
		At:       at,
	}

	if err := json.NewEncoder(f).Encode(md); err != nil {
		return fmt.Errorf("failed to write convo metadata to file: %w", err)
	}

	if err := f.Close(); err != nil {

M main.go => main.go +21 -3
@@ 46,6 46,7 @@ var (
)

type Chatter interface {
	Backend() string
	SetConvo(id gpt.ConvoID) ([]gpt.Message, error)
	GetTitle(respHandler func(msg string)) error
	Chat(prompt string, respHandler func(msg string)) error


@@ 65,6 66,10 @@ func newTestChatter(client *gpt.Client) (*testChatter, error) {
	}, nil
}

func (t *testChatter) Backend() string {
	return "Test"
}

func (t *testChatter) SetConvo(id gpt.ConvoID) ([]gpt.Message, error) {
	msgs, err := t.gpt.SetConvo(id)
	if err != nil {


@@ 186,7 191,7 @@ func initialModel(chatter Chatter, pastConvos []gpt.PastConvo) *model {
	return &model{
		convoList: convoList,
		page:      chatPage,
		convoName: "New Conversation",
		convoName: "New Conversation with " + chatter.Backend(),
		chatter:   chatter,
		textarea:  ti,
		err:       nil,


@@ 471,7 476,8 @@ func (m *model) View() string {
func run(args []string) error {
	fs := flag.NewFlagSet(args[0], flag.ContinueOnError)
	var (
		apiKey   = fs.String("api_key", "", "The API key to use with the OpenAI API")
		apiKey   = fs.String("api_key", "", "The API key to use with the backend API")
		backend  = fs.String("backend", "", "The name of the backend to use, either 'gpt4' or 'claude'")
		testMode = fs.Bool("test_mode", false, "If true, don't send any real messages.")
	)
	if err := fs.Parse(args[1:]); err != nil {


@@ 488,7 494,19 @@ func run(args []string) error {
	}

	dataDir := filepath.Join(u.HomeDir, ".local", "share", "gpt4-client")
	client, err := gpt.NewClient(*apiKey, dataDir)

	var bot gpt.Bot
	switch *backend {
	case "gpt4":
		bot = gpt.GPT4
	case "claude":
		bot = gpt.Claude
	case "":
		return errors.New("no --backend specified, use 'gpt4' or 'claude'")
	default:
		return fmt.Errorf("unknown backend %q", *backend)
	}
	client, err := gpt.NewClient(*apiKey, dataDir, bot)
	if err != nil {
		return fmt.Errorf("failed to init GPT client: %w", err)
	}