File size: 2,040 Bytes
1a3e79b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import gradio as gr
from smollm_training import SmolLMConfig, tokenizer, SmolLM


# Load the model
def load_model():
    config = SmolLMConfig()
    model = SmolLM(config)  # Create base model instead of Lightning model

    # Load just the model weights
    state_dict = torch.load("model_weights.pt", map_location="cpu")
    model.load_state_dict(state_dict)

    model.eval()
    return model


def generate_text(prompt, max_tokens, temperature=0.8, top_k=40):
    """Generate text based on the prompt"""
    try:
        # Encode the prompt
        prompt_ids = tokenizer.encode(prompt, return_tensors="pt")

        # Move to device if needed
        device = next(model.parameters()).device
        prompt_ids = prompt_ids.to(device)

        # Generate text
        with torch.no_grad():
            generated_ids = model.generate(  # Call generate directly on base model
                prompt_ids,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_k=top_k,
            )

        # Decode the generated text
        generated_text = tokenizer.decode(generated_ids[0].tolist())

        return generated_text

    except Exception as e:
        return f"An error occurred: {str(e)}"


# Load the model globally
model = load_model()

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(
            label="Enter your prompt", placeholder="Once upon a time...", lines=3
        ),
        gr.Slider(
            minimum=50,
            maximum=500,
            value=100,
            step=10,
            label="Maximum number of tokens",
        ),
    ],
    outputs=gr.Textbox(label="Generated Text", lines=10),
    title="SmolLM Text Generator",
    description="Enter a prompt and the model will generate a continuation.",
    examples=[
        ["Once upon a time", 100],
        ["The future of AI is", 200],
        ["In a galaxy far far away", 150],
    ],
)

if __name__ == "__main__":
    demo.launch()