coo7 commited on
Commit
ed5f2a4
·
verified ·
1 Parent(s): 4930319

Upload sse.go

Browse files
Files changed (1) hide show
  1. internal/monica/sse.go +164 -0
internal/monica/sse.go ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package monica
2
+
3
+ import (
4
+ "bufio"
5
+ "fmt"
6
+ "io"
7
+ "log"
8
+ "sync"
9
+ "time"
10
+
11
+ "monica-proxy/internal/types"
12
+ "monica-proxy/internal/utils"
13
+ "net/http"
14
+ "strings"
15
+
16
+ "github.com/bytedance/sonic"
17
+ "github.com/sashabaranov/go-openai"
18
+ )
19
+
20
+ const (
21
+ sseObject = "chat.completion.chunk"
22
+ sseFinish = "[DONE]"
23
+ flushInterval = 100 * time.Millisecond // 刷新间隔
24
+ bufferSize = 4096 // 缓冲区大小
25
+ )
26
+
27
+ // SSEData 用于解析 Monica SSE json
28
+ type SSEData struct {
29
+ Text string `json:"text"`
30
+ Finished bool `json:"finished"`
31
+ }
32
+
33
+ var sseDataPool = sync.Pool{
34
+ New: func() interface{} {
35
+ return &SSEData{}
36
+ },
37
+ }
38
+
39
+ // StreamMonicaSSEToClient 将 Monica SSE 转成前端可用的流
40
+ func StreamMonicaSSEToClient(model string, w io.Writer, r io.Reader) error {
41
+ reader := bufio.NewReaderSize(r, bufferSize)
42
+ writer := bufio.NewWriterSize(w, bufferSize)
43
+ defer writer.Flush()
44
+
45
+ chatId := utils.RandStringUsingMathRand(29)
46
+ now := time.Now().Unix()
47
+ fingerprint := utils.RandStringUsingMathRand(10)
48
+
49
+ // 创建一个定时刷新的 ticker
50
+ ticker := time.NewTicker(flushInterval)
51
+ defer ticker.Stop()
52
+
53
+ // 创建一个 done channel 用于清理
54
+ done := make(chan struct{})
55
+ defer close(done)
56
+
57
+ // 启动一个 goroutine 定期刷新缓冲区
58
+ go func() {
59
+ for {
60
+ select {
61
+ case <-ticker.C:
62
+ if f, ok := w.(http.Flusher); ok {
63
+ writer.Flush()
64
+ f.Flush()
65
+ }
66
+ case <-done:
67
+ return
68
+ }
69
+ }
70
+ }()
71
+
72
+ for {
73
+ line, err := reader.ReadString('\n')
74
+ if err != nil {
75
+ if err == io.EOF {
76
+ return nil
77
+ }
78
+ return fmt.Errorf("read error: %w", err)
79
+ }
80
+
81
+ // Monica SSE 的行前缀一般是 "data: "
82
+ if !strings.HasPrefix(line, "data: ") {
83
+ continue
84
+ }
85
+
86
+ jsonStr := strings.TrimPrefix(line, "data: ")
87
+ if jsonStr == "" {
88
+ continue
89
+ }
90
+
91
+ // 从对象池获取 SSEData
92
+ sseObj := sseDataPool.Get().(*SSEData)
93
+ if err := sonic.UnmarshalString(jsonStr, sseObj); err != nil {
94
+ sseDataPool.Put(sseObj)
95
+ // 记录错误但继续处理
96
+ log.Printf("Error unmarshaling SSE data: %v", err)
97
+ continue
98
+ }
99
+
100
+ // 将拆分后的文字写回
101
+ var sseMsg types.ChatCompletionStreamResponse
102
+ if sseObj.Finished {
103
+ sseMsg = types.ChatCompletionStreamResponse{
104
+ ID: "chatcmpl-" + chatId,
105
+ Object: sseObject,
106
+ Created: now,
107
+ Model: model,
108
+ Choices: []types.ChatCompletionStreamChoice{
109
+ {
110
+ Index: 0,
111
+ Delta: openai.ChatCompletionStreamChoiceDelta{
112
+ Role: openai.ChatMessageRoleAssistant,
113
+ },
114
+ FinishReason: openai.FinishReasonStop,
115
+ },
116
+ },
117
+ }
118
+ } else {
119
+ sseMsg = types.ChatCompletionStreamResponse{
120
+ ID: "chatcmpl-" + chatId,
121
+ Object: sseObject,
122
+ SystemFingerprint: fingerprint,
123
+ Created: now,
124
+ Model: model,
125
+ Choices: []types.ChatCompletionStreamChoice{
126
+ {
127
+ Index: 0,
128
+ Delta: openai.ChatCompletionStreamChoiceDelta{
129
+ Role: openai.ChatMessageRoleAssistant,
130
+ Content: sseObj.Text,
131
+ },
132
+ FinishReason: openai.FinishReasonNull,
133
+ },
134
+ },
135
+ }
136
+ }
137
+
138
+ var sb strings.Builder
139
+ sb.WriteString("data: ")
140
+ sendLine, _ := sonic.MarshalString(sseMsg)
141
+ sb.WriteString(sendLine)
142
+ sb.WriteString("\n\n")
143
+
144
+ // 写入缓冲区
145
+ if _, err := writer.WriteString(sb.String()); err != nil {
146
+ sseDataPool.Put(sseObj)
147
+ return fmt.Errorf("write error: %w", err)
148
+ }
149
+
150
+ // 如果发现 finished=true,就可以结束
151
+ if sseObj.Finished {
152
+ writer.WriteString(fmt.Sprintf("data: %s\n\n", sseFinish))
153
+ writer.Flush()
154
+ if f, ok := w.(http.Flusher); ok {
155
+ f.Flush()
156
+ }
157
+ sseDataPool.Put(sseObj)
158
+ return nil
159
+ }
160
+
161
+ // 将对象放回对象池
162
+ sseDataPool.Put(sseObj)
163
+ }
164
+ }