oucgc1996's picture
Upload 5 files
e41678c verified
raw
history blame
2.08 kB
import torch
import torch.nn as nn
from transformers import AutoTokenizer,AutoModelForSequenceClassification,AutoConfig
from torchcrf import CRF
import numpy as np
import pandas as pd
import re
from Bio.Seq import Seq
import matplotlib.pyplot as plt
from collections import OrderedDict
from transformers import set_seed
import random
import gradio as gr
def setup_seed(seed):
set_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(4)
device = "cpu"
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
config = AutoConfig.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def conotoxinfinder(files):
fr=open(files, 'r')
seqs = []
for line in fr:
if not line.startswith('>'):
line = line.replace('\n','')
line = line.replace(' ','')
if line.islower():
seqs.append(str((Seq(line).translate())))
else:
seqs.append(line)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=1)
model.load_state_dict(torch.load("best_model.pth"))
model = model.to(device)
value_all = []
for i in seqs:
tokenizer_test = tokenizer(i, return_tensors='pt').to(device)
with torch.no_grad():
value = model(**tokenizer_test)
value_all.append(np.exp(value["logits"][0].item()))
summary = OrderedDict()
summary['Seq'] = seqs
summary['Value'] = value_all
summary_df = pd.DataFrame(summary)
summary_df.to_csv('output.csv', index=False)
return 'output.csv'
with open("conotoxinfinder.md", "r") as f:
description = f.read()
iface = gr.Interface(fn=conotoxinfinder,
title="ConotoxinFinder α7 regression",
inputs=["file"
],
outputs= "file",
description=description
)
iface.launch()