Astrea commited on
Commit
b3ef6d9
·
1 Parent(s): bd6f27f

Error fixed?!

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -1,18 +1,27 @@
1
  import gradio as gr
2
- import random
3
- import time
4
 
5
- ith gr.Blocks() as demo:
6
- chatbot = gr.Chatbot()
7
- msg = gr.Textbox()
8
- clear = gr.ClearButton([msg, chatbot])
9
 
10
- def respond(message, chat_history):
11
- bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
12
- chat_history.append((message, bot_message))
13
- time.sleep(2)
14
- return "", chat_history
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- msg.submit(respond, [msg, chatbot], [msg, chatbot])
17
-
18
- gr.load("models/Astrea/DialoGPT-small-akari").launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ # Load the DialoGPT model and tokenizer
6
+ model_name = "models/Astrea/DialoGPT-small-akari"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ # Define the Gradio interface
11
+ akari = gr.Interface(
12
+ fn=lambda messages: {
13
+ "response": model.generate(
14
+ tokenizer.encode(messages["input"], return_tensors="pt"),
15
+ max_length=50,
16
+ num_beams=5,
17
+ no_repeat_ngram_size=2,
18
+ top_k=50,
19
+ top_p=0.95,
20
+ temperature=0.7,
21
+ )[0].decode("utf-8")
22
+ },
23
+ inputs=gr.Textbox(),
24
+ outputs=gr.Textbox(),
25
+ )
26
 
27
+ akari.launch()