AminFaraji commited on
Commit
1b679df
·
verified ·
1 Parent(s): c74eeed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -3
app.py CHANGED
@@ -1,15 +1,83 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
-
 
 
 
4
  # Load the model and tokenizer
5
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
6
  model = AutoModelForCausalLM.from_pretrained("gpt2")
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def generate_response(input_text):
 
 
9
  inputs = tokenizer(input_text, return_tensors="pt")
10
  outputs = model.generate(inputs.input_ids, max_length=50)
11
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
12
- return response
13
 
14
  iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
15
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer,StoppingCriteria
3
+ from langchain.chains import ConversationChain
4
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain import PromptTemplate
7
  # Load the model and tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
9
  model = AutoModelForCausalLM.from_pretrained("gpt2")
10
 
11
+ generation_config = model.generation_config
12
+ generation_config.temperature = 0
13
+ generation_config.num_return_sequences = 1
14
+ generation_config.max_new_tokens = 256
15
+ generation_config.use_cache = False
16
+ generation_config.repetition_penalty = 1.7
17
+ generation_config.pad_token_id = tokenizer.eos_token_id
18
+ generation_config.eos_token_id = tokenizer.eos_token_id
19
+ generation_config
20
+ stop_tokens = [["Human", ":"], ["AI", ":"]]
21
+ stopping_criteria = StoppingCriteriaList(
22
+ [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
23
+ )
24
+
25
+ class StopGenerationCriteria(StoppingCriteria):
26
+ def __init__(
27
+ self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
28
+ ):
29
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
30
+ self.stop_token_ids = [
31
+ torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
32
+ ]
33
+
34
+ def __call__(
35
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
36
+ ) -> bool:
37
+ for stop_ids in self.stop_token_ids:
38
+ if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
39
+ return True
40
+ return False
41
+
42
+
43
+ generation_pipeline = pipeline(
44
+ model=model,
45
+ tokenizer=tokenizer,
46
+ return_full_text=True,
47
+ task="text-generation",
48
+ stopping_criteria=stopping_criteria,
49
+ generation_config=generation_config,
50
+ )
51
+
52
+ llm = HuggingFacePipeline(pipeline=generation_pipeline)
53
+ template = """
54
+ The following
55
+ Current conversation:
56
+
57
+ {history}
58
+
59
+ Human: {input}
60
+ AI:""".strip()
61
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
62
+
63
+ memory = ConversationBufferWindowMemory(
64
+ memory_key="history", k=6, return_only_outputs=True
65
+ )
66
+
67
+ chain = ConversationChain(
68
+ llm=llm,
69
+ memory=memory,
70
+ prompt=prompt,
71
+ verbose=True,
72
+ )
73
+
74
  def generate_response(input_text):
75
+ res=chain.invoke(input_text)
76
+ print(4444444444444444444444444444444444444444444444)
77
  inputs = tokenizer(input_text, return_tensors="pt")
78
  outputs = model.generate(inputs.input_ids, max_length=50)
79
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ return res
81
 
82
  iface = gr.Interface(fn=generate_response, inputs="text", outputs="text")
83
  iface.launch()