File size: 2,453 Bytes
82f9e44
 
 
 
 
 
ec30812
 
7e1aa1c
 
ec30812
7e1aa1c
 
 
ec30812
17a3eb0
ec30812
17a3eb0
 
 
ec30812
 
17a3eb0
 
ec30812
17a3eb0
 
ec30812
 
 
 
 
17a3eb0
 
ec30812
4adbe84
 
e6bd9b6
d727d22
 
82f9e44
ec30812
d727d22
 
419d496
28553d3
 
419d496
ec30812
d727d22
 
 
 
82f9e44
 
 
 
 
d727d22
419d496
 
ec30812
d727d22
 
ec30812
d727d22
 
 
 
82f9e44
ec30812
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model

# Load tokenizer
def get_tokenizer(config) -> Tokenizer:
    tokenizers_path = Path(config['tokenizer_file'])
    if Path.exists(tokenizers_path):
        print("Loading tokenizer from", tokenizers_path)
        tokenizer = Tokenizer.from_file(str(tokenizers_path))
        return tokenizer
    else:
        raise FileNotFoundError("Can't find tokenizer file:", tokenizers_path)

# Setup config
config = get_config("./openweb.config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)

# Token IDs
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
user_token_id = tokenizer.token_to_id("<user>")
ai_token_id = tokenizer.token_to_id("<ai>")

# Load model
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config, config['preload'])
model.eval()
state = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])

# Streaming text generation
def generate_response(prompt: str, history):
    #i dont train with history . so i not use it
    input_tokens = tokenizer.encode(prompt).ids
    input_tokens = [user_token_id] + input_tokens + [ai_token_id]

    if len(input_tokens) > config['seq_len']:
        yield "Prompt too long."
        return

    input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)
    temperature = 0.5
    top_k = 20
    generated_text = ""
    i = 0

    while input_tokens.shape[1] < 2000:
        out = model.decode(input_tokens)
        logits = model.project(out[:, -1])
        logits = logits / temperature
        top_k_logits, top_k_indices = torch.topk(logits, top_k)
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = top_k_indices.gather(-1, next_token)

        word = tokenizer.decode([next_token.item()])
        generated_text += word
        yield generated_text  # ✅ plain string for ChatInterface

        input_tokens = torch.cat([input_tokens, next_token], dim=1)

        if input_tokens.shape[1] > config['seq_len']:
            input_tokens = input_tokens[:, -config['seq_len']:]

        if next_token.item() == eos_token_id or i >= 1024:
            break
        i += 1