File size: 1,386 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import math
from dataclasses import dataclass, field
from typing import List, Union, Optional


@dataclass
class Token:
    text: str
    prob: float
    top_candidates: List = field(default_factory=list)
    ppl: Union[float, None] = field(default=None)

    @property
    def logprob(self) -> float:
        return math.log(self.prob)


@dataclass
class TopkTokenModel:
    do_sample: bool = False
    temperature: float = 0
    max_tokens: int = 4096
    repetition_penalty: float = 1.05
    num_beams: int = 1
    topk: int = 50
    topp: float = 0.95

    topk_per_token: int = 5  # number of topk tokens to generate for each token

    async def generate_topk_per_token(self, text: str) -> List[Token]:
        """
        Generate prob, text and candidates for each token of the model's output.
        This function is used to visualize the inference process.
        """
        raise NotImplementedError

    async def generate_inputs_prob(self, text: str, history: Optional[List[str]] = None) -> List[Token]:
        """
        Generate prob and text for each token of the input text.
        This function is used to visualize the ppl.
        """
        raise NotImplementedError

    async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
        """
        Generate answer from the model.
        """
        raise NotImplementedError