SivaResearch commited on
Commit
397093f
·
verified ·
1 Parent(s): fa031c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -44
app.py CHANGED
@@ -1,49 +1,81 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
6
-
7
-
8
- def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
9
- formatted_text = ""
10
- for message in messages:
11
- if message["role"] == "system":
12
- formatted_text += "<|system|>\n" + message["content"] + "\n"
13
- elif message["role"] == "user":
14
- formatted_text += "<|user|>\n" + message["content"] + "\n"
15
- elif message["role"] == "assistant":
16
- formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
17
- else:
18
- raise ValueError(
19
- "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
20
- message["role"]
21
- )
22
- )
23
- formatted_text += "<|assistant|>\n"
24
- formatted_text = bos + formatted_text if add_bos else formatted_text
25
- return formatted_text
26
-
27
-
28
- def inference(input_prompts, model, tokenizer):
29
- input_prompts = [
30
- create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
31
- for input_prompt in input_prompts
32
- ]
33
-
34
- encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
35
- encodings = encodings.to(device)
36
-
37
- with torch.inference_mode():
38
- outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
39
-
40
- output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
41
-
42
- input_prompts = [
43
- tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
44
- ]
45
- output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
46
- return output_texts
47
 
48
 
49
  model_name = "ai4bharat/Airavata"
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ # Load model and tokenizer directly
5
+ tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
6
+ model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
7
+
8
+ def generate_response(prompt):
9
+ # Tokenize input prompt and generate response
10
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=256, truncation=True)
11
+ outputs = model.generate(**inputs)
12
+ response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
13
+
14
+ return response
15
+
16
+ # Define Gradio interface
17
+ iface = gr.Interface(
18
+ fn=generate_response,
19
+ inputs=gr.Textbox(),
20
+ outputs=gr.Textbox(),
21
+ live=True,
22
+ title="CAMAI",
23
+ description="Enter a prompt to generate text.",
24
+ )
25
+
26
+ # Launch Gradio interface
27
+ iface.launch()
28
+
29
+
30
+
31
+
32
+
33
+ # import torch
34
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
35
+ # import gradio as gr
36
+
37
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+
40
+ # def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
41
+ # formatted_text = ""
42
+ # for message in messages:
43
+ # if message["role"] == "system":
44
+ # formatted_text += "<|system|>\n" + message["content"] + "\n"
45
+ # elif message["role"] == "user":
46
+ # formatted_text += "<|user|>\n" + message["content"] + "\n"
47
+ # elif message["role"] == "assistant":
48
+ # formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
49
+ # else:
50
+ # raise ValueError(
51
+ # "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
52
+ # message["role"]
53
+ # )
54
+ # )
55
+ # formatted_text += "<|assistant|>\n"
56
+ # formatted_text = bos + formatted_text if add_bos else formatted_text
57
+ # return formatted_text
58
+
59
+
60
+ # def inference(input_prompts, model, tokenizer):
61
+ # input_prompts = [
62
+ # create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
63
+ # for input_prompt in input_prompts
64
+ # ]
65
+
66
+ # encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
67
+ # encodings = encodings.to(device)
68
+
69
+ # with torch.inference_mode():
70
+ # outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
71
+
72
+ # output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
73
 
74
+ # input_prompts = [
75
+ # tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
76
+ # ]
77
+ # output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
78
+ # return output_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
  model_name = "ai4bharat/Airavata"