File size: 741 Bytes
92a5a61
b3ef6d9
 
bd6f27f
b3ef6d9
435bf96
b3ef6d9
 
bd6f27f
b3ef6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd6f27f
b3ef6d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the DialoGPT model and tokenizer
model_name = "Astrea/DialoGPT-small-akari"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Define the Gradio interface
akari = gr.Interface(
    fn=lambda messages: {
        "response": model.generate(
            tokenizer.encode(messages["input"], return_tensors="pt"),
            max_length=50,
            num_beams=5,
            no_repeat_ngram_size=2,
            top_k=50,
            top_p=0.95,
            temperature=0.7,
        )[0].decode("utf-8")
    },
    inputs=gr.Textbox(),
    outputs=gr.Textbox(),
)

akari.launch()