Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,83 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
-
|
|
|
|
|
|
|
4 |
# Load the model and tokenizer
|
5 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
6 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
def generate_response(input_text):
|
|
|
|
|
9 |
inputs = tokenizer(input_text, return_tensors="pt")
|
10 |
outputs = model.generate(inputs.input_ids, max_length=50)
|
11 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
12 |
-
return
|
13 |
|
14 |
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
|
15 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer,StoppingCriteria
|
3 |
+
from langchain.chains import ConversationChain
|
4 |
+
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
|
5 |
+
from langchain.llms import HuggingFacePipeline
|
6 |
+
from langchain import PromptTemplate
|
7 |
# Load the model and tokenizer
|
8 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
9 |
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
10 |
|
11 |
+
generation_config = model.generation_config
|
12 |
+
generation_config.temperature = 0
|
13 |
+
generation_config.num_return_sequences = 1
|
14 |
+
generation_config.max_new_tokens = 256
|
15 |
+
generation_config.use_cache = False
|
16 |
+
generation_config.repetition_penalty = 1.7
|
17 |
+
generation_config.pad_token_id = tokenizer.eos_token_id
|
18 |
+
generation_config.eos_token_id = tokenizer.eos_token_id
|
19 |
+
generation_config
|
20 |
+
stop_tokens = [["Human", ":"], ["AI", ":"]]
|
21 |
+
stopping_criteria = StoppingCriteriaList(
|
22 |
+
[StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
|
23 |
+
)
|
24 |
+
|
25 |
+
class StopGenerationCriteria(StoppingCriteria):
|
26 |
+
def __init__(
|
27 |
+
self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
|
28 |
+
):
|
29 |
+
stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
30 |
+
self.stop_token_ids = [
|
31 |
+
torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
|
32 |
+
]
|
33 |
+
|
34 |
+
def __call__(
|
35 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
36 |
+
) -> bool:
|
37 |
+
for stop_ids in self.stop_token_ids:
|
38 |
+
if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
|
39 |
+
return True
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
generation_pipeline = pipeline(
|
44 |
+
model=model,
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
return_full_text=True,
|
47 |
+
task="text-generation",
|
48 |
+
stopping_criteria=stopping_criteria,
|
49 |
+
generation_config=generation_config,
|
50 |
+
)
|
51 |
+
|
52 |
+
llm = HuggingFacePipeline(pipeline=generation_pipeline)
|
53 |
+
template = """
|
54 |
+
The following
|
55 |
+
Current conversation:
|
56 |
+
|
57 |
+
{history}
|
58 |
+
|
59 |
+
Human: {input}
|
60 |
+
AI:""".strip()
|
61 |
+
prompt = PromptTemplate(input_variables=["history", "input"], template=template)
|
62 |
+
|
63 |
+
memory = ConversationBufferWindowMemory(
|
64 |
+
memory_key="history", k=6, return_only_outputs=True
|
65 |
+
)
|
66 |
+
|
67 |
+
chain = ConversationChain(
|
68 |
+
llm=llm,
|
69 |
+
memory=memory,
|
70 |
+
prompt=prompt,
|
71 |
+
verbose=True,
|
72 |
+
)
|
73 |
+
|
74 |
def generate_response(input_text):
|
75 |
+
res=chain.invoke(input_text)
|
76 |
+
print(4444444444444444444444444444444444444444444444)
|
77 |
inputs = tokenizer(input_text, return_tensors="pt")
|
78 |
outputs = model.generate(inputs.input_ids, max_length=50)
|
79 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
+
return res
|
81 |
|
82 |
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
|
83 |
iface.launch()
|