본문 바로가기
기타

[ChemBERTa][SMILES][Drug Discovery][신약개발]ChemBERTa를 활용한 분자 구조 예측 모델 개발

by Chandler.j 2024. 8. 26.
반응형

1. 개요

  • 프로젝트 목표: 이 프로젝트는 ChemBERTa 모델을 활용하여 분자 구조(SMILES 문자열)와 관련된 예측을 수행하는 것입니다. 주어진 데이터셋에서 분자 구조를 학습하고, 이를 바탕으로 IC50 값을 예측하는 모델을 개발합니다.

2. 주요 라이브러리 및 모델 로드

  • 사용된 라이브러리: transformers, torch, pandas, sklearn
  • 모델 및 토크나이저 로드:
    • 모델 이름: seyonec/ChemBERTa-zinc-base-v1
    • RobertaTokenizer와 RobertaForSequenceClassification을 사용하여 ChemBERTa 모델과 토크나이저를 로드합니다.
model_name = "seyonec/ChemBERTa-zinc-base-v1"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=1)

 


3. 데이터 로드 및 전처리

  • 데이터 파일: train.csv와 test.csv 파일에서 데이터를 불러옵니다.
  • 데이터 전처리:
    • SMILES 문자열을 토큰화하고, 동적 패딩을 적용합니다.
    • IC50 값은 별도의 스케일링 없이 그대로 사용합니다.
train_df = pd.read_csv('../data/train.csv')
test_df = pd.read_csv('../data/test.csv')

def tokenize_smiles(smiles):
    return tokenizer(smiles, padding=True, truncation=True, max_length=128, return_tensors='pt')

train_tokens = [tokenize_smiles(smiles) for smiles in train_df['Smiles']]
test_tokens = [tokenize_smiles(smiles) for smiles in test_df['Smiles']]

4. 데이터셋 클래스 정의

  • CustomDataset 클래스: 모델 학습에 사용할 Dataset 클래스를 정의하여, SMILES 토큰과 IC50 값을 포함합니다.
  • getitemlen 메서드: 데이터셋의 인덱싱 및 길이를 정의하여 모델 학습에 필요한 데이터를 제공합니다.
class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

5. 모델 학습 설정

  • TrainingArguments 설정: 학습 파라미터를 정의합니다.
    • 에폭 수: 5
    • 배치 사이즈: 16
    • 학습률: 2e-5
  • Trainer 객체 생성: 모델 학습을 위한 Trainer 객체를 설정합니다.
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator
)

6. 모델 학습 및 평가

  • 모델 학습: Trainer 객체를 통해 모델을 학습시킵니다.
  • 모델 평가: 학습된 모델의 성능을 평가합니다.
trainer.train()
trainer.evaluate()

 


7. Full code and dataset

python_script.py
0.00MB
test.csv
0.01MB
train.csv
0.40MB


TOP

Designed by 티스토리