Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	:gem: [Feature] Enable stream output of chat completions
Browse files- tests/openai.py +38 -12
    	
        tests/openai.py
    CHANGED
    
    | @@ -1,4 +1,6 @@ | |
| 1 | 
             
            import copy
         | 
|  | |
|  | |
| 2 | 
             
            import uuid
         | 
| 3 |  | 
| 4 | 
             
            from pathlib import Path
         | 
| @@ -46,17 +48,16 @@ class OpenaiAPI: | |
| 46 | 
             
                            "http": http_proxy,
         | 
| 47 | 
             
                            "https": http_proxy,
         | 
| 48 | 
             
                        }
         | 
|  | |
|  | |
| 49 | 
             
                    else:
         | 
| 50 | 
             
                        self.requests_proxies = None
         | 
| 51 |  | 
| 52 | 
             
                def log_request(self, url, method="GET"):
         | 
| 53 | 
            -
                    if ENVER["http_proxy"]:
         | 
| 54 | 
            -
                        logger.note(f"> Using Proxy:", end=" ")
         | 
| 55 | 
            -
                        logger.mesg(f"{ENVER['http_proxy']}")
         | 
| 56 | 
             
                    logger.note(f"> {method}:", end=" ")
         | 
| 57 | 
             
                    logger.mesg(f"{url}", end=" ")
         | 
| 58 |  | 
| 59 | 
            -
                def log_response(self, res: requests.Response, stream=False):
         | 
| 60 | 
             
                    status_code = res.status_code
         | 
| 61 | 
             
                    status_code_str = f"[{status_code}]"
         | 
| 62 |  | 
| @@ -64,12 +65,35 @@ class OpenaiAPI: | |
| 64 | 
             
                        logger_func = logger.success
         | 
| 65 | 
             
                    else:
         | 
| 66 | 
             
                        logger_func = logger.warn
         | 
|  | |
| 67 | 
             
                    logger_func(status_code_str)
         | 
| 68 |  | 
| 69 | 
            -
                    if  | 
| 70 | 
            -
                         | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 73 |  | 
| 74 | 
             
                def get_models(self):
         | 
| 75 | 
             
                    self.log_request(self.api_models)
         | 
| @@ -111,7 +135,7 @@ class OpenaiAPI: | |
| 111 | 
             
                                "metadata": {},
         | 
| 112 | 
             
                            }
         | 
| 113 | 
             
                        ],
         | 
| 114 | 
            -
                        "parent_message_id":  | 
| 115 | 
             
                        "model": "text-davinci-002-render-sha",
         | 
| 116 | 
             
                        "timezone_offset_min": -480,
         | 
| 117 | 
             
                        "suggestions": [],
         | 
| @@ -124,22 +148,24 @@ class OpenaiAPI: | |
| 124 | 
             
                        "websocket_request_id": str(uuid.uuid4()),
         | 
| 125 | 
             
                    }
         | 
| 126 | 
             
                    self.log_request(self.api_conversation, method="POST")
         | 
| 127 | 
            -
                     | 
|  | |
| 128 | 
             
                        self.api_conversation,
         | 
| 129 | 
             
                        headers=requests_headers,
         | 
| 130 | 
             
                        json=post_data,
         | 
| 131 | 
             
                        proxies=self.requests_proxies,
         | 
| 132 | 
             
                        timeout=10,
         | 
| 133 | 
             
                        impersonate="chrome120",
         | 
|  | |
| 134 | 
             
                    )
         | 
| 135 | 
            -
                    self.log_response(res, stream=True)
         | 
| 136 |  | 
| 137 |  | 
| 138 | 
             
            if __name__ == "__main__":
         | 
| 139 | 
             
                api = OpenaiAPI()
         | 
| 140 | 
             
                # api.get_models()
         | 
| 141 | 
             
                api.auth()
         | 
| 142 | 
            -
                prompt = " | 
| 143 | 
             
                api.chat_completions(prompt)
         | 
| 144 |  | 
| 145 | 
             
                # python -m tests.openai
         | 
|  | |
| 1 | 
             
            import copy
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import re
         | 
| 4 | 
             
            import uuid
         | 
| 5 |  | 
| 6 | 
             
            from pathlib import Path
         | 
|  | |
| 48 | 
             
                            "http": http_proxy,
         | 
| 49 | 
             
                            "https": http_proxy,
         | 
| 50 | 
             
                        }
         | 
| 51 | 
            +
                        logger.note(f"> Using Proxy:", end=" ")
         | 
| 52 | 
            +
                        logger.mesg(f"{ENVER['http_proxy']}")
         | 
| 53 | 
             
                    else:
         | 
