3demo / app.py
sadgaj's picture
Update app.py
c573591
raw
history blame
1.54 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from rdkit.Chem import Draw
from rdkit import Chem
import selfies as sf
sf_input
def greet1(name):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen")
sf_input = tokenizer(name, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=5,
num_beams=5)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
return sf_output
def greet2(name):
smis = [sf.decoder(i) for i in sf_output]
mols = []
for smi in smis:
mol = Chem.MolFromSmiles(smi)
mols.append(mol)
img = Draw.MolsToGridImage(
mols,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mols]
)
return img
greeter_1 = gr.Interface(greet1, inputs="text", outputs="text")
greeter_2 = gr.Interface(greet2 , inputs="text", outputs="image")
demo = gr.Parallel(greeter_1, greeter_2)
demo.launch()
#iface = gr.Interface(fn=greet, inputs="text", outputs="image", title="Molecular Language Model as Multi-task Generator",
# )
#iface.launch()