charanhu commited on
Commit
4dce766
·
1 Parent(s): f5b98ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -62
app.py CHANGED
@@ -1,72 +1,49 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import time
 
4
 
5
  # Load model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
7
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
 
9
- def generate_text(prompt, max_length=100, min_length=20, temperature=1.0):
10
- # Tokenize the prompt
11
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
12
-
13
- # Generate text
14
- output = model.generate(
15
- input_ids,
16
- max_length=max_length,
17
- min_length=min_length,
18
- num_return_sequences=1,
19
- temperature=temperature
20
- )
21
-
22
- # Decode the generated output
23
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
24
-
25
- return generated_text
26
-
27
- def print_like_dislike(x: gr.LikeData):
28
- print(x.index, x.value, x.liked)
29
-
30
- def add_text(history, text):
31
- history = history + [(text, None)]
32
- return history, gr.Textbox(value="", interactive=False)
33
-
34
- def bot(history, max_length=100, min_length=20, temperature=1.0):
35
- prompt = history[-1][0]
36
- response = generate_text(prompt, max_length=max_length, min_length=min_length, temperature=temperature)
37
- history[-1][1] = ""
38
- for character in response:
39
- history[-1][1] += character
40
- time.sleep(0.05)
41
- yield history
42
-
43
- with gr.Blocks() as demo:
44
- chatbot = gr.Chatbot(
45
- [],
46
- elem_id="chatbot",
47
- bubble_full_width=False,
48
- avatar_images=(None, None), # Set avatar image path or URL
49
- )
50
-
51
- with gr.Row():
52
- txt = gr.Textbox(
53
- scale=4,
54
- show_label=False,
55
- placeholder="Enter text and press enter, or upload an image",
56
- container=False,
57
  )
 
 
58
 
59
- max_len_slider = gr.Slider(0, 2048, 100, label="Max Length")
60
- min_len_slider = gr.Slider(0, 2048, 20, label="Min Length")
61
- temp_slider = gr.Slider(0.1, 2.0, 1.0, label="Temperature")
62
-
63
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
64
- bot, chatbot, chatbot, max_len_slider, min_len_slider, temp_slider
65
- )
66
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
67
 
68
- chatbot.like(print_like_dislike, None, None)
69
 
70
- demo.queue()
71
- if __name__ == "__main__":
72
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ from threading import Thread
5
 
6
  # Load model and tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
8
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T")
9
 
10
+ class StopOnTokens(StoppingCriteria):
11
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
12
+ stop_ids = [29, 0]
13
+ for stop_id in stop_ids:
14
+ if input_ids[0][-1] == stop_id:
15
+ return True
16
+ return False
17
+
18
+ def predict(message, history):
19
+
20
+ history_transformer_format = history + [[message, ""]]
21
+ stop = StopOnTokens()
22
+
23
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]]) #curr_system_message +
24
+ for item in history_transformer_format])
25
+
26
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
27
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
28
+ generate_kwargs = dict(
29
+ model_inputs,
30
+ streamer=streamer,
31
+ max_new_tokens=1024,
32
+ do_sample=True,
33
+ top_p=1,
34
+ top_k=50,
35
+ temperature=1.0,
36
+ num_beams=1,
37
+ stopping_criteria=StoppingCriteriaList([stop])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
40
+ t.start()
41
 
42
+ partial_message = ""
43
+ for new_token in streamer:
44
+ if new_token != '<':
45
+ partial_message += new_token
46
+ yield partial_message
 
 
 
47
 
 
48
 
49
+ gr.ChatInterface(predict).queue().launch()