File size: 3,684 Bytes
ed5f2a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
package monica
import (
"bufio"
"fmt"
"io"
"log"
"sync"
"time"
"monica-proxy/internal/types"
"monica-proxy/internal/utils"
"net/http"
"strings"
"github.com/bytedance/sonic"
"github.com/sashabaranov/go-openai"
)
const (
sseObject = "chat.completion.chunk"
sseFinish = "[DONE]"
flushInterval = 100 * time.Millisecond // 刷新间隔
bufferSize = 4096 // 缓冲区大小
)
// SSEData 用于解析 Monica SSE json
type SSEData struct {
Text string `json:"text"`
Finished bool `json:"finished"`
}
var sseDataPool = sync.Pool{
New: func() interface{} {
return &SSEData{}
},
}
// StreamMonicaSSEToClient 将 Monica SSE 转成前端可用的流
func StreamMonicaSSEToClient(model string, w io.Writer, r io.Reader) error {
reader := bufio.NewReaderSize(r, bufferSize)
writer := bufio.NewWriterSize(w, bufferSize)
defer writer.Flush()
chatId := utils.RandStringUsingMathRand(29)
now := time.Now().Unix()
fingerprint := utils.RandStringUsingMathRand(10)
// 创建一个定时刷新的 ticker
ticker := time.NewTicker(flushInterval)
defer ticker.Stop()
// 创建一个 done channel 用于清理
done := make(chan struct{})
defer close(done)
// 启动一个 goroutine 定期刷新缓冲区
go func() {
for {
select {
case <-ticker.C:
if f, ok := w.(http.Flusher); ok {
writer.Flush()
f.Flush()
}
case <-done:
return
}
}
}()
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("read error: %w", err)
}
// Monica SSE 的行前缀一般是 "data: "
if !strings.HasPrefix(line, "data: ") {
continue
}
jsonStr := strings.TrimPrefix(line, "data: ")
if jsonStr == "" {
continue
}
// 从对象池获取 SSEData
sseObj := sseDataPool.Get().(*SSEData)
if err := sonic.UnmarshalString(jsonStr, sseObj); err != nil {
sseDataPool.Put(sseObj)
// 记录错误但继续处理
log.Printf("Error unmarshaling SSE data: %v", err)
continue
}
// 将拆分后的文字写回
var sseMsg types.ChatCompletionStreamResponse
if sseObj.Finished {
sseMsg = types.ChatCompletionStreamResponse{
ID: "chatcmpl-" + chatId,
Object: sseObject,
Created: now,
Model: model,
Choices: []types.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: openai.ChatMessageRoleAssistant,
},
FinishReason: openai.FinishReasonStop,
},
},
}
} else {
sseMsg = types.ChatCompletionStreamResponse{
ID: "chatcmpl-" + chatId,
Object: sseObject,
SystemFingerprint: fingerprint,
Created: now,
Model: model,
Choices: []types.ChatCompletionStreamChoice{
{
Index: 0,
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: openai.ChatMessageRoleAssistant,
Content: sseObj.Text,
},
FinishReason: openai.FinishReasonNull,
},
},
}
}
var sb strings.Builder
sb.WriteString("data: ")
sendLine, _ := sonic.MarshalString(sseMsg)
sb.WriteString(sendLine)
sb.WriteString("\n\n")
// 写入缓冲区
if _, err := writer.WriteString(sb.String()); err != nil {
sseDataPool.Put(sseObj)
return fmt.Errorf("write error: %w", err)
}
// 如果发现 finished=true,就可以结束
if sseObj.Finished {
writer.WriteString(fmt.Sprintf("data: %s\n\n", sseFinish))
writer.Flush()
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
sseDataPool.Put(sseObj)
return nil
}
// 将对象放回对象池
sseDataPool.Put(sseObj)
}
}
|