Spaces:
Paused
Paused
| import os, gc | |
| from typing import AsyncGenerator | |
| from rwkv.model import RWKV | |
| from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
| from asyncio import sleep | |
| class Answerer: | |
| def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int): | |
| os.environ["RWKV_JIT_ON"] = "1" | |
| # os.environ["RWKV_CUDA_ON"] = "1" | |
| self.__model = RWKV(f"models/{model}.pth", strategy=strategy) | |
| self.__pipeline = PIPELINE(self.__model, vocab) | |
| self.ctx_limit = ctx_limit | |
| async def __call__( | |
| self, | |
| input: str, | |
| max_output_length_tk: int, | |
| chaos = .1, | |
| repetitiveness = .3, | |
| diversity = 0, | |
| _count_penalty = 1, | |
| ) -> AsyncGenerator[str, None]: | |
| args = PIPELINE_ARGS( | |
| temperature=chaos, | |
| top_p=repetitiveness, | |
| alpha_frequency=_count_penalty, | |
| alpha_presence=diversity, | |
| token_ban = [], | |
| token_stop = [0], | |
| ) | |
| input = input.strip() | |
| result: str = "" | |
| occurrences: dict[int, int] = {} | |
| tokens: list[int] = [] | |
| current_token = None | |
| state = None | |
| for _ in range(max_output_length_tk): | |
| out, state = self.__model.forward( | |
| [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:], | |
| state, | |
| ) | |
| for token in occurrences: | |
| out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency | |
| current_token = self.__pipeline.sample_logits( | |
| out, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| ) | |
| if current_token in args.token_stop: break | |
| tokens.append(current_token) | |
| for token in occurrences: | |
| occurrences[token] *= 0.996 | |
| if current_token in occurrences: | |
| occurrences[current_token] += 1 | |
| else: | |
| occurrences[current_token] = 1 | |
| tmp: str = self.__pipeline.decode(tokens) | |
| if "\ufffd" not in tmp: | |
| tokens.clear() | |
| result += tmp | |
| yield result | |
| await sleep(.02) | |
| tokens.clear() | |
| occurrences.clear() | |
| del out, tmp | |
| del occurrences, tokens, current_token, state | |
| gc.collect() | |