| 54 | 
             
                        self.requests_proxies = None
         | 
| 55 |  | 
| 56 | 
             
                def log_request(self, url, method="GET"):
         | 
|  | |
|  | |
|  | |
| 57 | 
             
                    logger.note(f"> {method}:", end=" ")
         | 
| 58 | 
             
                    logger.mesg(f"{url}", end=" ")
         | 
| 59 |  | 
| 60 | 
            +
                def log_response(self, res: requests.Response, stream=False, verbose=False):
         | 
| 61 | 
             
                    status_code = res.status_code
         | 
| 62 | 
             
                    status_code_str = f"[{status_code}]"
         | 
| 63 |  | 
|  | |
| 65 | 
             
                        logger_func = logger.success
         | 
| 66 | 
             
                    else:
         | 
| 67 | 
             
                        logger_func = logger.warn
         | 
| 68 | 
            +
             | 
| 69 | 
             
                    logger_func(status_code_str)
         | 
| 70 |  | 
| 71 | 
            +
                    if verbose:
         | 
| 72 | 
            +
                        if stream:
         | 
| 73 | 
            +
                            if not hasattr(self, "content_offset"):
         | 
| 74 | 
            +
                                self.content_offset = 0
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                            for line in res.iter_lines():
         | 
| 77 | 
            +
                                line = line.decode("utf-8")
         | 
| 78 | 
            +
                                line = re.sub(r"^data:\s*", "", line)
         | 
| 79 | 
            +
                                if re.match(r"^\[DONE\]", line):
         | 
| 80 | 
            +
                                    logger.success("\n[Finished]")
         | 
| 81 | 
            +
                                    break
         | 
| 82 | 
            +
                                line = line.strip()
         | 
| 83 | 
            +
                                if line:
         | 
| 84 | 
            +
                                    try:
         | 
| 85 | 
            +
                                        data = json.loads(line, strict=False)
         | 
| 86 | 
            +
                                        role = data["message"]["author"]["role"]
         | 
| 87 | 
            +
                                        if role != "assistant":
         | 
| 88 | 
            +
                                            continue
         | 
| 89 | 
            +
                                        content = data["message"]["content"]["parts"][0]
         | 
| 90 | 
            +
                                        delta_content = content[self.content_offset :]
         | 
| 91 | 
            +
                                        self.content_offset = len(content)
         | 
| 92 | 
            +
                                        logger_func(delta_content, end="")
         | 
| 93 | 
            +
                                    except Exception as e:
         | 
| 94 | 
            +
                                        logger.warn(e)
         | 
| 95 | 
            +
                        else:
         | 
| 96 | 
            +
                            logger_func(res.json())
         | 
| 97 |  | 
| 98 | 
             
                def get_models(self):
         | 
| 99 | 
             
                    self.log_request(self.api_models)
         | 
|  | |
| 135 | 
             
                                "metadata": {},
         | 
| 136 | 
             
                            }
         | 
| 137 | 
             
                        ],
         | 
| 138 | 
            +
                        "parent_message_id": "",
         | 
| 139 | 
             
                        "model": "text-davinci-002-render-sha",
         | 
| 140 | 
             
                        "timezone_offset_min": -480,
         | 
| 141 | 
             
                        "suggestions": [],
         | 
|  | |
| 148 | 
             
                        "websocket_request_id": str(uuid.uuid4()),
         | 
| 149 | 
             
                    }
         | 
| 150 | 
             
                    self.log_request(self.api_conversation, method="POST")
         | 
| 151 | 
            +
                    s = requests.Session()
         | 
| 152 | 
            +
                    res = s.post(
         | 
| 153 | 
             
                        self.api_conversation,
         | 
| 154 | 
             
                        headers=requests_headers,
         | 
| 155 | 
             
                        json=post_data,
         | 
| 156 | 
             
                        proxies=self.requests_proxies,
         | 
| 157 | 
             
                        timeout=10,
         | 
| 158 | 
             
                        impersonate="chrome120",
         | 
| 159 | 
            +
                        stream=True,
         | 
| 160 | 
             
                    )
         | 
| 161 | 
            +
                    self.log_response(res, stream=True, verbose=True)
         | 
| 162 |  | 
| 163 |  | 
| 164 | 
             
            if __name__ == "__main__":
         | 
| 165 | 
             
                api = OpenaiAPI()
         | 
| 166 | 
             
                # api.get_models()
         | 
| 167 | 
             
                api.auth()
         | 
| 168 | 
            +
                prompt = "你的名字?"
         | 
| 169 | 
             
                api.chat_completions(prompt)
         | 
| 170 |  | 
| 171 | 
             
                # python -m tests.openai
         | 
